In [6]:
import boto3

# Initialize Textract client with region
textract = boto3.client('textract', region_name='us-east-1')

def extract_text_from_image(image_path):
    with open(image_path, 'rb') as image_file:
        # Call Amazon Textract
        response = textract.detect_document_text(
            Document={'Bytes': image_file.read()}
        )
    
    # Extract text
    detected_text = []
    for item in response["Blocks"]:
        if item["BlockType"] == "LINE":
            detected_text.append(item["Text"])
    
    return " ".join(detected_text)

# Example usage
image_path = "path_to_your_image.jpg"
extracted_text = extract_text_from_image(image_path)
print(extracted_text)

FileNotFoundError: [Errno 2] No such file or directory: 'path_to_your_image.jpg'

In [7]:
import pandas as pd
from utils import download_images

# Load your dataset (assumes 'train.csv' has a column 'image_link')
df = pd.read_csv('dataset/train.csv')

# Extract image URLs from the 'image_link' column
image_links = df['image_link'].tolist()

# Define the folder where you want to download the images
download_folder = "downloaded_images"

# Call the download_images function to download all images
download_images(image_links=image_links, download_folder=download_folder, allow_multiprocessing=True)

ModuleNotFoundError: No module named 'utils'

In [15]:
# main.py

import os
import sys

# Check if '__file__' is defined, otherwise use the current working directory
if '__file__' in globals():
    project_root = os.path.dirname(os.path.abspath(__file__))
else:
    project_root = os.getcwd()  # Use current working directory in interactive environments

sys.path.append(project_root)

import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch import nn
from torchvision import models, transforms

# Now these imports should work
from src.utils import download_images
from src.constants import ALLOWED_UNITS

# Define a custom dataset class
class ProductImageDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_url = self.data.loc[idx, 'image_link']
        image = download_images([img_url])[0]  # Assuming this returns a PIL Image
        
        if self.transform:
            image = self.transform(image)
        
        # For training data
        if 'entity_value' in self.data.columns:
            label = self.data.loc[idx, 'entity_value']
            return image, label
        
        # For test data
        return image, self.data.loc[idx, 'index']

