# ImageNet Training with ResNet-50

This notebook provides a complete pipeline for training ResNet-50 on ImageNet-1k dataset with options for:
- Training on small subset for quick experiments
- Finding optimal learning rate using LR Finder
- Full dataset training
- Using pretrained weights
- Replacing MaxPool with strided convolution

## 1. Setup Environment

In [None]:
# Check if running on Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    from google.colab import drive
    drive.mount('/content/drive')
else:
    print("Running locally")

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib tqdm scipy

In [None]:
# Clone or setup the repository
import os

if IN_COLAB:
    # Option 1: Clone from GitHub (if you have uploaded the code)
    # !git clone https://github.com/yourusername/S9_Assignment.git
    # os.chdir('S9_Assignment')
    
    # Option 2: Copy from Google Drive (if you have uploaded the code)
    # !cp -r /content/drive/MyDrive/S9_Assignment .
    # os.chdir('S9_Assignment')
    
    # Option 3: Upload the files directly
    print("Please upload the S9_Assignment folder to Colab or Google Drive")
    
# Add the parent directory to path
sys.path.append('..')

## 2. Import Required Modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import json

# Import custom modules
from models.resnet50_imagenet import resnet50
from dataset.imagenet_loader import create_imagenet_loaders
from utils.lr_finder import LRFinder
from utils.train_test import train_epoch, test_epoch

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

## 3. Configuration

In [None]:
# Training configuration
config = {
    # Dataset
    'dataset_type': 'small',  # Options: 'small', 'medium', 'full', 'tiny_imagenet'
    'data_dir': '/content/imagenet',  # Update this path
    'batch_size': 128,  # Reduce if GPU memory is limited
    'num_workers': 4,
    
    # Model
    'pretrained': False,  # Use pretrained weights
    'replace_maxpool_with_conv': True,  # Replace MaxPool with Conv (default: True)
    
    # Training
    'epochs': 10,  # Increase for full training
    'learning_rate': 0.1,
    'momentum': 0.9,
    'weight_decay': 1e-4,
    'scheduler': 'onecycle',  # Options: 'onecycle', 'cosine', 'step', None
    
    # LR Finder
    'find_lr': True,  # Run LR finder before training
    'lr_finder_iterations': 100,
    
    # Paths
    'checkpoint_dir': './checkpoints',
    'log_dir': './logs',
    'plot_dir': './plots',
}

# Dataset size configurations
dataset_configs = {
    'small': {'subset_percent': 0.01, 'tiny_imagenet': False},  # 1% of ImageNet
    'medium': {'subset_percent': 0.1, 'tiny_imagenet': False},  # 10% of ImageNet
    'full': {'subset_percent': None, 'tiny_imagenet': False},   # Full ImageNet
    'tiny_imagenet': {'subset_percent': None, 'tiny_imagenet': True}  # Tiny ImageNet
}

# Update config based on dataset type
dataset_config = dataset_configs[config['dataset_type']]
config.update(dataset_config)

print("Configuration:")
for key, value in config.items():
    print(f"{key}: {value}")

## 4. Download ImageNet Data (Instructions)

### Option 1: Tiny ImageNet (Recommended for testing)
Tiny ImageNet is a subset with 200 classes and smaller images (64x64).

In [None]:
# Download Tiny ImageNet (only if using tiny_imagenet)
if config.get('tiny_imagenet', False):
    !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    !unzip -q tiny-imagenet-200.zip
    !mv tiny-imagenet-200 /content/tiny-imagenet
    config['data_dir'] = '/content/tiny-imagenet'

### Option 2: Full ImageNet-1k

For full ImageNet-1k dataset:

1. **Register and download from official source**: https://image-net.org/download.php
2. **Required files**:
   - Training images: `ILSVRC2012_img_train.tar` (~138GB)
   - Validation images: `ILSVRC2012_img_val.tar` (~6.3GB)
   - Development kit: `ILSVRC2012_devkit_t12.tar.gz`

3. **Extract the data**:

In [None]:
# Commands to extract ImageNet (run in terminal or adapt for Colab)
# Note: This requires significant storage space (~150GB)

# # Create directories
# !mkdir -p /content/imagenet/train /content/imagenet/val

