In [71]:
from transformers import pipeline
import pandas as pd
import re
from tqdm import tqdm
from time import time

In [72]:
qa_pipeline = pipeline('question-answering', model="deepset/roberta-base-squad2-distilled", device='cuda')
# qa_pipeline = pipeline('question-answering', model="deepset/roberta-base-squad2", device=0)
context = ""
question = "Answer the numerical value with the relevant unit. What is the %% mentioned in the description?"



In [73]:
def process_text(answer, entity):
    # Clean up the output of the LLM
    entity_unit_map = {
        "width": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
        "depth": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
        "height": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
        "item_weight": {
            "milligram",
            "kilogram",
            "microgram",
            "gram",
            "ounce",
            "ton",
            "pound",
        },
        "voltage": {"millivolt", "kilovolt", "volt"},
        "wattage": {"kilowatt", "watt"},
        "item_volume": {
            "cubic foot",
            "microlitre",
            "cup",
            "fluid ounce",
            "centilitre",
            "imperial gallon",
            "pint",
            "decilitre",
            "litre",
            "millilitre",
            "quart",
            "cubic inch",
            "gallon",
            "ounce",
        },
        "maximum_weight_recommendation": {
            "milligram",
            "kilogram",
            "microgram",
            "gram",
            "ounce",
            "ton",
            "pound",
        },
    }
    unit_short_form = {
        "cm": "centimetre",
        "mm": "millimetre",
        "m": "metre",
        "in": "inch",
        "ft": "foot",
        "yd": "yard",
        "g": "gram",
        "kg": "kilogram",
        "mg": "milligram",
        "lb": "pound",
        "oz": "ounce",
        "ton": "ton",
        "ug": "microgram",
        "lbs": "pound",
        "9": "gram",
        "ozs": "ounce",
        "mv": "millivolt",
        "kv": "kilovolt",
        "v": "volt",
        "kw": "kilowatt",
        "w": "watt",
        "cf": "cubic foot",
        "ul": "microlitre",
        "fl oz": "fluid ounce",
        "cl": "centilitre",
        "gal": "imperial gallon",
        "pt": "pint",
        "dl": "decilitre",
        "l": "litre",
        "ml": "millilitre",
        "qt": "quart",
        "cu in": "cubic inch",
        "gals": "gallon",
        "c in": "cubic inch",
        "cu ft": "cubic foot",
        "o2": "ounce",
        "0z": "ounce",
        "k9": "kilogram",
        "\'\'": "inch",
        "``": "inch",
        "\"": "inch",
        # If no unit is mentioned, assume it is inches
        "": "inch",
        "c": "centimetre",
    }
    # Replace common OCR errors
    answer = answer.replace(",", ".")
    
        


    entity_unit = entity_unit_map[entity]
    regex_map = {
        "item_weight": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(g|kg|mg|ug|oz|ton|lb|lbs|Ibs|bs|1bs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
        "width": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)",
        "depth": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)",
        "height": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)",
        "voltage": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(mv|kv|v|milli|kilo)",
        "wattage": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(kw|w|kilo)",
        "item_volume": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(cf|ul|fl oz|cl|gal|pt|dl|l|ml|qt|cu in|gals|c in|cu ft|o2|oz|ozs|0z|cubic|micro|cup|fluid|centi|imperial|pint|decilitre|litre|millilitre|quart|gallon|ounce)",
        "maximum_weight_recommendation": r"[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?\s*(g|kg|mg|ug|oz|ton|lb|lbs|Ibs|bs|1bs|ozs|o2|02|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
    }
    reverse_regex_map = {
        "item_weight": r"\s*(g|kg|mg|ug|oz|ton|lb|lbs|Ibs|bs|1bs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "width": r"\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "depth": r"\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "height": r"\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "voltage": r"\s*(mv|kv|v|milli|kilo)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "wattage": r"\s*(kw|w|kilo)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "item_volume": r"\s*(cf|ul|fl oz|cl|gal|pt|dl|l|ml|qt|cu in|gals|c in|cu ft|o2|oz|ozs|0z|cubic|micro|cup|fluid|centi|imperial|pint|decilitre|litre|millilitre|quart|gallon|ounce)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
        "maximum_weight_recommendation": r"\s*(g|kg|mg|ug|oz|ton|lb|lbs|Ibs|bs|1bs|ozs|o2|02|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)\s*[0-9IJOQSZDLB]+(\.[0-9IJOQSZDLB]+)?",
    }
    regex_units = {
        "item_weight": r"\s*(g|kg|mg|ug|oz|ton|lb|lbs|Ibs|bs|1bs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
        "width": r"\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)",
        "depth": r"\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)",
        "height": r"\s*(c|cm|mm|m|in|ft|yd|centi|milli|inch|foot|yard|\'\'|``|\"|feet)",
        "voltage": r"\s*(mv|kv|v|milli|kilo)",
        "wattage": r"\s*(kw|w|kilo)",
        "item_volume": r"\s*(cf|ul|fl oz|cl|gal|pt|dl|l|ml|qt|cu in|gals|c in|cu ft|o2|oz|ozs|0z|cubic|micro|cup|fluid|centi|imperial|pint|decilitre|litre|millilitre|quart|gallon|ounce)",
        "maximum_weight_recommendation": r"\s*(g|kg|mg|ug|oz|ton|lb|lbs|Ibs|bs|1bs|ozs|o2|0z|k9|kilo|milli|micro|ounce|tonne|pound|gram)",
    }
    # Remove igh which will cause false positives
    answer = answer.replace("igh", "")
    # Remove ima which will cause false positives
    answer = answer.replace("ima", "")
    # Replace comma with dot
    answer = answer.replace(",", ".")
    # Match the regex ignoring case
    match = re.finditer(regex_map[entity], answer, re.IGNORECASE)
    reverse_match = re.finditer(reverse_regex_map[entity], answer, re.IGNORECASE)

    # Process each match
    for m in match:
        # Extract the number and the unit
        # Extract the matched value
        value = m.group(0)
        # Extract the unit
        unit = re.search(
            regex_units[entity], value, re.IGNORECASE
        )
        if (entity == "height" or entity == "width") and not unit:
            unit = "inch"
            sstr = ""
        elif not unit:
            continue
        else:
            unit = unit.group(0).strip()
            sstr = unit
            unit = unit.lower()
        if unit == "feet":
            unit = "foot"
        # Convert the unit to full form
        if unit in unit_short_form:
            unit = unit_short_form[unit]
        # Check if the unit is valid
        if unit in entity_unit:
            entity_value = value[: value.index(sstr)].strip()
            # Replace I with 1 and O with 0
            entity_value = (
                entity_value.replace("I", "1")
                .replace("i", "1")
                .replace("J", "1")
                .replace("j", "1")
                .replace("O", "0")
                .replace("o", "0")
                .replace("D", "0")
                .replace("B", "8")
                .replace("b", "6")
                .replace("L", "1")
                .replace("l", "1")
                .replace("S", "5")
                .replace("s", "5")
                .replace("Z", "2")
                .replace("z", "2")
                .replace("Q", "0")
                .replace("q", "9")
            )
            try:
                entity_value = float(entity_value)
                if int(entity_value) == entity_value and entity_value > 1000:
                    entity_value = int(entity_value)
            except:
                continue
            return str(entity_value) + " " + unit
    for m in reverse_match:
        value = m.group(0)
        unit = re.search(
            regex_units[entity], value, re.IGNORECASE
        )
        if entity == "height" and not unit:
            unit = "inch"
            sstr = ""
        elif not unit:
            continue
        else:
            unit = unit.group(0).strip()
            sstr = unit
            unit = unit.lower()
        if unit == "feet":
            unit = "foot"
        # Convert the unit to full form
        if unit in unit_short_form:
            unit = unit_short_form[unit]
        # Check if the unit is valid
        if unit in entity_unit:
            entity_value = value.replace(sstr, "").strip()
            # Replace I with 1 and O with 0
            entity_value = (
                entity_value.replace("I", "1")
                .replace("i", "1")
                .replace("J", "1")
                .replace("j", "1")
                .replace("O", "0")
                .replace("o", "0")
                .replace("D", "0")
                .replace("B", "8")
                .replace("b", "6")
                .replace("L", "1")
                .replace("l", "1")
                .replace("S", "5")
                .replace("s", "5")
                .replace("Z", "2")
                .replace("z", "2")
                .replace("Q", "0")
                .replace("q", "9")
            )
            try:
                # Reverse the value
                entity_value = entity_value[::-1]
                entity_value = float(entity_value)
                if int(entity_value) == entity_value and entity_value > 1000:
                    entity_value = int(entity_value)
            except:
                continue
            return str(entity_value) + " " + unit
    return None

