# 📦 ML Image Entity Extraction using OCR (EasyOCR)

In [None]:
import os
import pandas as pd
import numpy as np
import re
from tqdm import tqdm
from PIL import Image
import easyocr

from sklearn.model_selection import train_test_split

from utils import download_images, parse_string
from constants import entity_unit_map


In [None]:
train_df = pd.read_csv("/home/ds_yashraj/Projects/Image_Feature_Extraction/dataset/train.csv").sample(n=5000, random_state=42).reset_index(drop=True)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)


In [None]:
image_path = os.path.join("images/train", image_name)
image_path = os.path.join("images/test", image_name)



In [None]:
reader = easyocr.Reader(['en'], gpu=True)


In [None]:
def ocr_predict(image_path, entity):
    try:
        result = reader.readtext(image_path, detail=0)
        text = " ".join(result).lower()
        # Get allowed units
        allowed_units = entity_unit_map.get(entity, [])
        pattern = r"(\d+(?:\.\d+)?)\s*(" + "|".join([re.escape(u) for u in allowed_units]) + ")"
        match = re.search(pattern, text)
        if match:
            value, unit = match.groups()
            return f"{float(value):.2f} {unit}"
    except:
        pass
    return ""


In [None]:
val_preds = []
val_true = []

for i, row in tqdm(val_df.iterrows(), total=len(val_df)):
    raw_val = row['entity_value']
    if isinstance(raw_val, str) and raw_val.startswith('['):
        match = re.match(r"\[(\d+(?:\.\d+)?),\s*(\d+(?:\.\d+)?)\]\s+([a-zA-Z\s]+)", raw_val)
        if match:
            avg = (float(match.group(1)) + float(match.group(2))) / 2
            raw_val = f"{avg:.2f} {match.group(3)}"

    try:
        true_val, true_unit = parse_string(raw_val)
        val_true.append(f"{true_val:.2f} {true_unit}")
    except:
        val_true.append("")

    image_name = os.path.basename(row['image_link'])
    image_path = os.path.join("images/val", image_name)
    pred = ocr_predict(image_path, row['entity_name'])
    val_preds.append(pred)


In [None]:
def extract_parts(s):
    try:
        val, unit = parse_string(s)
        return f"{val:.2f} {unit}"
    except:
        return ""

TP = FP = FN = 0
for pred, true in zip(val_preds, val_true):
    pred_clean = extract_parts(pred)
    true_clean = extract_parts(true)

    if pred_clean and true_clean:
        if pred_clean == true_clean:
            TP += 1
        else:
            FP += 1
    elif pred_clean and not true_clean:
        FP += 1
    elif not pred_clean and true_clean:
        FN += 1

precision = TP / (TP + FP) if (TP + FP) else 0
recall = TP / (TP + FN) if (TP + FN) else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0

print(f"F1 Score: {f1:.4f} | TP: {TP}, FP: {FP}, FN: {FN}")
