<a href="https://colab.research.google.com/github/ztide-ad/AmazonMLChallenge/blob/main/AmazonML_challenge_adv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytesseract
!sudo apt-get install tesseract-ocr

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pytesseract
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [None]:
import re
import requests
import multiprocessing
import time
from time import time as timer
from pathlib import Path
from functools import partial
import urllib
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
entity_unit_map = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}
}

allowed_units = {unit for entity in entity_unit_map for unit in entity_unit_map[entity]}

In [None]:
def common_mistake(unit):
    if unit in allowed_units:
        return unit
    if unit.replace('ter', 'tre') in allowed_units:
        return unit.replace('ter', 'tre')
    if unit.replace('feet', 'foot') in allowed_units:
        return unit.replace('feet', 'foot')
    return unit

def parse_string(s):
    s_stripped = "" if s is None or str(s) == 'nan' else s.strip()
    if s_stripped == "":
        return None, None

    # Handle range values
    range_pattern = re.compile(r'^\[(\d+(\.\d+)?),\s*(\d+(\.\d+)?)\]\s+([a-zA-Z\s]+)$')
    range_match = range_pattern.match(s_stripped)

    if range_match:
        # For ranges, we'll use the average of the two values
        start, end = float(range_match.group(1)), float(range_match.group(3))
        number = (start + end) / 2
        unit = range_match.group(5)
    else:
        # Original pattern for single values
        pattern = re.compile(r'^-?\d+(\.\d+)?\s+[a-zA-Z\s]+$')
        if not pattern.match(s_stripped):
            raise ValueError(f"Invalid format in {s}")
        parts = s_stripped.split(maxsplit=1)
        number = float(parts[0])
        unit = parts[1]

    unit = common_mistake(unit)
    if unit not in allowed_units:
        raise ValueError(f"Invalid unit [{unit}] found in {s}. Allowed units: {allowed_units}")

    return number, unit

def common_mistake(unit):
    if unit in allowed_units:
        return unit
    if unit.replace('ter', 'tre') in allowed_units:
        return unit.replace('ter', 'tre')
    if unit.replace('feet', 'foot') in allowed_units:
        return unit.replace('feet', 'foot')
    return unit

def create_placeholder_image(image_save_path):
    try:
        placeholder_image = Image.new('RGB', (100, 100), color='black')
        placeholder_image.save(image_save_path)
    except Exception as e:
        return

def download_image(image_link, save_folder, retries=3, delay=3):
    if not isinstance(image_link, str):
        return

    filename = Path(image_link).name
    image_save_path = os.path.join(save_folder, filename)

    if os.path.exists(image_save_path):
        return

    for _ in range(retries):
        try:
            urllib.request.urlretrieve(image_link, image_save_path)
            return
        except:
            time.sleep(delay)

    create_placeholder_image(image_save_path) #Create a black placeholder image for invalid links/images

def download_images(image_links, download_folder, allow_multiprocessing=True):
    if not os.path.exists(download_folder):
        os.makedirs(download_folder)

    if allow_multiprocessing:
        download_image_partial = partial(
            download_image, save_folder=download_folder, retries=3, delay=3)

        with multiprocessing.Pool(64) as pool:
            list(tqdm(pool.imap(download_image_partial, image_links), total=len(image_links)))
            pool.close()
            pool.join()
    else:
        for image_link in tqdm(image_links, total=len(image_links)):
            download_image(image_link, save_folder=download_folder, retries=3, delay=3)

In [None]:
def load_data(csv_path, num_samples=1000):
    df = pd.read_csv(csv_path)
    return df.sample(n=num_samples, random_state=42)

train_df = load_data('/content/drive/MyDrive/AmazonML/dataset/train.csv', num_samples=1000)

In [None]:
image_folder = '/content/drive/MyDrive/AmazonML/dataset/images'
os.makedirs(image_folder, exist_ok=True)
download_images(train_df['image_link'], image_folder)

100%|██████████| 1000/1000 [00:00<00:00, 1182.43it/s]


In [None]:
class ProductDataset(Dataset):
    def __init__(self, dataframe, image_folder, transform=None):
        self.dataframe = dataframe
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.dataframe.iloc[idx]['image_link'].split('/')[-1])
        image = Image.open(img_name).convert('RGB')

        ocr_text = perform_ocr(img_name)

        if self.transform:
            image = self.transform(image)

        entity_name = self.dataframe.iloc[idx]['entity_name']
        entity_value = self.dataframe.iloc[idx]['entity_value']

        return image, ocr_text, entity_name, entity_value

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
dataset = ProductDataset(train_df, image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
def perform_ocr(image_path):
    image = Image.open(image_path).convert('RGB')
    return pytesseract.image_to_string(image)

In [None]:
# NLP model (using a pre-trained model for sequence classification)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(allowed_units))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [None]:
# Change allowed_units to a list
allowed_units = list(allowed_units)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-1)
num_epochs = 1

In [None]:
for epoch in range(num_epochs):
    model.train()
    for images, ocr_texts, entity_names, entity_values in tqdm(dataloader):
        images = images.to(device)

        # Combine OCR text and entity names
        input_texts = [f"{ocr} {name}" for ocr, name in zip(ocr_texts, entity_names)]

        # Tokenize input texts
        inputs = tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt").to(device)

        # Parse entity values
        labels = []
        for val in entity_values:
            try:
                _, unit = parse_string(val)
                if unit in allowed_units:
                    labels.append(allowed_units.index(unit))
                else:
                    labels.append(-1)  # Use -1 as a label for unknown units
            except ValueError:
                labels.append(-1)  # Use -1 as a label for parsing errors

        labels = torch.tensor(labels).to(device)

        # Filter out samples with unknown units or parsing errors
        valid_samples = labels != -1
        if valid_samples.sum() > 0:
            inputs = {k: v[valid_samples] for k, v in inputs.items()}
            labels = labels[valid_samples]

            outputs = model(**inputs, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            print("No valid samples in this batch")

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

  0%|          | 0/32 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/AmazonML/model.pth')

In [None]:
def predict(image_path, entity_name):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    ocr_text = perform_ocr(image.squeeze().permute(1, 2, 0).cpu().numpy())
    input_text = f"{ocr_text} {entity_name}"

    inputs = tokenizer(input_text, padding=True, truncation=True, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    predicted_unit_index = outputs.logits.argmax().item()
    if predicted_unit_index < len(allowed_units):
        predicted_unit = list(allowed_units)[predicted_unit_index]
        return predicted_unit
    else:
        return "Unknown unit"

In [None]:
# Test the model on a sample image
sample_image_path = '/content/drive/MyDrive/AmazonML/dataset/images/81u23a-tF-L.jpg'
sample_entity_name = 'item_weight'
predicted_unit = predict(sample_image_path, sample_entity_name)
print(f"Predicted unit for {sample_entity_name}: {predicted_unit}")