In [74]:
test_data = pd.read_csv("../dataset/merged_output (3).csv")
test_data = test_data.dropna(subset=["OCR Text"])
test_data.head()

Unnamed: 0,index,image_link,group_id,entity_name,OCR Text
0,0,https://m.media-amazon.com/images/I/110EibNycl...,156839,height,2.63in 6.68cm 91.44cm - 199.39cm 36in - 78in
1,1,https://m.media-amazon.com/images/I/11TU2clswz...,792578,width,"Size Width Length One Size 42cm/16.54"" 200cm/7..."
2,2,https://m.media-amazon.com/images/I/11TU2clswz...,792578,height,"Size Width Length One Size 42cm/16.54"" 200cm/7..."
3,3,https://m.media-amazon.com/images/I/11TU2clswz...,792578,depth,"Size Width Length One Size 42cm/16.54"" 200cm/7..."
4,4,https://m.media-amazon.com/images/I/11gHj8dhhr...,792578,depth,"Size Width Length One Size 10.50cm/4.13"" 90cm/..."


In [81]:
test_input = pd.read_csv("../dataset/test.csv")

In [82]:
df = pd.DataFrame(columns=["index", "prediction"], index=list(range(len(test_input))))

In [83]:
for i, index in enumerate(tqdm(test_input["index"], total=len(test_input))):
    if index in test_data["index"]:
        row = test_data[test_data["index"] == index]
        if len(row) == 0:
            df.loc[i] = [index, ""]
            continue
        context = row["OCR Text"].values[0]
        entity = row["entity_name"].values[0]
        answer = qa_pipeline(question=question.replace("%%", entity), context=context)
        prediction = process_text(answer["answer"], entity)
        if prediction:
            df.loc[i] = [index, prediction]
        else:
            prediction = process_text(context, entity)
            if prediction:
                df.loc[i] = [index, prediction]
            else:
                df.loc[i] = [index, ""]
    else:
        df.loc[i] = [index, ""]
        

  1%|▏         | 1860/131187 [00:44<51:18, 42.01it/s]  