# # Extract training data
# !tar -xf ILSVRC2012_img_train.tar -C /content/imagenet/train/
# !cd /content/imagenet/train && for f in *.tar; do mkdir -p "${f%.tar}" && tar -xf "$f" -C "${f%.tar}" && rm "$f"; done

# # Extract validation data
# !tar -xf ILSVRC2012_img_val.tar -C /content/imagenet/val/
# # Use the validation ground truth to organize val images into folders
# # You'll need the ILSVRC2012_validation_ground_truth.txt file

### Option 3: Use a subset for quick experiments

If you have limited resources, you can:
1. Use Tiny ImageNet (recommended)
2. Use a small subset of ImageNet by setting `subset_percent` in config
3. Create your own small dataset with a few classes

## 5. Create Data Loaders

In [None]:
# Create data loaders
train_loader, val_loader, dataset_stats = create_imagenet_loaders(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    subset_percent=config.get('subset_percent', None),
    tiny_imagenet=config.get('tiny_imagenet', False),
    augment_train=True
)

print("\nDataset Statistics:")
for key, value in dataset_stats.items():
    print(f"{key}: {value}")

## 6. Create Model

In [None]:
# Create model
model = resnet50(
    num_classes=dataset_stats['num_classes'],
    pretrained=config['pretrained'],
    replace_maxpool_with_conv=config['replace_maxpool_with_conv']
)

# Move to GPU if available
model = model.to(device)

# Multi-GPU training
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Learning Rate Finder

In [None]:
if config['find_lr']:
    print("Running LR Finder...")
    
    # Create criterion and temporary optimizer
    criterion = nn.CrossEntropyLoss()
    temp_optimizer = optim.SGD(model.parameters(), lr=1e-7, momentum=0.9)
    
    # Create LR finder
    lr_finder = LRFinder(model, temp_optimizer, criterion, device)
    
    # Run range test
    lr_finder.range_test(
        train_loader,
        start_lr=1e-7,
        end_lr=10,
        num_iter=config['lr_finder_iterations'],
        step_mode='exp'
    )
    
    # Plot and find optimal LR
    suggested_lr, min_loss_lr = lr_finder.plot_with_suggestion()
    
    print(f"\nSuggested LR: {suggested_lr:.2e}")
    print(f"Min Loss LR: {min_loss_lr:.2e}")
    
    # Update config with suggested LR
    config['learning_rate'] = suggested_lr
    config['max_lr'] = suggested_lr
    
    # Reset model
    lr_finder.reset()

## 8. Setup Training

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.SGD(
    model.parameters(),
    lr=config['learning_rate'],
    momentum=config['momentum'],
    weight_decay=config['weight_decay']
)

# Learning rate scheduler
scheduler = None
if config['scheduler'] == 'onecycle':
    scheduler = OneCycleLR(
        optimizer,
        max_lr=config.get('max_lr', config['learning_rate']),
        epochs=config['epochs'],
        steps_per_epoch=len(train_loader),
        pct_start=0.3,
        anneal_strategy='cos',
        div_factor=25.0,
        final_div_factor=10000.0
    )
elif config['scheduler'] == 'cosine':
    scheduler = CosineAnnealingLR(
        optimizer,
        T_max=config['epochs'],
        eta_min=1e-6
    )

print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Scheduler: {scheduler.__class__.__name__ if scheduler else 'None'}")
print(f"Initial LR: {optimizer.param_groups[0]['lr']:.2e}")

## 9. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_acc_top5': [],
    'lr': []
}

# Best model tracking
best_val_acc = 0
best_epoch = 0