# Define the model
class EntityExtractionModel(nn.Module):
    def __init__(self, num_classes):
        super(EntityExtractionModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Main training function
def train_model(train_csv, val_csv, num_epochs=10):
    # Set up datasets and dataloaders
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = ProductImageDataset(train_csv, transform=transform)
    val_dataset = ProductImageDataset(val_csv, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    # Initialize model, loss function, and optimizer
    model = EntityExtractionModel(num_classes=len(ALLOWED_UNITS))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {loss.item():.4f}, '
              f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Accuracy: {100 * correct / total:.2f}%')

    return model

# Function to generate predictions
def generate_predictions(model, test_csv, output_csv):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = ProductImageDataset(test_csv, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    predictions = []
    model.eval()
    with torch.no_grad():
        for inputs, indices in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            for idx, pred in zip(indices, predicted):
                predictions.append([idx, f"{pred.item()} {ALLOWED_UNITS[pred.item()]}"])

    pd.DataFrame(predictions, columns=['index', 'prediction']).to_csv(output_csv, index=False)

# Main execution
if __name__ == "__main__":
    train_csv = 'dataset/train.csv'
    val_csv = 'dataset/val.csv'  # You might need to create this from train.csv
    test_csv = 'dataset/test.csv'
    output_csv = 'test_out.csv'

    model = train_model(train_csv, val_csv)
    generate_predictions(model, test_csv, output_csv)

    # Run sanity check
    from src.sanity import check_output_format
    check_output_format(output_csv)

ModuleNotFoundError: No module named 'src'

In [15]:
import os
import sys

# Set the directory where 'src' is located
project_root = os.getcwd()  # Or replace this with the path to your project root
src_directory = os.path.join(project_root, 'src')

# Add 'src' directory to sys.path
sys.path.append(src_directory)

import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch import nn
from torchvision import models, transforms

# Now these imports should work
from src.utils import download_images
from src.constants import entity_unit_map

import os
import tempfile
from PIL import Image
from src.utils import download_images

from sklearn.preprocessing import LabelEncoder

class ProductImageDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.temp_dir = tempfile.mkdtemp()

        # Initialize LabelEncoder to convert string labels to numerical labels
        if 'entity_value' in self.data.columns:
            self.label_encoder = LabelEncoder()
            self.data['entity_value'] = self.label_encoder.fit_transform(self.data['entity_value'])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_url = self.data.loc[idx, 'image_link']
        download_images([img_url], self.temp_dir)
        filename = os.path.basename(img_url)
        image_path = os.path.join(self.temp_dir, filename)
        image = Image.open(image_path).convert('RGB')  # Ensure image is in RGB mode
        if self.transform:
            image = self.transform(image)

        if 'entity_value' in self.data.columns:
            label = torch.tensor(self.data.loc[idx, 'entity_value'], dtype=torch.long)  # Ensure label is a tensor and long type
            return image, label
        
        return image, self.data.loc[idx, 'index']
# Define the model
class EntityExtractionModel(nn.Module):
    def __init__(self, num_classes):
        super(EntityExtractionModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Main training function
def train_model(train_csv, val_csv, num_epochs=10):
    # Your existing transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = ProductImageDataset(train_csv, transform=transform)
    val_dataset = ProductImageDataset(val_csv, transform=transform)

    num_classes = train_dataset.data['entity_value'].nunique()  # Dynamically set num_classes
    model = EntityExtractionModel(num_classes=num_classes)  # Update the model to use the correct num_classes

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # The rest of your training loop remains the same...

    for epoch in range(num_epochs):
        model.train()
        for inputs, labels in train_loader:
            print(f'Inputs shape: {inputs.shape}, Labels shape: {labels.shape}')  # Debug line
            optimizer.zero_grad()
            outputs = model(inputs)
            labels = labels.long()  # Ensure labels are of type long
            #print(f"Number of unique labels: {self.data['entity_value'].nunique()}")
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                labels = labels.long()  # Ensure labels are of type long
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {loss.item():.4f}, '
              f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Accuracy: {100 * correct / total:.2f}%')

    return model

# Function to generate predictions
def generate_predictions(model, test_csv, output_csv):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = ProductImageDataset(test_csv, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    predictions = []
    model.eval()
    with torch.no_grad():
        for inputs, indices in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            for idx, pred in zip(indices, predicted):
                predictions.append([idx, f"{pred.item()} {ALLOWED_UNITS[pred.item()]}"])

    pd.DataFrame(predictions, columns=['index', 'prediction']).to_csv(output_csv, index=False)

# Main execution
if __name__ == "__main__":
    # Update these paths to use the correct directory structure
    train_csv = os.path.join('dataset', 'train.csv')
    val_csv = os.path.join('dataset', 'test.csv')  # You might need to create this from train.csv
    test_csv = os.path.join('dataset', 'sample_test.csv')
    output_csv = os.path.join('dataset','sample_test_out.csv')

    model = train_model(train_csv, val_csv)
    generate_predictions(model, test_csv, output_csv)

    # Run sanity check
    from src.sanity import check_output_format
    check_output_format(output_csv)

100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
100%|██████████| 1/1 [00:00<00:00,  1.31it/s]
100%|██████████| 1/1 [00:00<00:00,  1.38it/s]
100%|██████████| 1/1 [00:00<00:00,  1.97it/s]
100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
100%|██████████| 1/1 [00:00<00:00,  1.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
100%|██████████| 1/1 [00:00<00:00,  1.22it/s]
100%|██████████| 1/1 [00:00<00:00,  1.11it/s]
100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
100%|██████████| 1/1 [00:00<00:00,  1.42it/s]
100%|██████████| 1/1 [00:00<00:00,  2.09it/s]
100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
100%|██████████| 1/1 [00:00<00:00,  1.11it/s]
100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  1.02it/s]
100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.03it/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  1.36it/s]
100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
100%|██████████| 1/1 [00:00<00:00,  2.48it/s]
100%|██████████| 1/1 [00:00<00:00,  1.70it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.38it/s]
100%|██████████| 1/1 [00:00<00:00,  1.93it/s]
100%|██████████| 1/1 [00:00<00:00,  1.84it/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:00<00:00,  2.04it/s]
100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  1.84it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.37it/s]
100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.10it/s]
100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.09it/s]
100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
100%|██████████| 1/1 [00:00<00:00,  1.98it/s]
100%|██████████| 1/1 [00:00<00:00,  1.84it/s]
100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
100%|██████████| 1/1 [00:00<00:00,  2.07it/s]
100%|██████████| 1/1 [00:00<00:00,  1.79it/s]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.83it/s]
100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 1/1 [00:00<00:00,  1.35it/s]
100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 1/1 [00:00<00:00,  2.09it/s]
100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
100%|██████████| 1/1 [00:00<00:00,  2.04it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.40it/s]
100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.91it/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
100%|██████████| 1/1 [00:00<00:00,  1.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.42it/s]
100%|██████████| 1/1 [00:00<00:00,  2.37it/s]
100%|██████████| 1/1 [00:00<00:00,  2.49it/s]
100%|██████████| 1/1 [00:00<00:00,  2.36it/s]
100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.34it/s]
100%|██████████| 1/1 [00:00<00:00,  2.22it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.30it/s]
100%|██████████| 1/1 [00:00<00:00,  1.88it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
100%|██████████| 1/1 [00:00<00:00,  2.05it/s]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:00<00:00,  1.88it/s]
100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
100%|██████████| 1/1 [00:00<00:00,  1.95it/s]
100%|██████████| 1/1 [00:00<00:00,  2.37it/s]
100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
100%|██████████| 1/1 [00:00<00:00,  2.00it/s]
100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
100%|██████████| 1/1 [00:00<00:00,  2.33it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.33it/s]
100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:01<00:00,  1.50s/it]
100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
100%|██████████| 1/1 [00:00<00:00,  2.38it/s]
100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 1/1 [00:00<00:00,  2.44it/s]
100%|██████████| 1/1 [00:00<00:00,  1.94it/s]
100%|██████████| 1/1 [00:00<00:00,  2.51it/s]
100%|██████████| 1/1 [00:00<00:00,  1.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
100%|██████████| 1/1 [00:00<00:00,  2.36it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  1.97it/s]
100%|██████████| 1/1 [00:00<00:00,  2.10it/s]
100%|██████████| 1/1 [00:00<00:00,  2.37it/s]
100%|██████████| 1/1 [00:00<00:00,  2.53it/s]
100%|██████████| 1/1 [00:00<00:00,  2.45it/s]
100%|██████████| 1/1 [00:00<00:00,  2.49it/s]
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
100%|██████████| 1/1 [00:00<00:00,  2.44it/s]
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:00<00:00,  1.24it/s]
100%|██████████| 1/1 [00:00<00:00,  1.95it/s]
100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  1.22it/s]
100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 1/1 [00:00<00:00,  2.23it/s]
100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.38it/s]
100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
100%|██████████| 1/1 [00:00<00:00,  1.83it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.36it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.19it/s]
100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 1/1 [00:00<00:00,  2.04it/s]
100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
100%|██████████| 1/1 [00:00<00:00,  2.46it/s]
100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
100%|██████████| 1/1 [00:00<00:00,  2.40it/s]
100%|██████████| 1/1 [00:00<00:00,  2.45it/s]
100%|██████████| 1/1 [00:00<00:00,  2.52it/s]
100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
100%|██████████| 1/1 [00:00<00:00,  2.56it/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:00<00:00,  2.40it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.33it/s]
100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 1/1 [00:00<00:00,  2.38it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:00<00:00,  2.34it/s]
100%|██████████| 1/1 [00:00<00:00,  2.36it/s]
100%|██████████| 1/1 [00:00<00:00,  2.38it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:00<00:00,  2.11it/s]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  1.30it/s]
100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
100%|██████████| 1/1 [00:00<00:00,  1.85it/s]
100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 1/1 [00:00<00:00,  1.93it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  2.01it/s]
100%|██████████| 1/1 [00:01<00:00,  1.04s/it]
100%|██████████| 1/1 [00:00<00:00,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00,  1.59it/s]
100%|██████████| 1/1 [00:00<00:00,  1.74it/s]
100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:00<00:00,  1.38it/s]
100%|██████████| 1/1 [00:00<00:00,  2.10it/s]
100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
100%|██████████| 1/1 [00:00<00:00,  2.42it/s]
100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
100%|██████████| 1/1 [00:00<00:00,  1.85it/s]
100%|██████████| 1/1 [00:00<00:00,  1.13it/s]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:00<00:00,

