In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoImageProcessor, ResNetForImageClassification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
from sklearn.metrics import classification_report, confusion_matrix

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TrademarkDataset(Dataset):
    def __init__(self, df, target_column, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.target_column = target_column
        
        # Encode target labels
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(df[target_column].str.split(',').str[0])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            img_path = os.path.join(self.img_dir, self.df.iloc[idx]['image_name'])
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return image, label
        except Exception as e:
            print(f"Error loading image {self.df.iloc[idx]['image_name']}: {str(e)}")
            raise

In [3]:
def validate_image_existence(df, img_dir):
    """
    Check and filter the DataFrame to keep only rows where images exist and can be opened.
    """
    missing_images = []
    corrupted_images = []
    valid_rows = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Checking images"):
        img_path = os.path.join(img_dir, row['image_name'])
        try:
            # Try to open the image to check if it's valid
            with Image.open(img_path) as img:
                img.verify()  # Verify it's actually an image
            valid_rows.append(True)
        except (FileNotFoundError, Image.UnidentifiedImageError, IOError):
            valid_rows.append(False)
            if not os.path.exists(img_path):
                missing_images.append(row['image_name'])
            else:
                corrupted_images.append(row['image_name'])
    
    valid_df = df[valid_rows].copy()
    
    print(f"\nTotal images in CSV: {len(df)}")
    print(f"Missing images: {len(missing_images)}")
    print(f"Corrupted images: {len(corrupted_images)}")
    print(f"Valid images: {len(valid_df)}")
    
    return valid_df, missing_images, corrupted_images

In [4]:
def train_trademark_classifier(df, target_column, img_dir, num_epochs=10, batch_size=32):
    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
    
    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = TrademarkDataset(train_df, target_column, img_dir, transform)
    val_dataset = TrademarkDataset(val_df, target_column, img_dir, transform)
    test_dataset = TrademarkDataset(test_df, target_column, img_dir, transform)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Initialize model
    num_classes = len(train_dataset.label_encoder.classes_)
    model = ResNetForImageClassification.from_pretrained(
        "microsoft/resnet-50",
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Define optimizer and loss function
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    criterion = torch.nn.CrossEntropyLoss()
    
    # Training loop
    best_val_acc = 0
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).logits
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        # Print epoch statistics
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Acc: {100.*val_correct/val_total:.2f}%')
        
        # Save best model
        val_acc = 100.*val_correct/val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'best_model_{target_column}.pth')
    
    # Test phase
    model.load_state_dict(torch.load(f'best_model_{target_column}.pth'))
    model.eval()
    
    test_predictions = []
    test_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images).logits
            _, predicted = outputs.max(1)
            test_predictions.extend(predicted.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())
    
    # Convert numeric predictions back to original labels
    pred_labels = train_dataset.label_encoder.inverse_transform(test_predictions)
    true_labels = train_dataset.label_encoder.inverse_transform(test_labels)
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(true_labels, pred_labels))
    
    return model, train_dataset.label_encoder

In [5]:
def main():
    print("Starting trademark classification pipeline...")
    
    # Load data
    df = pd.read_csv('./data/csv/pretrain_fill.csv')
    img_dir = './data/img/'  # Updated path to match your structure
    
    # Validate images first
    print("\nValidating image files...")
    valid_df, missing_images, corrupted_images = validate_image_existence(df, img_dir)
    
    # Save validation results
    if len(missing_images) > 0 or len(corrupted_images) > 0:
        print("\nWarning: Some images are missing or corrupted!")
        
        # Save missing images list
        with open('missing_images.txt', 'w') as f:
            f.write("Missing images:\n")
            for img in missing_images:
                f.write(f"{img}\n")
            f.write("\nCorrupted images:\n")
            for img in corrupted_images:
                f.write(f"{img}\n")
        print("Image issues list saved to 'missing_images.txt'")
        
        # Save valid dataset
        valid_df.to_csv('pretrain_fill_valid.csv', index=False)
        print("Valid dataset saved to 'pretrain_fill_valid.csv'")
    
    # Check if we have enough data to proceed
    if len(valid_df) < 10:
        print("Error: Not enough valid images to train. Please check your dataset.")
        return
    
    # Train models for different hierarchical levels
    # hierarchical_levels = ['target', 'target_h1', 'target_h2', 'target_h3']
    hierarchical_levels = ['target_h1']
    
    for level in hierarchical_levels:
        print(f"\n{'='*50}")
        print(f"Training model for {level}")
        print(f"{'='*50}")
        
        try:
            model, label_encoder = train_trademark_classifier(valid_df, level, img_dir)
            
            # Save label encoder for later use
            import joblib
            joblib.dump(label_encoder, f'label_encoder_{level}.pkl')
            print(f"\nModel and label encoder for {level} saved successfully")
            
        except Exception as e:
            print(f"\nError training model for {level}: {str(e)}")
            continue

In [None]:
if __name__ == "__main__":
    main()

Starting trademark classification pipeline...

Validating image files...


Checking images: 100%|██████████| 158511/158511 [03:26<00:00, 766.65it/s]  



Total images in CSV: 158511
Missing images: 55178
Corrupted images: 12
Valid images: 103321

Image issues list saved to 'missing_images.txt'
Valid dataset saved to 'pretrain_fill_valid.csv'

Training model for target_h1


Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([30]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([30, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/10:  24%|██▍       | 541/2261 [10:09:18<32:17:10, 67.58s/it]    
