In [None]:
from google.colab import userdata
userdata.get('KAGGLE_JSON')

In [None]:
import os
from google.colab import userdata

# Load Kaggle credentials from Colab secrets
kaggle_json_content = userdata.get('KAGGLE_JSON')

kaggle_config_dir = os.path.join(os.path.expanduser('~'), '.config', 'kaggle')
os.makedirs(kaggle_config_dir, exist_ok=True)

with open(os.path.join(kaggle_config_dir, 'kaggle.json'), 'w') as f:
    f.write(kaggle_json_content)

os.chmod(os.path.join(kaggle_config_dir, 'kaggle.json'), 0o600)

print(f"Kaggle credentials loaded to {kaggle_config_dir}.")

In [None]:
import os

kaggle_file_path = os.path.join(os.path.expanduser('~'), '.kaggle', 'kaggle.json')

if os.path.exists(kaggle_file_path):
    print(f"kaggle.json found at: {kaggle_file_path}")
else:
    print(f"kaggle.json NOT found at: {kaggle_file_path}")

In [None]:
import os
import kaggle
import zipfile
import pandas as pd
from pathlib import Path

class AtlasDatasetDownloader:


    def __init__(self, data_dir='data/'):
        self.data_dir = Path(data_dir)
        self.raw_dir = self.data_dir / 'raw'
        self.raw_dir.mkdir(parents=True, exist_ok=True)

    def download_from_kaggle(self):

        print("Downloading Atlas dataset from Kaggle...")

        # Download dataset
        kaggle.api.dataset_download_files(
            'silverstone1903/atlas-e-commerce-clothing-product-categorization',
            path=str(self.raw_dir),
            unzip=True
        )

        print(f"Dataset downloaded to {self.raw_dir}")

    def load_metadata(self):

        csv_path = self.raw_dir / 'atlas_data.csv'

        if not csv_path.exists():
            # Try feather format
            csv_path = self.raw_dir / 'atlas_data.feather'
            df = pd.read_feather(csv_path)
        else:
            df = pd.read_csv(csv_path)

        print(f"Loaded {len(df)} image records")
        print(f"Columns: {df.columns.tolist()}")

        return df

    def verify_dataset(self, df):

        images_dir = self.raw_dir / 'images'

        if not images_dir.exists():
            raise FileNotFoundError(f"Images directory not found at {images_dir}")

        # Count available images
        image_files = list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.png'))
        print(f"Found {len(image_files)} image files")

        # Check for missing images
        missing_count = 0
        # Assuming 'item_ID' column contains the image filenames
        image_column_name = 'item_ID'

        if image_column_name not in df.columns:
             raise ValueError(f"Column '{image_column_name}' not found in the DataFrame.")


        for idx, row in df.iterrows():
            img_path = images_dir / row[image_column_name]
            if not img_path.exists():
                missing_count += 1

        print(f"Missing images: {missing_count}/{len(df)}")

        return df

# Usage
if __name__ == "__main__":
    downloader = AtlasDatasetDownloader()
    downloader.download_from_kaggle()
    df = downloader.load_metadata()
    df = downloader.verify_dataset(df)

# Data Exploration and Analysis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

