In [None]:
!pip install transformers einops

import pandas as pd
import torch
from PIL import Image
import requests
from io import BytesIO
from torchvision import transforms
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

class ProductDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, mode='train', limit=None):
        # Read CSV and limit to first 'limit' rows if specified
        self.data = pd.read_csv(csv_file)
        if limit:
            self.data = self.data.head(limit)
        self.mode = mode

    def download_image(self, url):
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()  # Raises an HTTPError for bad responses
            img = Image.open(BytesIO(response.content))
            return img
        except (requests.RequestException, IOError):
            return None

    def resize_image(self, img, size=(224, 224)):
        # Resize the image using PIL
        return img.resize(size)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img = self.download_image(row['image_link'])

#         if img is not None:
#             img = self.resize_image(img)

        inputs = {
            "image": img,
            "group_id": row['group_id'],
            "question": f"What is the {row['entity_name']} of the product in the image? Answer in the format: x unit"
        }

        if self.mode == 'train':
            label = row['entity_value']
            return inputs, label
        else:
            return inputs, row['index']  # Return the index from test.csv

# Define the device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model and tokenizer
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, revision=revision
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)

# Move model to GPU
model.to(DEVICE)

# Instantiate the dataset
test_dataset = ProductDataset('test.csv', mode='test')

# Create a list to store results
results_list = []

for sample, index in tqdm(test_dataset):
    if sample['image'] is not None:
        try:
            md_answer = model.answer_question(
                model.encode_image(sample['image']),
                sample['question'],
                tokenizer=tokenizer,
            )
        except Exception as e:
            print(f"Error processing image for index {index}: {str(e)}")
            md_answer = ""
    else:
        md_answer = ""

    # Append the result to the list
    results_list.append({'index': index, 'prediction': md_answer})

# Create DataFrame from the list of results
results_df = pd.DataFrame(results_list)

# Sort the DataFrame by index to ensure correct order
results_df = results_df.sort_values('index').reset_index(drop=True)

# Export the DataFrame to a CSV file
results_df.to_csv('test_output.csv', index=False)
print("Predictions saved to test_output.csv")

In [None]:
# ----------------------- POST PROCESSING FOR WEIGHT ------------------

import csv
import re

unit_map = {
    'g': 'gram',
    'gm': 'gram',
    'gram': 'gram',
    'grams': 'gram',
    'kg': 'kilogram',
    'kilo': 'kilogram',
    'kilogram': 'kilogram',
    'μg': 'microgram',
    'ug': 'microgram',
    'mcg': 'microgram',
    'microgram': 'microgram',
    'mg': 'milligram',
    'milligram': 'milligram',
    'oz': 'ounce',
    'ounce': 'ounce',
    'ozs': 'ounce',
    'lb': 'pound',
    'lbs': 'pound',
    'lbs.': 'pound',
    'pound': 'pound',
    'pounds': 'pound',
    't': 'ton',
    'ton': 'ton',
    'tons': 'ton'
}

valid_units = {'gram', 'kilogram', 'microgram', 'milligram', 'ounce', 'pound', 'ton'}

def format_prediction(prediction):
    prediction = re.sub(r'\([^)]*\)', '', prediction)
    # Regular expression to match number and unit, with or without space
    match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)\s*(.*)$', prediction.strip())

    if not match:
        return ''

    value, unit = match.groups()

    # If no unit was captured, check if the unit is attached to the number
    if not unit:
        for abbr in sorted(unit_map.keys(), key=len, reverse=True):
            if value.lower().endswith(abbr.lower()):
                unit = abbr
                value = value[:-len(abbr)]
                break

    # Convert unit to lowercase for case-insensitive matching
    unit = unit.lower()

    # Check if the unit is in our mapping
    if unit in unit_map:
        full_unit = unit_map[unit]
        if full_unit in valid_units:
            return f"{value.strip()} {full_unit}"

    # If unit is not recognized or no unit provided, return empty string
    return ''

# Read the input CSV and write to a new CSV
with open('test_output.csv', 'r') as infile, open('test_output.csv', 'w', newline='') as outfile:
    reader = csv.DictReader(infile)
    fieldnames = ['index', 'prediction']
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)

    writer.writeheader()
    for row in reader:
        formatted_prediction = format_prediction(row['prediction'])
        writer.writerow({'index': row['index'], 'prediction': formatted_prediction})

In [None]:
# -------------- POST PROCESSING FOR VOLTAGE -------------

import csv
import re

unit_map = {
    'v': 'volt',
    'volt': 'volt',
    'volts': 'volt',
    'kv': 'kilovolt',
    'kilovolt': 'kilovolt',
    'kilovolts': 'kilovolt',
    'mv': 'millivolt',
    'millivolt': 'millivolt',
    'millivolts': 'millivolt'
}