KeyboardInterrupt: 

In [47]:
df2 = df.copy()
df2.head(n=20)

Unnamed: 0,index,prediction
0,0,199.39 centimetre
1,1,200.0 centimetre
2,2,200.0 centimetre
3,3,200.0 centimetre
4,4,90.0 centimetre
5,5,90.0 centimetre
6,6,90.0 centimetre
7,7,3.56 centimetre
8,8,40.0 centimetre
9,9,40.0 centimetre


In [48]:
df2.to_csv("predictions.csv", index=False)

In [62]:
heights = pd.read_csv("../dataset/height_yolo.csv")
for row in tqdm(heights.iterrows(), total=len(heights)):
    row = row[1]
    index = row["index"]
    text = str(row["OCR Text"])
    if text == "LLM maybe":
        continue
    answer = process_text(text, "height")
    if answer:
        df.loc[df["index"] == index] = [index, answer]

 19%|█▊        | 5995/32282 [00:11<00:51, 514.86it/s]


KeyboardInterrupt: 

In [60]:
widths = pd.read_csv("../dataset/width_text.csv")
for row in tqdm(widths.iterrows(), total=len(widths)):
    row = row[1]
    index = row["index"]
    text = str(row["OCR Text"])
    if text == "LLM maybe":
        continue
    answer = process_text(text, "width")
    if answer:
        df.loc[df["index"] == index] = [index, answer]


  5%|▍         | 1324/26931 [00:02<00:57, 442.17it/s]