class AtlasDatasetAnalyzer:
    """Analyze Atlas dataset structure and statistics"""

    def __init__(self, df):
        self.df = df

    def parse_taxonomy(self):
        """Parse hierarchical category paths"""
        # Create the 'category_path' column by combining existing columns
        self.df['category_path'] = self.df['gender'] + ' > ' + self.df['category'] + ' > ' + self.df['sub-category']

        # Category path format: Gender > ClothingType > SpecificCategory
        self.df['level1'] = self.df['category_path'].apply(lambda x: x.split(' > ')[0] if isinstance(x, str) and len(x.split(' > ')) > 0 else None)
        self.df['level2'] = self.df['category_path'].apply(lambda x: x.split(' > ')[1] if isinstance(x, str) and len(x.split(' > ')) > 1 else None)
        self.df['level3'] = self.df['category_path'].apply(lambda x: x.split(' > ')[2] if isinstance(x, str) and len(x.split(' > ')) > 2 else None)

        return self.df

    def analyze_distribution(self):
        """Analyze category distribution"""
        print("\n=== Level 1 (Gender) Distribution ===")
        print(self.df['level1'].value_counts())

        print("\n=== Level 2 (Clothing Type) Distribution ===")
        print(self.df['level2'].value_counts())

        print("\n=== Level 3 (Specific Category) Distribution ===")
        print(self.df['level3'].value_counts().head(20))

        print(f"\nTotal unique category paths: {self.df['category_path'].nunique()}")

    def visualize_distribution(self):
        """Visualize category distributions"""
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))

        # Level 1 distribution
        level1_counts = self.df['level1'].value_counts()
        axes[0, 0].bar(level1_counts.index, level1_counts.values, color='steelblue')
        axes[0, 0].set_title('Level 1 (Gender) Distribution', fontsize=14, fontweight='bold')
        axes[0, 0].set_xlabel('Gender')
        axes[0, 0].set_ylabel('Count')
        axes[0, 0].tick_params(axis='x', rotation=45)

        # Level 2 distribution
        level2_counts = self.df['level2'].value_counts().head(10)
        axes[0, 1].barh(level2_counts.index, level2_counts.values, color='coral')
        axes[0, 1].set_title('Top 10 Level 2 (Clothing Type) Distribution', fontsize=14, fontweight='bold')
        axes[0, 1].set_xlabel('Count')

        # Level 3 distribution (top 20)
        level3_counts = self.df['level3'].value_counts().head(20)
        axes[1, 0].barh(level3_counts.index, level3_counts.values, color='mediumseagreen')
        axes[1, 0].set_title('Top 20 Level 3 (Specific Category) Distribution', fontsize=14, fontweight='bold')
        axes[1, 0].set_xlabel('Count')
        axes[1, 0].invert_yaxis()

        # Category path distribution (top 15)
        path_counts = self.df['category_path'].value_counts().head(15)
        axes[1, 1].barh(range(len(path_counts)), path_counts.values, color='mediumpurple')
        axes[1, 1].set_yticks(range(len(path_counts)))
        axes[1, 1].set_yticklabels([p[:40] + '...' if len(p) > 40 else p for p in path_counts.index], fontsize=8)
        axes[1, 1].set_title('Top 15 Complete Category Paths', fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel('Count')
        axes[1, 1].invert_yaxis()

        plt.tight_layout()
        plt.savefig('data/category_distribution.png', dpi=300, bbox_inches='tight')
        plt.show()

    def check_class_imbalance(self):
        """Check for class imbalance issues"""
        category_counts = self.df['category_path'].value_counts()

        print("\n=== Class Imbalance Analysis ===")
        print(f"Most frequent category: {category_counts.iloc[0]} samples")
        print(f"Least frequent category: {category_counts.iloc[-1]} samples")
        print(f"Imbalance ratio: {category_counts.iloc[0] / category_counts.iloc[-1]:.2f}:1")

        # Categories with < 100 samples
        rare_categories = category_counts[category_counts < 100]
        print(f"\nCategories with < 100 samples: {len(rare_categories)}")

        return category_counts

# Usage
df = pd.read_csv('data/raw/atlas_data.csv')
analyzer = AtlasDatasetAnalyzer(df)
df = analyzer.parse_taxonomy()
analyzer.analyze_distribution()
analyzer.visualize_distribution()
category_counts = analyzer.check_class_imbalance()

#  Data Preprocessing and Augmentation

In [None]:
# src/data/preprocessing.py
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms

class AtlasImagePreprocessor:


    def __init__(self, target_size=(224, 224), normalize=True):
        self.target_size = target_size
        self.normalize = normalize

        # ImageNet statistics for transfer learning
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

    def load_image(self, image_path):

        img = Image.open(image_path).convert('RGB')
        return img

    def resize_image(self, img):

        # Calculate scaling factor
        width, height = img.size
        scale = min(self.target_size[0] / width, self.target_size[1] / height)

        new_width = int(width * scale)
        new_height = int(height * scale)

        # Resize
        img = img.resize((new_width, new_height), Image.LANCZOS)

        # Create padded image
        padded_img = Image.new('RGB', self.target_size, (255, 255, 255))
        paste_x = (self.target_size[0] - new_width) // 2
        paste_y = (self.target_size[1] - new_height) // 2
        padded_img.paste(img, (paste_x, paste_y))

        return padded_img

    def normalize_image(self, img):

        img_array = np.array(img).astype(np.float32) / 255.0

        if self.normalize:
            img_array = (img_array - self.mean) / self.std

        return img_array

    def preprocess(self, image_path):

        img = self.load_image(image_path)
        img = self.resize_image(img)
        img_array = self.normalize_image(img)

        # Convert to tensor (C, H, W)
        img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)

        return img_tensor