valid_units = {'volt', 'kilovolt', 'millivolt'}

def format_prediction(prediction):
    # Remove any content within parentheses
    prediction = re.sub(r'\([^)]*\)', '', prediction)

    # Regular expression to match number and unit, with or without space
    match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)[-\s/]?([^\s/]+)?', prediction.strip())

    if not match:
        return ''

    value, unit = match.groups()

    # If no unit was captured, check if the unit is attached to the number
    if not unit:
        for abbr in sorted(unit_map.keys(), key=len, reverse=True):
            if value.lower().endswith(abbr.lower()):
                unit = abbr
                value = value[:-len(abbr)]
                break
        else:
            # If no unit is found, default to volt
            return f"{value.strip()} volt"

    # Convert unit to lowercase for case-insensitive matching
    unit = unit.lower()

    # Remove any non-alphabetic characters from the unit
    unit = re.sub(r'[^a-z]', '', unit)

    # Check if the unit is in our mapping
    if unit in unit_map:
        full_unit = unit_map[unit]
        if full_unit in valid_units:
            return f"{value.strip()} {full_unit}"

    # If unit is not recognized, default to volt
    if unit.startswith('v') or unit == 'w':
        return f"{value.strip()} volt"

    # For other cases, return empty string
    return ''

# Read the input CSV and write to a new CSV
with open('test_output.csv', 'r') as infile, open('test_output.csv', 'w', newline='') as outfile:
    reader = csv.DictReader(infile)
    fieldnames = ['index', 'prediction']
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)

    writer.writeheader()
    for row in reader:
        formatted_prediction = format_prediction(row['prediction'])
        writer.writerow({'index': row['index'], 'prediction': formatted_prediction})


In [None]:
# -------------------- POST PROCESSING FOR WATTAGE -------------------

import csv
import re

unit_map = {
    'w': 'watt',
    'watt': 'watt',
    'watts': 'watt',
    'kw': 'kilowatt',
    'kilowatt': 'kilowatt',
    'kilowatts': 'kilowatt',
    'k': 'kilowatt'
}

valid_units = {'watt', 'kilowatt'}

def format_prediction(prediction):
    # Remove any content within parentheses
    prediction = re.sub(r'\([^)]*\)', '', prediction)

    # Remove commas from numbers
    prediction = prediction.replace(',', '')

    # Regular expression to match number and unit, with or without space
    match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)[-\s]?([^\s-]+)?', prediction.strip())

    if not match:
        return ''

    value, unit = match.groups()

    # If no unit was captured, check if the unit is attached to the number
    if not unit:
        for abbr in sorted(unit_map.keys(), key=len, reverse=True):
            if value.lower().endswith(abbr.lower()):
                unit = abbr
                value = value[:-len(abbr)]
                break
        else:
            # If no unit is found, default to watt
            return f"{value.strip()} watt"

    # Convert unit to lowercase for case-insensitive matching
    unit = unit.lower()

    # Remove any non-alphabetic characters from the unit
    unit = re.sub(r'[^a-z]', '', unit)

    # Check if the unit is in our mapping
    if unit in unit_map:
        full_unit = unit_map[unit]
        if full_unit in valid_units:
            return f"{value.strip()} {full_unit}"

    # If unit is not recognized, default to watt for 'v', 'hr', or 'btu'
    if unit.startswith('v') or unit == 'hr' or unit == 'btu':
        return f"{value.strip()} watt"

    # For other cases, return empty string
    return ''

# Read the input CSV and write to a new CSV
with open('test_output.csv', 'r') as infile, open('test_output.csv', 'w', newline='') as outfile:
    reader = csv.DictReader(infile)
    fieldnames = ['index', 'prediction']
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)

    writer.writeheader()
    for row in reader:
        formatted_prediction = format_prediction(row['prediction'])
        writer.writerow({'index': row['index'], 'prediction': formatted_prediction})

In [None]:
# ---------------------------- POST PROCESSING FOR WIDTH ------------------

import csv
import re

unit_map = {
    'cm': 'centimetre',
    'centimeter': 'centimetre',
    'centimeters': 'centimetre',
    'centimetre': 'centimetre',
    'centimetres': 'centimetre',
    'ft': 'foot',
    'feet': 'foot',
    'foot': 'foot',
    '"': 'inch',
    'in': 'inch',
    'inch': 'inch',
    'inches': 'inch',
    'm': 'metre',
    'meter': 'metre',
    'meters': 'metre',
    'metre': 'metre',
    'metres': 'metre',
    'mm': 'millimetre',
    'millimeter': 'millimetre',
    'millimeters': 'millimetre',
    'millimetre': 'millimetre',
    'millimetres': 'millimetre',
    'yd': 'yard',
    'yard': 'yard',
    'yards': 'yard'
}