Inputs shape: torch.Size([32, 3, 224, 224]), Labels shape: torch.Size([32])


100%|██████████| 1/1 [00:00<00:00,  1.22it/s]
100%|██████████| 1/1 [00:00<00:00,  1.07it/s]
100%|██████████| 1/1 [00:00<00:00,  1.31it/s]
100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
100%|██████████| 1/1 [00:01<00:00,  1.12s/it]
100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
100%|██████████| 1/1 [00:01<00:00,  1.10s/it]
100%|██████████| 1/1 [00:01<00:00,  1.43s/it]
100%|██████████| 1/1 [00:01<00:00,  1.58s/it]
100%|██████████| 1/1 [00:01<00:00,  1.54s/it]
100%|██████████| 1/1 [00:01<00:00,  1.41s/it]
100%|██████████| 1/1 [00:01<00:00,  1.28s/it]
100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
100%|██████████| 1/1 [00:01<00:00,  1.70s/it]
100%|██████████| 1/1 [00:01<00:00,  1.40s/it]
100%|██████████| 1/1 [00:01<00:00,  1.54s/it]
100%|██████████| 1/1 [00:01<00:00,  1.74s/it]
100%|██████████| 1/1 [00:01<00:00,  1.49s/it]
100%|██████████| 1/1 [00:01<00:00,  1.35s/it]
100%|██████████| 1/1 [00:01<00:00,  1.39s/it]
100%|██████████| 1/1 [00:01<00:00,

In [15]:
import os
import sys
import tempfile
import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch import nn
from torchvision import models, transforms

# Set the directory where 'src' is located
project_root = os.getcwd()  # Or replace this with the path to your project root
src_directory = os.path.join(project_root, 'src')

# Add 'src' directory to sys.path
sys.path.append(src_directory)

# Now these imports should work
from src.utils import download_images
from src.constants import entity_unit_map

class ProductImageDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.temp_dir = tempfile.mkdtemp()
        
        # Create a mapping of entity values to numerical labels
        self.entity_to_label = {entity: idx for idx, entity in enumerate(entity_unit_map.keys())}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_url = self.data.loc[idx, 'image_link']
        
        # Download the image
        download_images([img_url], self.temp_dir)

        # Construct the image path
        filename = os.path.basename(img_url)
        image_path = os.path.join(self.temp_dir, filename)
        
        # Open the image
        image = Image.open(image_path).convert('RGB')  # Ensure 3-channel images
        
        if self.transform:
            image = self.transform(image)
        
        # For training data
        if 'entity_value' in self.data.columns:
            entity_value = self.data.loc[idx, 'entity_value']
            label = self.entity_to_label[entity_value]
            return image, torch.tensor(label, dtype=torch.long)
        
        # For test data
        return image, self.data.loc[idx, 'index']

# Define the model
class EntityExtractionModel(nn.Module):
    def __init__(self, num_classes):
        super(EntityExtractionModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Main training function
def train_model(train_csv, val_csv, num_epochs=10):
    # Set up datasets and dataloaders
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = ProductImageDataset(train_csv, transform=transform)
    val_dataset = ProductImageDataset(val_csv, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    # Initialize model, loss function, and optimizer
    model = EntityExtractionModel(num_classes=len(entity_unit_map))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            labels = labels.squeeze()  # Ensure labels are the correct shape
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                labels = labels.squeeze()  # Ensure labels are the correct shape
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {loss.item():.4f}, '
              f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Accuracy: {100 * correct / total:.2f}%')

    return model

# Function to generate predictions
def generate_predictions(model, test_csv, output_csv):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = ProductImageDataset(test_csv, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    predictions = []
    model.eval()
    with torch.no_grad():
        for inputs, indices in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            for idx, pred in zip(indices, predicted):
                entity = list(entity_unit_map.keys())[pred.item()]
                unit = entity_unit_map[entity]
                predictions.append([idx, f"{entity} {unit}"])

    pd.DataFrame(predictions, columns=['index', 'prediction']).to_csv(output_csv, index=False)

# Main execution
if __name__ == "__main__":
    # Update these paths to use the correct directory structure
    train_csv = os.path.join(project_root, 'dataset', 'train.csv')
    val_csv = os.path.join(project_root, 'dataset', 'test.csv')  # Using test.csv as validation for now
    test_csv = os.path.join(project_root, 'dataset', 'sample_test.csv')
    output_csv = os.path.join(project_root, 'dataset', 'sample_test_out.csv')

    model = train_model(train_csv, val_csv)
    generate_predictions(model, test_csv, output_csv)

    # Run sanity check
    from src.sanity import check_output_format
    check_output_format(output_csv)

100%|██████████| 1/1 [00:00<00:00,  1.86it/s]


KeyError: '21.0 centimetre'