In [1]:
# Cell 1: Import Required Libraries
import pandas as pd
import numpy as np
import os
import pytesseract  # For OCR
from PIL import Image  # For image processing
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score  # Import accuracy_score
from src.utils import download_images  # Assuming this function is defined in utils.py
from src.constants import entity_unit_map, allowed_units  # Importing the unit map
import cv2  # For image processing
import matplotlib.pyplot as plt  # For visualization
import torch  # For BERT and PyTorch
from transformers import BertTokenizer, BertForTokenClassification  # For BERT model
from transformers import Trainer, TrainingArguments  # For training the model
import random
import re
from transformers import EarlyStoppingCallback  # For callbacks

# Set random seed for reproducibility
random.seed(42)
torch.manual_seed(42)

ModuleNotFoundError: No module named 'constants'

In [None]:

# Cell 2: Load and Clean the Dataset
# Load the training and test datasets
train_df = pd.read_csv('dataset/train.csv')
test_df = pd.read_csv('dataset/test.csv')

# Display total null values in the dataset
print("Total null values in training dataset:")
print(train_df.isnull().sum())

# Replace null values in entity_value with a specified string
train_df['entity_value'].fillna('', inplace=True)  # Replace NaN in entity_value with ''

In [None]:
# Cell 3: Download Images
# Download images for both train and test datasets
download_images(train_df['image_link'].tolist(), download_folder='dataset/train_images')
download_images(test_df['image_link'].tolist(), download_folder='dataset/test_images')

# Cell 4: Display Random Images with Entity Name and Value
# Display 5 random images with their entity names and values
def display_random_images(df, num_images=5):
    random_samples = df.sample(num_images)
    plt.figure(figsize=(15, 10))
    
    for i, (index, row) in enumerate(random_samples.iterrows()):
        img_path = os.path.join('dataset/train_images', row['image_link'].split('/')[-1])
        img = plt.imread(img_path)
        plt.subplot(1, num_images, i + 1)
        plt.imshow(img)
        plt.title(f"{row['entity_name']}: {row['entity_value']}")
        plt.axis('off')
    
    plt.show()

display_random_images(train_df, num_images=5)

In [None]:
# Cell 5: Group Data by Entity Name
# Group the training data by entity_name
grouped_train_df = train_df.groupby('entity_name')

# Display the grouped data (optional)
for name, group in grouped_train_df:
    print(f"Entity Name: {name}, Number of Samples: {len(group)}")

In [None]:
# Cell 6: Prepare Data for Model Training
# Create a mapping for allowed units and their abbreviations
unit_abbreviations = {
    # For 'item_weight' and 'maximum_weight_recommendation'
    'gram': ['g', 'gr', 'gm', 'grams', 'grm'],
    'kilogram': ['kg', 'kilograms', 'kgs'],
    'milligram': ['mg', 'milligrams', 'mgs'],
    'microgram': ['µg', 'mcg', 'micrograms'],
    'ounce': ['oz', 'ounces', 'ozs'],
    'pound': ['lb', 'lbs', 'pounds'],
    'ton': ['t', 'tons', 'tonne', 'tonnes'],

    # For 'item_volume'
    'millilitre': ['ml', 'milliliters', 'millilitres'],
    'litre': ['l', 'lit', 'liters', 'litres'],
    'cubic_centimetre': ['cc', 'cm³', 'cubic centimeters', 'cubic centimetres'],
    'cubic_metre': ['m³', 'cubic meters', 'cubic metres'],
    'gallon': ['gal', 'gallons'],
    'quart': ['qt', 'quarts'],
    'pint': ['pt', 'pints'],
    'cup': ['c', 'cups'],

    # For 'voltage'
    'volt': ['v', 'volts'],
    'kilovolt': ['kv', 'kilovolts'],
    'millivolt': ['mv', 'millivolts'],

    # For 'wattage'
    'watt': ['w', 'watts'],
    'kilowatt': ['kw', 'kilowatts'],
    'megawatt': ['mw', 'megawatts'],
    'gigawatt': ['gw', 'gigawatts'],

    # For 'height', 'depth', and 'width'
    'millimetre': ['mm', 'millimeters', 'millimetres'],
    'centimetre': ['cm', 'centimeters', 'centimetres'],
    'metre': ['m', 'meters', 'metres'],
    'kilometre': ['km', 'kilometers', 'kilometres'],
    'inch': ['in', 'inches'],
    'foot': ['ft', 'feet'],
    'yard': ['yd', 'yards'],
    'mile': ['mi', 'miles'],

    # Other common units
    'degree_celsius': ['°C', 'C', 'degrees Celsius'],
    'degree_fahrenheit': ['°F', 'F', 'degrees Fahrenheit'],
    'calorie': ['cal', 'calories'],
    'kilocalorie': ['kcal', 'kcals'],
    'joule': ['j', 'joules'],
    'pascal': ['Pa', 'pascals'],
    'bar': ['bar', 'bars'],
    'psi': ['psi', 'pounds per square inch'],
    'newton': ['N', 'newtons'],
    'fluid_ounce': ['fl oz', 'fluid ounces'],
}