valid_units = {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'}

def format_prediction(prediction):
    try:
        # Remove any content within parentheses
        prediction = re.sub(r'\([^)]*\)', '', prediction)

        # Regular expression to match number and unit, with or without space
        match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)\s*([^\s/]+)?', prediction.strip())

        if not match:
            # If no match, check if it's just a number (default to centimetre)
            number_match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)$', prediction.strip())
            if number_match:
                return f"{number_match.group(1)} centimetre"
            return ''

        value, unit = match.groups()

        # If no unit was captured, check if the unit is attached to the number
        if not unit:
            for abbr in sorted(unit_map.keys(), key=len, reverse=True):
                if value.lower().endswith(abbr.lower()):
                    unit = abbr
                    value = value[:-len(abbr)]
                    break
            else:
                # If no unit is found, default to centimetre
                return f"{value.strip()} centimetre"

        # Convert unit to lowercase for case-insensitive matching
        unit = unit.lower()

        # Remove any non-alphabetic characters from the unit, except for '"'
        unit = re.sub(r'[^a-z"]', '', unit)

        # Check if the unit is in our mapping
        if unit in unit_map:
            full_unit = unit_map[unit]
            if full_unit in valid_units:
                return f"{value.strip()} {full_unit}"

        # For other cases, return empty string
        return ''
    except UnicodeDecodeError:
        return ''

# Read the input CSV and write to a new CSV
with open('test_output.csv', 'r', encoding='utf-8', errors='replace') as infile, open('test_output.csv', 'w', newline='', encoding='utf-8') as outfile:
    reader = csv.DictReader(infile)
    fieldnames = ['index', 'prediction']
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)

    writer.writeheader()
    for row in reader:
        formatted_prediction = format_prediction(row['prediction'])
        writer.writerow({'index': row['index'], 'prediction': formatted_prediction})


In [None]:
# ------------------- POST PROCESSING FOR ITEM VOLUME -----------------

import csv
import re

unit_map = {
    'cl': 'centilitre',
    'centiliter': 'centilitre',
    'centilitre': 'centilitre',
    'cu ft': 'cubic foot',
    'cubic ft': 'cubic foot',
    'cubic foot': 'cubic foot',
    'cu in': 'cubic inch',
    'cubic in': 'cubic inch',
    'cubic inch': 'cubic inch',
    'cup': 'cup',
    'cups': 'cup',
    'dl': 'decilitre',
    'deciliter': 'decilitre',
    'decilitre': 'decilitre',
    'fl oz': 'fluid ounce',
    'fl. oz.': 'fluid ounce',
    'fluid oz': 'fluid ounce',
    'fluid ounce': 'fluid ounce',
    'gal': 'gallon',
    'gallon': 'gallon',
    'gallons': 'gallon',
    'imp gal': 'imperial gallon',
    'imperial gal': 'imperial gallon',
    'imperial gallon': 'imperial gallon',
    'l': 'litre',
    'liter': 'litre',
    'litre': 'litre',
    'μl': 'microlitre',
    'ul': 'microlitre',
    'microliter': 'microlitre',
    'microlitre': 'microlitre',
    'ml': 'millilitre',
    'milliliter': 'millilitre',
    'millilitre': 'millilitre',
    'pt': 'pint',
    'pint': 'pint',
    'pints': 'pint',
    'qt': 'quart',
    'quart': 'quart',
    'quarts': 'quart'
}

valid_units = {'centilitre', 'cubic foot', 'cubic inch', 'cup', 'decilitre', 'fluid ounce', 'gallon', 'imperial gallon', 'litre', 'microlitre', 'millilitre', 'pint', 'quart'}

def format_prediction(prediction):
    try:
        # Remove any content within parentheses
        prediction = re.sub(r'\([^)]*\)', '', prediction)

        # Regular expression to match number and unit, with or without space
        match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)\s*([^\s/]+(?:\s+[^\s/]+)*)?', prediction.strip())

        if not match:
            # If no match, check if it's just a number (default to fluid ounce)
            number_match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)$', prediction.strip())
            if number_match:
                return f"{number_match.group(1)} fluid ounce"
            return ''

        value, unit = match.groups()

        # If no unit was captured, default to fluid ounce
        if not unit:
            return f"{value.strip()} fluid ounce"

        # Convert unit to lowercase for case-insensitive matching
        unit = unit.lower()

        # Remove any non-alphabetic characters from the unit, except for spaces
        unit = re.sub(r'[^a-z\s]', '', unit)

        # Check if the unit is in our mapping
        for key in unit_map:
            if unit.startswith(key):
                full_unit = unit_map[key]
                if full_unit in valid_units:
                    return f"{value.strip()} {full_unit}"

        # For other cases, return empty string
        return ''
    except Exception:
        return ''