# Create directories
Path(config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
Path(config['log_dir']).mkdir(parents=True, exist_ok=True)
Path(config['plot_dir']).mkdir(parents=True, exist_ok=True)

In [None]:
# Training loop
print("\nStarting Training...")
print("=" * 50)

for epoch in range(config['epochs']):
    # Training phase
    train_loss, train_acc = train_epoch(
        model, device, train_loader, optimizer, criterion,
        scheduler if config['scheduler'] == 'onecycle' else None,
        epoch, accumulation_steps=1, clip_grad_norm=None, verbose=True
    )
    
    # Validation phase
    val_loss, val_acc, val_acc_top5, _ = test_epoch(
        model, device, val_loader, criterion,
        epoch, verbose=True, calc_top5=True
    )
    
    # Update scheduler (for non-OneCycle)
    if scheduler and config['scheduler'] != 'onecycle':
        scheduler.step()
    
    # Get current LR
    current_lr = optimizer.param_groups[0]['lr']
    
    # Update history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_acc_top5'].append(val_acc_top5)
    history['lr'].append(current_lr)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': best_val_acc,
            'config': config
        }, f"{config['checkpoint_dir']}/best_model.pth")
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{config['epochs']} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"  Val Top-5 Acc: {val_acc_top5:.2f}%")
    print(f"  Learning Rate: {current_lr:.2e}")
    print(f"  Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch+1})")
    print("=" * 50)

print("\nTraining Completed!")
print(f"Best Validation Accuracy: {best_val_acc:.2f}% at Epoch {best_epoch+1}")

## 10. Plot Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

epochs = range(1, len(history['train_loss']) + 1)

# Loss plot
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train')
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val')
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Accuracy plot
axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train')
axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Top-1')
axes[0, 1].plot(epochs, history['val_acc_top5'], 'g-', label='Val Top-5')
axes[0, 1].set_title('Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Learning rate plot
axes[1, 0].plot(epochs, history['lr'], 'orange')
axes[1, 0].set_title('Learning Rate')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('LR')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True)

# Summary text
axes[1, 1].axis('off')
summary_text = f"""Training Summary:

Dataset: {config['dataset_type']}
Epochs: {config['epochs']}
Batch Size: {config['batch_size']}

Best Val Acc: {best_val_acc:.2f}%
Best Epoch: {best_epoch + 1}

Final Train Acc: {history['train_acc'][-1]:.2f}%
Final Val Acc: {history['val_acc'][-1]:.2f}%
Final Val Top-5: {history['val_acc_top5'][-1]:.2f}%
"""
axes[1, 1].text(0.1, 0.5, summary_text, fontsize=12, 
                verticalalignment='center', fontfamily='monospace')

plt.suptitle('Training Progress', fontsize=16)
plt.tight_layout()
plt.savefig(f"{config['plot_dir']}/training_curves.png", dpi=100)
plt.show()

## 11. Save Training History

In [None]:
# Save training history
history_path = f"{config['log_dir']}/training_history.json"
with open(history_path, 'w') as f:
    json.dump(history, f, indent=4)
print(f"Training history saved to {history_path}")

# Save configuration
config_path = f"{config['log_dir']}/config.json"
with open(config_path, 'w') as f:
    json.dump(config, f, indent=4)
print(f"Configuration saved to {config_path}")

## 12. Test the Best Model

In [None]:
# Load best model
checkpoint = torch.load(f"{config['checkpoint_dir']}/best_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")

# Final evaluation
model.eval()
val_loss, val_acc, val_acc_top5, _ = test_epoch(
    model, device, val_loader, criterion,
    epoch=checkpoint['epoch'], verbose=True, calc_top5=True
)

print(f"\nBest Model Performance:")
print(f"  Validation Loss: {val_loss:.4f}")
print(f"  Validation Top-1 Accuracy: {val_acc:.2f}%")
print(f"  Validation Top-5 Accuracy: {val_acc_top5:.2f}%")

## 13. Inference Example

In [None]:
# Get a batch of validation images
model.eval()
images, labels = next(iter(val_loader))
images, labels = images[:8].to(device), labels[:8].to(device)

# Make predictions
with torch.no_grad():
    outputs = model(images)
    _, predicted = outputs.topk(5, 1, largest=True, sorted=True)

# Display results
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.ravel()

for i in range(8):
    # Denormalize image for display
    img = images[i].cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    img = torch.clamp(img, 0, 1)
    
    # Display image
    axes[i].imshow(img.permute(1, 2, 0))
    axes[i].set_title(f'True: {labels[i].item()}\nPred: {predicted[i, 0].item()}')
    axes[i].axis('off')

plt.suptitle('Sample Predictions', fontsize=16)
plt.tight_layout()
plt.show()