# Prepare the data for BERT encoding
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def create_labels(df):
    labels = []
    for _, row in df.iterrows():
        entity_value = row['entity_value']
        entity_name = row['entity_name']
        
        # Tokenize the entity value
        tokens = tokenizer.tokenize(entity_value)
        
        # Create a label list for the tokens
        label = []
        for token in tokens:
            if token in unit_abbreviations.get(entity_name, []):
                label.append(1)  # Label for the unit (1 for unit)
            else:
                label.append(0)  # Label for non-unit (0 for non-unit)
        
        # Pad the label to match the maximum length
        label += [0] * (tokenizer.model_max_length - len(label))
        
        # Append the label list to the labels
        labels.append(label[:tokenizer.model_max_length])  # Ensure it doesn't exceed max length
    
    return labels

# Generate the training labels
train_labels = create_labels(train_df)

# Prepare the encodings for training
def encode_data(df):
    return tokenizer(df['entity_value'].tolist(), padding=True, truncation=True, return_tensors='pt')

train_encodings = encode_data(train_df)

In [None]:
# Cell 7: Define the Dataset Class
from torch.utils.data import Dataset

class EntityDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

# Create the dataset
train_dataset = EntityDataset(train_encodings, train_labels)


In [None]:
# Cell 8: Define the Model
# Load BERT model for token classification
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=len(entity_unit_map))

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    save_steps=10_000,
    save_total_limit=2,
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
)

# Create a Trainer instance with EarlyStoppingCallback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # Stop training if no improvement
)


In [None]:
# Cell 9: Train the Model
# Train the model
trainer.train()

# Save the model after training
trainer.save_model('./saved_model')  # Specify the directory where you want to save the model

In [None]:
# Cell 10: Evaluate the Model
# Evaluate the model on the training set
train_predictions = trainer.predict(train_dataset)
train_preds = np.argmax(train_predictions.predictions, axis=2)

# Generate classification report
print("Classification Report on Training Set:")
print(classification_report(train_labels, train_preds))

# Calculate and print accuracy
accuracy = accuracy_score(train_labels, train_preds)
print(f"Accuracy on Training Set: {accuracy:.4f}")

In [None]:
# Cell 11: Make Predictions on Test Set
# Prepare test data for predictions
test_encodings = encode_data(test_df)

# Create test labels (this is a placeholder; you need to define how to create labels)
test_labels = create_labels(test_df)  # Generate test labels similarly

# Create the test dataset
test_dataset = EntityDataset(test_encodings, test_labels)

# Make predictions
test_predictions = trainer.predict(test_dataset)
test_preds = np.argmax(test_predictions.predictions, axis=2)

# Integrate unit extraction into the prediction process
def extract_text_from_image(image_path):
    img = Image.open(image_path)
    text = pytesseract.image_to_string(img)
    return text

def find_units_in_text(text, unit_abbreviations):
    found_units = {}
    for unit, abbreviations in unit_abbreviations.items():
        for abbreviation in abbreviations:
            if abbreviation in text:
                found_units[unit] = abbreviation
                break  # Stop checking once we find a match for this unit
    return found_units

In [None]:
# Cell 12: Create Submission File
# Create a DataFrame for submission
submission_df = pd.DataFrame({
    'index': test_df['index'],
    'prediction': [entity_unit_map[row['entity_name']][pred] for row, pred in zip(test_df.iterrows(), test_preds)],
    'extracted_units': [find_units_in_text(extract_text_from_image(os.path.join('dataset/test_images', row['image_link'].split('/')[-1])), unit_abbreviations) for _, row in test_df.iterrows()]
})

# Save the DataFrame to a CSV file
submission_df.to_csv('submission.csv', index=False)

print("Submission file 'submission.csv' created successfully.")


In [None]:


# Function to process a single image
def process_single_image(index):
    row = test_df.iloc[index]
    img_path = os.path.join('dataset/test_images', row['image_link'].split('/')[-1])
    
    # Extract text from the image
    extracted_text = extract_text_from_image(img_path)
    
    # Find units in the extracted text
    found_units = find_units_in_text(extracted_text, unit_abbreviations)
    
    # Get the prediction for this index
    prediction = [entity_unit_map[row['entity_name']][pred] for pred in test_preds[index]]
    
    # Print the results
    print(f"Image: {row['image_link']}")
    print(f"Extracted Text: {extracted_text}")
    print(f"Found Units: {found_units}")
    print(f"Prediction: {prediction}")

# Example usage for a single image (e.g., the first image in the test set)
process_single_image(0)  # Change the index to process a different image