# Read the input CSV and write to a new CSV
with open('test_output.csv', 'r', encoding='utf-8', errors='replace') as infile, open('test_output.csv', 'w', newline='', encoding='utf-8') as outfile:
    reader = csv.DictReader(infile)
    fieldnames = ['index', 'prediction']
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)

    writer.writeheader()
    for row in reader:
        formatted_prediction = format_prediction(row['prediction'])
        writer.writerow({'index': row['index'], 'prediction': formatted_prediction})


In [None]:
# -------- POST PROCESSING FOR HEIGHT -----------------

import csv
import re

unit_map = {
    'cm': 'centimetre',
    'centimeter': 'centimetre',
    'centimeters': 'centimetre',
    'centimetre': 'centimetre',
    'centimetres': 'centimetre',
    'ft': 'foot',
    'feet': 'foot',
    'foot': 'foot',
    '"': 'inch',
    'in': 'inch',
    'inch': 'inch',
    'inches': 'inch',
    'm': 'metre',
    'meter': 'metre',
    'meters': 'metre',
    'metre': 'metre',
    'metres': 'metre',
    'mm': 'millimetre',
    'millimeter': 'millimetre',
    'millimeters': 'millimetre',
    'millimetre': 'millimetre',
    'millimetres': 'millimetre',
    'yd': 'yard',
    'yard': 'yard',
    'yards': 'yard'
}

valid_units = {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'}

def format_prediction(prediction):
    try:
      # Remove content within parentheses
      prediction = re.sub(r'\([^)]*\)', '', prediction)

      # Remove any non-digit, non-letter characters from the end
      prediction = re.sub(r'[^a-zA-Z0-9]+$', '', prediction)

      # Regular expression to match number and unit, with or without space
      match = re.match(r'^([-+]?[0-9]*\.?[0-9]+)\s*(.*)$', prediction.strip())

      if not match:
          # Check for cases like "x38 mm"
          match = re.search(r'(\d+(?:\.\d+)?)\s*(.+)', prediction)
          if match:
              value, unit = match.groups()
          else:
              return ''
      else:
          value, unit = match.groups()

      # If no unit was captured, check if the unit is attached to the number
      if not unit:
          for abbr in sorted(unit_map.keys(), key=len, reverse=True):
              if value.lower().endswith(abbr.lower()):
                  unit = abbr
                  value = value[:-len(abbr)]
                  break

      # Convert unit to lowercase for case-insensitive matching
      unit = unit.lower()

      # Check if the unit is in our mapping
      if unit in unit_map:
          full_unit = unit_map[unit]
          if full_unit in valid_units:
              return f"{value.strip()} {full_unit}"

      # Special case for no unit (assume centimetre)
      if unit == '' and value.replace('.', '').isdigit():
          return f"{value.strip()} centimetre"

      # If unit is not recognized or no valid unit provided, return empty string
      return ''
    except Exception:
      return ''

# Read the input CSV and write to a new CSV
with open('test_output.csv', 'r') as infile, open('test_output.csv', 'w', newline='') as outfile:
    reader = csv.DictReader(infile)
    fieldnames = ['index', 'prediction']
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)
    writer.writeheader()

    for row in reader:
        original_prediction = row['prediction']

        # Handle special cases
        if 'x' in original_prediction.lower():
            parts = re.split(r'[x×]', original_prediction.lower())
            formatted_prediction = format_prediction(parts[0])
        elif '-' in original_prediction:
            parts = original_prediction.split('-')
            formatted_prediction = format_prediction(parts[-1])
        else:
            formatted_prediction = format_prediction(original_prediction)

        writer.writerow({'index': row['index'], 'prediction': formatted_prediction})


In [None]:
# ------------- FOR MERGING ANSWER FILE WITH NEW FORMATTED FILE ------------

import pandas as pd

# Read both CSV files
df1 = pd.read_csv('test_output_final.csv', index_col='index')
df2 = pd.read_csv('test_output.csv', index_col='index')

# Combine the dataframes, giving priority to df2
combined_df = df1.combine_first(df2)

# Fill NaN values with empty string
combined_df = combined_df.fillna('')

# Reset the index to make it a column again
combined_df = combined_df.reset_index()

# Ensure the final dataframe has the same number of rows as file1
final_df = pd.DataFrame(index=range(len(df1)), columns=['index', 'prediction'])
final_df['index'] = combined_df['index']
final_df['prediction'] = combined_df['prediction']

# Fill any remaining NaN values with empty string
final_df = final_df.fillna('')

# Save the result to a new CSV file
final_df.to_csv('test_output_final.csv', index=False)