In [15]:
from transformers import pipeline
import pandas as pd
import re
from tqdm import tqdm
from time import time

In [16]:
qa_pipeline = pipeline('question-answering', model="deepset/roberta-base-squad2-distilled", device='cuda')
context = ""
question = "What is the %% mentioned in the description?"



In [66]:
def process_text(answer, entity):
    # Clean up the output of the LLM
    entity_unit_map = {
        "width": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
        "depth": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
        "height": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
        "item_weight": {
            "milligram",
            "kilogram",
            "microgram",
            "gram",
            "ounce",
            "ton",
            "pound",
        },
        "voltage": {"millivolt", "kilovolt", "volt"},
        "wattage": {"kilowatt", "watt"},
        "item_volume": {
            "cubic foot",
            "microlitre",
            "cup",
            "fluid ounce",
            "centilitre",
            "imperial gallon",
            "pint",
            "decilitre",
            "litre",
            "millilitre",
            "quart",
            "cubic inch",
            "gallon",
            "ounce",
        },
        "maximum_weight_recommendation": {
            "milligram",
            "kilogram",
            "microgram",
            "gram",
            "ounce",
            "ton",
            "pound",
        },
    }
    unit_short_form = {
        "cm": "centimetre",
        "mm": "millimetre",
        "m": "metre",
        "in": "inch",
        "ft": "foot",
        "yd": "yard",
        "g": "gram",
        "kg": "kilogram",
        "mg": "milligram",
        "lb": "pound",
        "oz": "ounce",
        "ton": "ton",
        "ug": "microgram",
        "lbs": "pound",
        "9": "gram",
        "ozs": "ounce",
        "mv": "millivolt",
        "kv": "kilovolt",
        "v": "volt",
        "kw": "kilowatt",
        "w": "watt",
        "cf": "cubic foot",
        "ul": "microlitre",
        "fl oz": "fluid ounce",
        "cl": "centilitre",
        "gal": "imperial gallon",
        "pt": "pint",
        "dl": "decilitre",
        "l": "litre",
        "ml": "millilitre",
        "qt": "quart",
        "cu in": "cubic inch",
        "gals": "gallon",
        "c in": "cubic inch",
        "cu ft": "cubic foot",
        "o2": "ounce",
        "0z": "ounce",
        "k9": "kilogram",
    }
    entity_unit = entity_unit_map[entity]
    regex_map = {
        "item_weight": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(g|kg|mg|ug|oz|ton|lb|lbs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
        "width": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard)",
        "depth": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard)",
        "height": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard)",
        "voltage": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(mv|kv|v|milli|kilo)",
        "wattage": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(kw|w|kilo)",
        "item_volume": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(cf|ul|fl oz|cl|gal|pt|dl|l|ml|qt|cu in|gals|c in|cu ft|o2|oz|ozs|0z|cubic|micro|cup|fluid|centi|imperial|pint|decilitre|litre|millilitre|quart|gallon|ounce)",
        "maximum_weight_recommendation": r"[0-9IJOD]+(\.[0-9IJOD]+)?\s*(g|kg|mg|ug|oz|ton|lb|lbs|ozs|o2|02|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
    }
    regex_units = {
        "item_weight": r"\s*(g|kg|mg|ug|oz|ton|lb|lbs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
        "width": r"\s*(cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard)",
        "depth": r"\s*(cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard)",
        "height": r"\s*(cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard)",
        "voltage": r"\s*(mv|kv|v|milli|kilo)",
        "wattage": r"\s*(kw|w|kilo)",
        "item_volume": r"\s*(cf|ul|fl oz|cl|gal|pt|dl|l|ml|qt|cu in|gals|c in|cu ft|o2|oz|ozs|0z|cubic|micro|cup|fluid|centi|imperial|pint|decilitre|litre|millilitre|quart|gallon|ounce)",
        "maximum_weight_recommendation": r"\s*(g|kg|mg|ug|oz|ton|lb|lbs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
    }
    # Remove igh which will cause false positives
    answer = answer.replace("igh", "")
    # Replace comma with dot
    answer = answer.replace(",", ".")
    # Match the regex ignoring case
    match = re.finditer(regex_map[entity], answer, re.IGNORECASE)
    # Process each match
    for m in match:
        # Extract the number and the unit
        # Extract the matched value
        value = m.group(0)
        # Extract the unit
        unit = re.search(
            regex_units[entity], value, re.IGNORECASE
        )
        if not unit:
            continue
        unit = unit.group(0).strip()
        sstr = unit
        unit = unit.lower()
        if unit == "feet":
            unit = "foot"
        # Convert the unit to full form
        if unit in unit_short_form:
            unit = unit_short_form[unit]
        # Check if the unit is valid
        if unit in entity_unit:
            entity_value = value.replace(sstr, "").strip()
            # Replace I with 1 and O with 0
            entity_value = (
                entity_value.replace("I", "1")
                .replace("i", "1")
                .replace("J", "1")
                .replace("j", "1")
                .replace("O", "0")
                .replace("o", "0")
                .replace("D", "0")
            )
            try:
                entity_value = float(entity_value)
                if int(entity_value) == entity_value and entity_value > 1000:
                    entity_value = int(entity_value)
            except:
                continue
            return str(entity_value) + " " + unit
    return None

In [67]:
test_data = pd.read_csv("../dataset/merged_data.csv")
test_data = test_data.dropna(subset=["OCR Text"])
test_data.head()

  test_data = pd.read_csv("../dataset/merged_data.csv")


Unnamed: 0,index,image_link,group_id,entity_name,OCR Text
0,0,https://m.media-amazon.com/images/I/110EibNycl...,156839,height,3rcn51 44muieetcm
1,1,https://m.media-amazon.com/images/I/11TU2clswz...,792578,width,"SizeWidthLengthOne Size42cm/16.54""200cm/78.74"""
2,2,https://m.media-amazon.com/images/I/11TU2clswz...,792578,height,"SizeWidthLengthOne Size42cm/16.54""200cm/78.74"""
3,3,https://m.media-amazon.com/images/I/11TU2clswz...,792578,depth,"SizeWidthLengthOne Size42cm/16.54""200cm/78.74"""
4,4,https://m.media-amazon.com/images/I/11gHj8dhhr...,792578,depth,"SizeWidthLengthOne Size10.50cm/4.13""90cm/35.43"""


In [68]:
test_input = pd.read_csv("../dataset/test.csv")

In [69]:
df = pd.DataFrame(columns=["index", "prediction"], index=list(range(len(test_input))))

In [70]:
for i, index in enumerate(tqdm(test_input["index"], total=len(test_input))):
    if index in test_data["index"]:
        row = test_data[test_data["index"] == index]
        if len(row) == 0:
            df.loc[i] = [index, ""]
            continue
        context = row["OCR Text"].values[0]
        entity = row["entity_name"].values[0]
        answer = qa_pipeline(question=question.replace("%%", entity), context=context)
        prediction = process_text(answer["answer"], entity)
        if prediction:
            df.loc[i] = [index, prediction]
        else:
            df.loc[i] = [index, ""]
    else:
        df.loc[i] = [index, ""]
        

100%|██████████| 131187/131187 [17:02<00:00, 128.29it/s]  


In [73]:
df.to_csv(f"predictions{time()}.csv", index=False)

In [74]:
!python sanity.py --test_filename ../dataset/test.csv --output_filename predictions1726322417.9047391.csv

Parsing successfull for file: predictions1726322417.9047391.csv
