<a href="https://colab.research.google.com/github/ztide-ad/AmazonMLChallenge/blob/main/AmazonML_challenge_lr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytesseract
!sudo apt-get install tesseract-ocr

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.


In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision
import pytesseract
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import re
import requests
import multiprocessing
import time
from time import time as timer
from pathlib import Path
from functools import partial
import urllib
from PIL import Image

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
allowed_units = [
    'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard',
    'gram', 'kilogram', 'microgram', 'milligram', 'ounce', 'pound', 'ton',
    'kilovolt', 'millivolt', 'volt',
    'kilowatt', 'watt',
    'centilitre', 'cubic foot', 'cubic inch', 'cup', 'decilitre', 'fluid ounce',
    'gallon', 'imperial gallon', 'litre', 'microlitre', 'millilitre', 'pint', 'quart'
]

In [5]:
def common_mistake(unit):
    if unit in allowed_units:
        return unit
    if unit.replace('ter', 'tre') in allowed_units:
        return unit.replace('ter', 'tre')
    if unit.replace('feet', 'foot') in allowed_units:
        return unit.replace('feet', 'foot')
    return unit

def parse_string(s):
    s_stripped = "" if s is None or str(s) == 'nan' else s.strip()
    if s_stripped == "":
        return None, None

    # Handle range values
    range_pattern = re.compile(r'^\[(\d+(\.\d+)?),\s*(\d+(\.\d+)?)\]\s+([a-zA-Z\s]+)$')
    range_match = range_pattern.match(s_stripped)

    if range_match:
        # For ranges, we'll use the average of the two values
        start, end = float(range_match.group(1)), float(range_match.group(3))
        number = (start + end) / 2
        unit = range_match.group(5)
    else:
        # Original pattern for single values
        pattern = re.compile(r'^-?\d+(\.\d+)?\s+[a-zA-Z\s]+$')
        if not pattern.match(s_stripped):
            raise ValueError(f"Invalid format in {s}")
        parts = s_stripped.split(maxsplit=1)
        number = float(parts[0])
        unit = parts[1]

    unit = common_mistake(unit)
    if unit not in allowed_units:
        raise ValueError(f"Invalid unit [{unit}] found in {s}. Allowed units: {allowed_units}")

    return number, unit

def common_mistake(unit):
    if unit in allowed_units:
        return unit
    if unit.replace('ter', 'tre') in allowed_units:
        return unit.replace('ter', 'tre')
    if unit.replace('feet', 'foot') in allowed_units:
        return unit.replace('feet', 'foot')
    return unit

def create_placeholder_image(image_save_path):
    try:
        placeholder_image = Image.new('RGB', (100, 100), color='black')
        placeholder_image.save(image_save_path)
    except Exception as e:
        return

def download_image(image_link, save_folder, retries=3, delay=3):
    if not isinstance(image_link, str):
        return

    filename = Path(image_link).name
    image_save_path = os.path.join(save_folder, filename)

    if os.path.exists(image_save_path):
        return

    for _ in range(retries):
        try:
            urllib.request.urlretrieve(image_link, image_save_path)
            return
        except:
            time.sleep(delay)

    create_placeholder_image(image_save_path) #Create a black placeholder image for invalid links/images

def download_images(image_links, download_folder, allow_multiprocessing=True):
    if not os.path.exists(download_folder):
        os.makedirs(download_folder)

    if allow_multiprocessing:
        download_image_partial = partial(
            download_image, save_folder=download_folder, retries=3, delay=3)

        with multiprocessing.Pool(64) as pool:
            list(tqdm(pool.imap(download_image_partial, image_links), total=len(image_links)))
            pool.close()
            pool.join()
    else:
        for image_link in tqdm(image_links, total=len(image_links)):
            download_image(image_link, save_folder=download_folder, retries=3, delay=3)

In [6]:
def load_data(csv_path, num_samples=1000):
    df = pd.read_csv(csv_path)
    return df.sample(n=num_samples, random_state=42)

train_df = load_data('/content/drive/MyDrive/AmazonML/dataset/train.csv', num_samples=1000)

In [7]:
image_folder = '/content/drive/MyDrive/AmazonML/dataset/images'
os.makedirs(image_folder, exist_ok=True)
download_images(train_df['image_link'], image_folder)

100%|██████████| 1000/1000 [00:01<00:00, 868.90it/s]


In [8]:
texts = []
labels = []

In [9]:
def perform_ocr(image_path):
    image = Image.open(image_path).convert('RGB')
    return pytesseract.image_to_string(image)

In [None]:
for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
    img_path = os.path.join(image_folder, row['image_link'].split('/')[-1])
    if os.path.exists(img_path):
        ocr_text = perform_ocr(img_path)
        combined_text = f"{ocr_text} {row['entity_name']}"
        texts.append(combined_text)

        unit = row['entity_value'].split()[-1]  # Assuming the unit is always the last word
        if unit in allowed_units:
            labels.append(allowed_units.index(unit))
        else:
            labels.append(-1)  # Unknown unit

 40%|████      | 404/1000 [08:25<14:49,  1.49s/it]

In [None]:
X = texts
y = [label for label in labels if label != -1]
X = [text for text, label in zip(X, labels) if label != -1]

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
vectorizer = TfidfVectorizer(max_features=5000)
X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

In [None]:
model = LogisticRegression(multi_class='ovr', max_iter=2500)
model.fit(X_train_tfidf, y_train)

In [None]:
y_train_pred = model.predict(X_train_tfidf)
y_test_pred = model.predict(X_test_tfidf)

In [None]:
train_f1 = f1_score(y_train, y_train_pred, average='weighted')
test_f1 = f1_score(y_test, y_test_pred, average='weighted')

print(f"Train F1 score: {train_f1:.4f}")
print(f"Test F1 score: {test_f1:.4f}")

In [None]:
print("Classification Report for Test Data:")
print(classification_report(y_test, y_test_pred, target_names=allowed_units))

In [None]:
# Function to predict unit for a new image
def predict_unit(image_path, entity_name):
    ocr_text = perform_ocr(image_path)
    combined_text = f"{ocr_text} {entity_name}"
    features = vectorizer.transform([combined_text])
    predicted_index = model.predict(features)[0]
    return allowed_units[predicted_index]

In [None]:
# Test the model on a sample image
sample_image_path = '/content/drive/MyDrive/AmazonML/dataset/images/7185b+0uzML.jpg'
sample_entity_name = 'item_weight'
predicted_unit = predict_unit(sample_image_path, sample_entity_name)
print(f"Predicted unit for {sample_entity_name}: {predicted_unit}")