KeyboardInterrupt: 

In [52]:
widths_vllm = pd.read_csv("../dataset/width_VLM.csv")
for row in tqdm(widths_vllm.iterrows(), total=len(widths_vllm)):
    row = row[1]
    index = row["index"]
    entity = row["entity_name"]
    text = str(row["prediction"])
    if text == "LLM maybe":
        text = str(row["OCR Text"])
    answer = qa_pipeline(question=question.replace("%%", entity), context=text)
    prediction = process_text(answer["answer"], entity)
    if prediction:
        df.loc[df["index"] == index] = [index, prediction]

100%|██████████| 453/453 [00:12<00:00, 37.40it/s]


In [53]:
widths_vllm = pd.read_csv("../dataset/width_VLM1.csv")
for row in tqdm(widths_vllm.iterrows(), total=len(widths_vllm)):
    row = row[1]
    index = row["index"]
    entity = row["entity_name"]
    text = str(row["prediction"])
    if text == "LLM maybe":
        text = str(row["OCR Text"])
    answer = qa_pipeline(question=question.replace("%%", entity), context=text)
    prediction = process_text(answer["answer"], entity)
    if prediction:
        df.loc[df["index"] == index] = [index, prediction]

100%|██████████| 1080/1080 [00:25<00:00, 41.92it/s]


In [54]:
heights_vllm = pd.read_csv("../dataset/height_VLM.csv")
for row in tqdm(heights_vllm.iterrows(), total=len(heights_vllm)):
    row = row[1]
    index = row["index"]
    entity = row["entity_name"]
    text = str(row["prediction"])
    if text == "LLM maybe":
        text = str(row["OCR Text"])
    answer = qa_pipeline(question=question.replace("%%", entity), context=text)
    prediction = process_text(answer["answer"], entity)
    if prediction:
        df.loc[df["index"] == index] = [index, prediction]

  0%|          | 0/734 [00:00<?, ?it/s]

100%|██████████| 734/734 [00:20<00:00, 36.63it/s]


In [64]:
# weights = pd.read_csv("../dataset/check_weight.csv")
# for row in tqdm(weights.iterrows(), total=len(weights)):
#     row = row[1]
#     index = row["index"]
#     text = str(row["OCR Text"])
#     if text == "LLM maybe":
#         continue
#     answer = process_text(text, "item_weight")
#     if answer:
#         df.loc[df["index"] == index] = [index, answer]

100%|██████████| 12354/12354 [00:18<00:00, 652.36it/s]


In [65]:
df.to_csv(f"predictions{time()}.csv", index=False)

In [66]:
!python sanity.py --test_filename ../dataset/test.csv --output_filename predictions1726435789.0962567.csv

Parsing successfull for file: predictions1726435789.0962567.csv


In [70]:
edge_case = [
    "Load Capacity 243LBS 16 16"" 16""",
    "Nutrition Facts Servings1.Serv.size:1 bar (48g"
]

for case in edge_case:
    answer = qa_pipeline(question=question.replace("%%", "item_weight"), context=case)
    print(f"LLM Output: {answer['answer']}")
    print(process_text(answer["answer"], "item_weight"))

LLM Output: Load Capacity 243LBS 16 16 16
243.0 pound
LLM Output: 1
None