# Data augmentation for training
class AtlasAugmentation:


    def __init__(self, target_size=(224, 224)):
        self.train_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(target_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def get_train_transform(self):
        return self.train_transform

    def get_val_transform(self):
        return self.val_transform


In [None]:
!pip install opencv-python

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from pathlib import Path
from PIL import Image # Import Image here

class AtlasDataset(Dataset):


    def __init__(self, df, images_dir, transform=None, target_type='path'):

        self.df = df.reset_index(drop=True)
        self.images_dir = Path(images_dir)
        self.transform = transform
        self.target_type = target_type

        # Create label encodings
        if target_type == 'path':
            self.classes = sorted(df['category_path'].unique())
        else:
            self.classes = sorted(df['level3'].unique())

        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load image
        # Use 'item_ID' column for image filename and append extension
        img_filename = row['item_ID'] + '.jpg' # Assuming .jpg extension
        img_path = self.images_dir / img_filename
        img = Image.open(img_path).convert('RGB')

        # Apply transforms
        if self.transform:
            img = self.transform(img)

        # Get label
        if self.target_type == 'path':
            label = self.class_to_idx[row['category_path']]
        else:
            label = self.class_to_idx[row['level3']]

        return img, label

def create_data_splits(df, test_size=0.30, val_size=0.05, random_state=42):

    # First split: train+val vs test
    train_val_df, test_df = train_test_split(
        df,
        test_size=test_size,
        stratify=df['category_path'],
        random_state=random_state
    )

    # Second split: train vs val
    val_ratio = val_size / (1 - test_size)
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=val_ratio,
        stratify=train_val_df['category_path'],
        random_state=random_state
    )

    print(f"Train set: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
    print(f"Val set: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
    print(f"Test set: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")

    return train_df, val_df, test_df

# Save splits
def save_splits(train_df, val_df, test_df, output_dir='data/splits'):
    """Save data splits to CSV files"""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    train_df.to_csv(output_dir / 'train.csv', index=False)
    val_df.to_csv(output_dir / 'val.csv', index=False)
    test_df.to_csv(output_dir / 'test.csv', index=False)

    print(f"Splits saved to {output_dir}")

**Model 1: ResNet34 Image Classifier**

In [None]:
# src/models/resnet_classifier.py
import torch
import torch.nn as nn
from torchvision import models

class ResNet34Classifier(nn.Module):


    def __init__(self, num_classes=52, pretrained=True, dropout=0.5):
        super(ResNet34Classifier, self).__init__()

        # Load pre-trained ResNet34
        self.resnet = models.resnet34(pretrained=pretrained)

        # Get number of features from last layer
        num_features = self.resnet.fc.in_features

        # Replace final fully connected layer
        self.resnet.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)

    def freeze_backbone(self):

        for param in self.resnet.parameters():
            param.requires_grad = False

        # Unfreeze final layer
        for param in self.resnet.fc.parameters():
            param.requires_grad = True

    def unfreeze_backbone(self):

        for param in self.resnet.parameters():
            param.requires_grad = True

# Usage
model = ResNet34Classifier(num_classes=52, pretrained=True)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


**Model 2: Attention-Based Sequence-to-Sequence Model**

In [None]:
import torch
import torch.nn as nn
from torchvision import models

class Encoder(nn.Module):


    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()

        resnet = models.resnet101(pretrained=True)

        # Remove linear and pool layers
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # Adaptive pooling to fixed spatial size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        # Fine-tune only top layers
        self.fine_tune()

    def forward(self, images):

        out = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        return out

    def fine_tune(self, fine_tune=True):

        for p in self.resnet.parameters():
            p.requires_grad = False

        # Fine-tune top 2 blocks
        if fine_tune:
            for c in list(self.resnet.children())[5:]:
                for p in c.parameters():
                    p.requires_grad = True



class Attention(nn.Module):


    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()

        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):

        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha



class DecoderWithAttention(nn.Module):


    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout_layer = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.init_weights()

    def init_weights(self):

        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):

        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort by decreasing caption length
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)

        # Decode lengths
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors for predictions and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(encoder_out.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(encoder_out.device)

        # Decode step by step
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t])
            )
            preds = self.fc(self.dropout_layer(h))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind



class Seq2SeqModel(nn.Module):


    def __init__(self, encoder, decoder):
        super(Seq2SeqModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions, caption_lengths):
        encoder_out = self.encoder(images)
        predictions, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
            encoder_out, captions, caption_lengths
        )
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind


In [None]:
import os

config_dir = 'configs'
config_filename = 'training_config.yaml'
config_path = os.path.join(config_dir, config_filename)

# Create the configs directory if it doesn't exist
os.makedirs(config_dir, exist_ok=True)

# YAML content as a string
yaml_content = """# configs/training_config.yaml
model_type: 'resnet34'  # or 'seq2seq'

# Data configuration
data:
  images_dir: 'data/raw/images'
  train_csv: 'data/splits/train.csv'
  val_csv: 'data/splits/val.csv'
  test_csv: 'data/splits/test.csv'
  image_size: 224
  num_workers: 4

# Model configuration
model:
  num_classes: 52
  pretrained: true
  dropout: 0.5

# Training hyperparameters
training:
  batch_size: 64
  epochs: 100 # Increased epochs
  learning_rate: 0.001
  weight_decay: 0.0001
  optimizer: 'adam'  # or 'sgd'
  scheduler: 'reduce_on_plateau'
  patience: 10 # Increased patience

# Seq2Seq specific
seq2seq:
  encoder_lr: 0.0001
  decoder_lr: 0.0004
  attention_dim: 512
  embed_dim: 512
  decoder_dim: 512
  grad_clip: 5.0
  alpha_c: 1.0  # regularization for doubly stochastic attention
  beam_width: 5
"""

# Write the content to the file
with open(config_path, 'w') as f:
    f.write(yaml_content)

print(f"Created {config_path}")

In [None]:
import pandas as pd
# from src.data.dataset import create_data_splits, save_splits # Assuming AtlasDataset and other related functions are in src.data.dataset

# Load the full dataset (assuming it's available from a previous step)
# If df is not available in the environment, you might need to reload it:
try:
    df
except NameError:
    print("Loading data from data/raw/atlas_data.csv...")
    df = pd.read_csv('data/raw/atlas_data.csv')

# df already has the 'category_path' column from previous analysis
if 'category_path' not in df.columns:
    print("Creating 'category_path' column...")
    df['category_path'] = df['gender'] + ' > ' + df['category'] + ' > ' + df['sub-category']


# Get category counts
category_counts = df['category_path'].value_counts()

# Identify categories with only one sample
single_sample_categories = category_counts[category_counts == 1].index

# Filter out rows belonging to single-sample categories
df_filtered = df[~df['category_path'].isin(single_sample_categories)].copy()

print(f"Original dataset size: {len(df)}")
print(f"Number of categories with a single sample: {len(single_sample_categories)}")
print(f"Filtered dataset size (after removing single-sample categories): {len(df_filtered)}")

# Use the filtered dataframe for splitting
df_to_split = df_filtered



# Create train, val, and test splits using the filtered dataframe
train_df, val_df, test_df = create_data_splits(df_to_split)

# Save the splits to CSV files
save_splits(train_df, val_df, test_df)

print("\nData splits created and saved successfully.")

In [None]:
import os

images_dir = 'data/raw/images'
if os.path.exists(images_dir):
    print(f"Listing files in {images_dir}:")
    files = os.listdir(images_dir)
    for i, file in enumerate(files[:10]):
        print(file)
    if len(files) > 10:
        print("...")
    print(f"Total files found: {len(files)}")
else:
    print(f"Images directory not found at {images_dir}")

In [None]:
import pandas as pd

try:
    train_df
except NameError:
    print("Loading train_df from data/splits/train.csv...")
    train_df = pd.read_csv('data/splits/train.csv')

print("\nFirst 10 item_ID values in train_df:")
print(train_df['item_ID'].head(10).tolist())

**Training Script for ResNet34 Classifier**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import yaml
from pathlib import Path

class ClassifierTrainer:


    def __init__(self, config_path='configs/training_config.yaml'):
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.setup_data()
        self.setup_model()
        self.setup_training()

    def setup_data(self):

        # Load datasets
        train_df = pd.read_csv(self.config['data']['train_csv'])
        val_df = pd.read_csv(self.config['data']['val_csv'])

        # Create augmentation
        augmentation = AtlasAugmentation(target_size=(self.config['data']['image_size'],
                                                      self.config['data']['image_size']))

        # Create datasets
        train_dataset = AtlasDataset(
            train_df,
            self.config['data']['images_dir'],
            transform=augmentation.get_train_transform()
        )

        val_dataset = AtlasDataset(
            val_df,
            self.config['data']['images_dir'],
            transform=augmentation.get_val_transform()
        )

        self.num_classes = train_dataset.classes
        print(f"Number of classes: {len(self.num_classes)}")

        # Create DataLoaders
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.config['training']['batch_size'],
            shuffle=True,
            num_workers=self.config['data']['num_workers'],
            pin_memory=True
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=self.config['training']['batch_size'],
            shuffle=False,
            num_workers=self.config['data']['num_workers'],
            pin_memory=True
        )


    def setup_model(self):

        self.model = ResNet34Classifier(
            num_classes=len(self.num_classes), # Changed from self.num_classes to len(self.num_classes)
            pretrained=self.config['model']['pretrained'],
            dropout=self.config['model']['dropout']
        ).to(self.device)

        # Freeze backbone initially for transfer learning
        self.model.freeze_backbone()

    def setup_training(self):

        self.criterion = nn.CrossEntropyLoss()

        if self.config['training']['optimizer'] == 'adam':
            self.optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config['training']['learning_rate'],
                weight_decay=self.config['training']['weight_decay']
            )
        else:
            self.optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config['training']['learning_rate'],
                momentum=0.9,
                weight_decay=self.config['training']['weight_decay']
            )

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=self.config['training']['patience'],

        )

        self.best_val_loss = float('inf')
        self.patience_counter = 0

    def train_epoch(self, epoch):

        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1} [Train]')
        for images, labels in pbar:
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)

            # Backward pass
            loss.backward()
            self.optimizer.step()

            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })

        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total

        return epoch_loss, epoch_acc

    def validate(self, epoch):

        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc=f'Epoch {epoch+1} [Val]')
            for images, labels in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)

                running_loss += loss.item() * images.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.*correct/total:.2f}%'
                })

        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total

        return epoch_loss, epoch_acc

    def train(self):

        print(f"Training on {self.device}")
        print(f"Total epochs: {self.config['training']['epochs']}")

        # Create directory for model checkpoints if it doesn't exist
        Path('models/checkpoints').mkdir(parents=True, exist_ok=True)

        for epoch in range(self.config['training']['epochs']):
            # Train
            train_loss, train_acc = self.train_epoch(epoch)

            # Validate
            val_loss, val_acc = self.validate(epoch)

            print(f"\nEpoch {epoch+1}/{self.config['training']['epochs']}")
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            # Learning rate scheduling
            self.scheduler.step(val_loss)

            # Save best model
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                }, 'models/checkpoints/best_resnet34.pth')
                print("✓ Saved best model")
            else:
                self.patience_counter += 1

            # Early stopping
            if self.patience_counter >= self.config['training']['patience'] * 2:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break

            # Unfreeze backbone after initial epochs
            if epoch == 10:
                print("\nUnfreezing backbone for fine-tuning...")
                self.model.unfreeze_backbone()
                self.optimizer = optim.Adam(
                    self.model.parameters(),
                    lr=self.config['training']['learning_rate'] / 10
                )

        print("\n✓ Training completed!")

# Usage
if __name__ == "__main__":
    trainer = ClassifierTrainer()
    trainer.train()

**Model Evaluation**

In [None]:

import torch
import numpy as np
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    precision_recall_fscore_support
)
import matplotlib.pyplot as plt
import seaborn as sns
import os # Import os module

class ModelEvaluator:


    def __init__(self, model, test_loader, class_names, device):
        self.model = model
        self.test_loader = test_loader
        self.class_names = class_names
        self.device = device

    def predict(self):

        self.model.eval()
        all_preds = []
        all_labels = []
        all_probs = []

        with torch.no_grad():
            pbar = tqdm(self.test_loader, desc='Predicting')
            for images, labels in pbar:
                images = images.to(self.device)
                outputs = self.model(images)
                probs = torch.softmax(outputs, dim=1)
                _, preds = outputs.max(1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.numpy())
                all_probs.extend(probs.cpu().numpy())

        return np.array(all_preds), np.array(all_labels), np.array(all_probs)

    def calculate_metrics(self, preds, labels):

        # Overall metrics
        accuracy = (preds == labels).mean()

        # Per-class metrics
        precision, recall, f1, support = precision_recall_fscore_support(
            labels, preds, average=None, zero_division=0
        )

        # Micro and macro averages
        micro_f1 = f1_score(labels, preds, average='micro')
        macro_f1 = f1_score(labels, preds, average='macro')

        print(f"\n=== Overall Metrics ===")
        print(f"Accuracy: {accuracy*100:.2f}%")
        print(f"Micro F1-Score: {micro_f1:.4f}")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        return {
            'accuracy': accuracy,
            'micro_f1': micro_f1,
            'macro_f1': macro_f1,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': support
        }

    def plot_confusion_matrix(self, preds, labels, save_path='results/confusion_matrix.png'):
        """Plot confusion matrix"""
        # Create directory if it doesn't exist
        output_dir = os.path.dirname(save_path)
        os.makedirs(output_dir, exist_ok=True)

        cm = confusion_matrix(labels, preds)

        plt.figure(figsize=(20, 18))
        sns.heatmap(
            cm,
            annot=False,
            fmt='d',
            cmap='Blues',
            xticklabels=self.class_names,
            yticklabels=self.class_names,
            cbar_kws={'label': 'Count'}
        )
        plt.title('Confusion Matrix - Atlas Clothing Classifier', fontsize=16, fontweight='bold')
        plt.ylabel('True Label', fontsize=14)
        plt.xlabel('Predicted Label', fontsize=14)
        plt.xticks(rotation=90, fontsize=8)
        plt.yticks(rotation=0, fontsize=8)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

        print(f"✓ Confusion matrix saved to {save_path}")

    def generate_classification_report(self, preds, labels, save_path='results/classification_report.txt'):

        # Create directory if it doesn't exist
        output_dir = os.path.dirname(save_path)
        os.makedirs(output_dir, exist_ok=True)

        report = classification_report(
            labels,
            preds,
            target_names=self.class_names,
            digits=4
        )

        print("\n=== Classification Report ===")
        print(report)

        # Save to file
        with open(save_path, 'w') as f:
            f.write(report)

        print(f"✓ Classification report saved to {save_path}")

    def evaluate(self):

        print("Evaluating model on test set...")

        preds, labels, probs = self.predict()
        metrics = self.calculate_metrics(preds, labels)
        self.plot_confusion_matrix(preds, labels)
        self.generate_classification_report(preds, labels)

        return metrics, preds, labels, probs

# Usage
test_df = pd.read_csv('data/splits/test.csv')

# Ensure AtlasAugmentation is available and get the validation transform
try:
    augmentation = AtlasAugmentation()
    val_transform = augmentation.get_val_transform()
except NameError:
    print("AtlasAugmentation class not found. Please ensure the cell defining it is executed.")
    exit() # Exit if the class is not available


test_dataset = AtlasDataset(test_df, 'data/raw/images', transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Load best model
checkpoint = torch.load('models/checkpoints/best_resnet34.pth')
model = ResNet34Classifier(num_classes=len(test_dataset.classes)) # Instantiate model before loading state_dict
model.load_state_dict(checkpoint['model_state_dict'])

# Define device if not already defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


evaluator = ModelEvaluator(model, test_loader, test_dataset.classes, device)
metrics, preds, labels, probs = evaluator.evaluate()