# ImageNet 1000-Class Training with ResNet50
## Target: 82%+ Validation Accuracy

This notebook trains ResNet50 on **full ImageNet-1K** (1000 classes) with:
- **~1.28M training images** (~1,300 per class)
- **~50,000 validation images** (~50 per class)
- **Medium data augmentation** for stable training
- **Training from scratch** using LR finder + OneCycleLR
- **Batch size 256** optimized for AWS GPU instances

**Recommended Hardware:**
- AWS g5.xlarge: 1x NVIDIA A10G (24GB VRAM) - $1.006/hour
- AWS g5.2xlarge: 1x NVIDIA A10G (24GB VRAM) + more CPU - $1.212/hour  
- AWS p3.2xlarge: 1x NVIDIA V100 (16GB VRAM) - $3.06/hour (faster)

**Expected Training Time:**
- g5.xlarge: ~8-10 min/epoch × 90 epochs = ~12-15 hours
- p3.2xlarge: ~5-6 min/epoch × 90 epochs = ~7-9 hours

**Expected Cost:**
- g5.xlarge: ~$12-15 (on-demand) or ~$4-5 (spot)
- p3.2xlarge: ~$21-27 (on-demand) or ~$6-8 (spot)

## Cell 1: Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import os

# Local imports
from model import create_resnet50, get_model_stats
from data_loader_full import get_full_dataloaders
from train import Trainer

print("✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA device count: {torch.cuda.device_count()}")

## Cell 2: Device Configuration

In [None]:
# Automatic device selection
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✓ Using CUDA: {torch.cuda.get_device_name(0)}")
    print(f"  CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("✓ Using Apple Metal Performance Shaders (MPS)")
    print("⚠️  Warning: MPS training will be very slow for 1000 classes!")
    print("⚠️  Recommend using AWS GPU instance instead.")
else:
    device = torch.device('cpu')
    print("⚠️  Using CPU (NOT RECOMMENDED for 1000-class training!)")
    print("⚠️  This will take days/weeks. Use AWS GPU instance.")

print(f"Device: {device}")

## Cell 3: Training Configuration

**Optimized for 82%+ validation accuracy on 1000 classes (training from scratch)**

**Key differences from 100-class training:**
- More classes (1000 vs 100) → harder task
- More data (1.28M vs 130k) → longer training
- Larger batch size (256 vs 128) → better for AWS GPUs
- Medium augmentation (not heavy) → more stable with larger dataset
- More epochs may be needed (90-120) for convergence

In [None]:
config = {
    # Data Configuration
    'data_dir': './imagenet_1000class_data',  # CHANGED: Full ImageNet path
    'num_classes': 1000,  # CHANGED: Full 1000 classes
    
    # Training Configuration
    'num_epochs': 90,  # 90-120 epochs for 82%+ accuracy from scratch
    'batch_size': 256,  # CHANGED: Larger batch for AWS GPU (has more memory)
    'num_workers': 8,  # CHANGED: AWS instances have more CPU cores
    'pin_memory': True,
    
    # Augmentation (CHANGED: medium for stability with 1000 classes)
    'augmentation_strength': 'medium',  # Medium augmentation for 1.28M samples
    
    # Learning Rate Configuration
    'find_lr': True,  # Auto-find optimal learning rate
    'initial_lr': 0.05,  # Starting LR (will be overridden by LR finder)
    'max_lr': 0.3,      # Max LR (will be overridden by LR finder)
    'lr_finder_iterations': 200,  # Iterations for LR finder
    
    # Regularization
    'weight_decay': 1e-4,  # L2 regularization
    'label_smoothing': 0.1,  # Label smoothing (0.1 is standard for ImageNet)
    'max_grad_norm': 1.0,  # Gradient clipping
    
    # Model Configuration
    'zero_init_residual': True,  # Zero-init residual connections
    
    # OneCycleLR Configuration (based on 100-class learnings)
    'pct_start': 0.3,  # 30% warmup (can adjust to 0.5-0.6 for longer warmup)
    'div_factor': 25.0,  # initial_lr = max_lr / 25
    'final_div_factor': 1e4,  # final_lr = initial_lr / 10000
    
    # Checkpoint Configuration
    'checkpoint_dir': './checkpoints_1000class',
    'save_frequency': 5,  # Save every 5 epochs
}

print("="*70)
print("TRAINING CONFIGURATION - ImageNet 1000 Classes")
print("="*70)
for key, value in config.items():
    print(f"{key:30s}: {value}")
print("="*70)
print("\n⚠️  IMPORTANT: This will train from scratch on 1.28M images!")
print("Expected time: 12-15 hours (g5.xlarge) or 7-9 hours (p3.2xlarge)")
print("Expected cost: $12-15 (g5.xlarge) or $21-27 (p3.2xlarge)")
print("="*70)

## Cell 4: Load Full ImageNet-1K Dataset

**⚠️  DATASET DOWNLOAD REQUIRED ⚠️**

This notebook assumes the full ImageNet-1K dataset is available at `./imagenet_1000class_data/`

**Option 1: Download using Hugging Face (RECOMMENDED for AWS)**
```python
# Run this ONCE on AWS to download full dataset (~150 GB, 2-3 hours):
from datasets import load_dataset
dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=False)
dataset = load_dataset("ILSVRC/imagenet-1k", split="validation", streaming=False)
# Follow data_loader_full.py documentation for complete setup
```

**Option 2: Download from Official ImageNet**
- Register at https://image-net.org/
- Download ILSVRC2012 training and validation sets
- Extract to `./imagenet_1000class_data/train/` and `./imagenet_1000class_data/val/`

**Expected Structure:**
```
imagenet_1000class_data/
├── train/
│   ├── n01440764/  (1,300 images)
│   ├── n01443537/  (1,300 images)
│   └── ... (1000 classes total)
└── val/
    ├── n01440764/  (50 images)
    ├── n01443537/  (50 images)
    └── ... (1000 classes total)
```

In [None]:
print("\n" + "="*70)
print("LOADING FULL IMAGENET-1K DATASET (1000 CLASSES)")
print("="*70)
print("Expected dataset size:")
print("  Training:   ~1,281,167 images (~1,281 per class)")
print("  Validation: ~50,000 images (~50 per class)")
print("  Total disk space: ~140-150 GB")
print("  Number of classes: 1000")
print("="*70 + "\n")

try:
    train_loader, val_loader, num_classes, class_names = get_full_dataloaders(
        data_dir=config['data_dir'],
        batch_size=config['batch_size'],
        num_workers=config['num_workers'],
        advanced_augmentation=True,
        augmentation_strength=config['augmentation_strength'],
        pin_memory=config['pin_memory'],
        distributed=False,  # Set to True for multi-GPU training
        auto_download=True  # Auto-download if dataset not found (WARNING: ~150GB!)
    )

    print(f"\n✓ Full ImageNet data loaded successfully!")
    print(f"  Number of classes: {num_classes}")
    print(f"  Training batches: {len(train_loader):,}")
    print(f"  Validation batches: {len(val_loader):,}")
    print(f"  Augmentation strength: {config['augmentation_strength']}")
    print(f"\n  With batch size {config['batch_size']}:")
    print(f"    ~{len(train_loader)} iterations per epoch")
    print(f"    ~{len(train_loader) * config['num_epochs']:,} total iterations")
    
except FileNotFoundError as e:
    print("\n" + "="*70)
    print("❌ DATASET DOWNLOAD FAILED")
    print("="*70)
    print(str(e))
    print("\n" + "="*70)
    print("ALTERNATIVE METHODS TO DOWNLOAD:")
    print("="*70)
    print("\n1. Manually run download script:")
    print("   python data_loader_full.py --download")
    print("\n2. Or download from official ImageNet website")
    print("\n3. Ensure dataset is extracted to:")
    print(f"   {config['data_dir']}/train/  (1000 class folders)")
    print(f"   {config['data_dir']}/val/    (1000 class folders)")
    print("="*70)
    raise

## Cell 5: Visualize Sample Images

Verify data loading and augmentation

In [None]:
from data_normalisation import denormalize_image

# Get a batch of training images
images, labels = next(iter(train_loader))

# Denormalize for visualization
images_denorm = denormalize_image(images)

# Plot first 8 images
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes = axes.flatten()

for i in range(8):
    img = images_denorm[i].permute(1, 2, 0).cpu().numpy()
    axes[i].imshow(img)
    axes[i].set_title(f"Class: {class_names[labels[i]]}\nLabel: {labels[i]}", fontsize=8)
    axes[i].axis('off')

plt.suptitle(f'Sample Training Images (augmentation: {config["augmentation_strength"]})', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('sample_images_1000class.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Label range: [{labels.min()}, {labels.max()}]")
print(f"Unique labels in batch: {len(torch.unique(labels))}")

## Cell 6: Create ResNet50 Model for 1000 Classes

In [None]:
print("\nCreating ResNet50 model for 1000 classes (training from scratch)...")

model = create_resnet50(
    num_classes=config['num_classes'],
    zero_init_residual=config['zero_init_residual']
)

model = model.to(device)

# Get model statistics
stats = get_model_stats(model)

print("\n" + "="*70)
print("MODEL STATISTICS")
print("="*70)
print(f"Architecture:          ResNet50")
print(f"Number of classes:     {config['num_classes']}")
print(f"Total parameters:      {stats['total_parameters']:,}")
print(f"Trainable parameters:  {stats['trainable_parameters']:,}")
print(f"Model size:            {stats['model_size_mb']:.2f} MB")
print(f"Training from:         Scratch (random initialization)")
print("="*70)

print(f"\n✓ Model created and moved to {device}")

## Cell 7: Learning Rate Finder

**Automatically finds optimal learning rate range for 1000-class training**

**Note:** LR finder will take ~15-20 minutes on AWS GPU

**Based on 100-class learnings:**
- Lower learning rates work better with medium/heavy augmentation
- Expected max_lr range: 0.01 - 0.1
- If suggested max_lr > 0.3, reduce it (too aggressive)

In [None]:
if config['find_lr']:
    print("\n" + "="*70)
    print("LEARNING RATE FINDER - 1000 Classes")
    print("="*70)
    print("Running LR range test to find optimal learning rates...")
    print(f"This will take ~15-20 minutes on AWS GPU.")
    print(f"Testing {config['lr_finder_iterations']} learning rate values.\n")
    
    from lr_finder import LRFinder
    
    lr_finder = LRFinder(
        model=model,
        optimizer=optim.SGD(model.parameters(), lr=1e-7, momentum=0.9, weight_decay=config['weight_decay']),
        criterion=nn.CrossEntropyLoss(label_smoothing=config['label_smoothing']),
        device=device
    )
    
    # Run LR finder
    lrs, losses, suggested_initial_lr, suggested_max_lr = lr_finder.find(
        train_loader,
        init_lr=1e-8,
        end_lr=10,
        num_iter=config['lr_finder_iterations']
    )
    
    # Plot results
    lr_finder.plot(lrs, losses, initial_lr=suggested_initial_lr, max_lr=suggested_max_lr)
    plt.savefig('lr_finder_1000class.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Validate and potentially adjust suggested LRs
    print(f"\n📊 LR Finder Results:")
    print(f"   Suggested initial_lr: {suggested_initial_lr:.6f}")
    print(f"   Suggested max_lr:     {suggested_max_lr:.6f}")
    
    # Safety checks based on 100-class learnings
    if suggested_max_lr > 0.5:
        print(f"\n⚠️  WARNING: Suggested max_lr ({suggested_max_lr:.2e}) is very high!")
        print(f"   Based on 100-class training, reducing to safer value...")
        suggested_max_lr = min(suggested_max_lr, 0.3)
        suggested_initial_lr = suggested_max_lr / config['div_factor']
        print(f"   Adjusted max_lr: {suggested_max_lr:.6f}")
        print(f"   Adjusted initial_lr: {suggested_initial_lr:.6f}")
    
    if suggested_max_lr < 0.01:
        print(f"\n⚠️  WARNING: Suggested max_lr ({suggested_max_lr:.2e}) is very low!")
        print(f"   Training might be very slow. Consider checking:")
        print(f"   - Dataset is loading correctly")
        print(f"   - Augmentation is not too aggressive")
        print(f"   Using suggested values anyway...")
    
    # Update config
    config['initial_lr'] = suggested_initial_lr
    config['max_lr'] = suggested_max_lr
    
    print(f"\n" + "="*70)
    print(f"FINAL LEARNING RATES:")
    print(f"  Initial LR: {config['initial_lr']:.6f}")
    print(f"  Max LR:     {config['max_lr']:.6f}")
    print(f"  Ratio:      {config['max_lr']/config['initial_lr']:.1f}x")
    print("="*70)
    
    # Save LR values for reference
    with open('lr_config_1000class.txt', 'w') as f:
        f.write(f"FOUND_INITIAL_LR={config['initial_lr']}\n")
        f.write(f"FOUND_MAX_LR={config['max_lr']}\n")
    print(f"\n✓ Learning rates saved to lr_config_1000class.txt")
    
    # Reload model to reset weights after LR finder
    print("\nReloading model with fresh weights...")
    model = create_resnet50(
        num_classes=config['num_classes'],
        zero_init_residual=config['zero_init_residual']
    )
    model = model.to(device)
    print("✓ Model reloaded with random initialization")
    
else:
    print("\n⚠️  Skipping LR finder (using manual LR values)")
    print(f"  Initial LR: {config['initial_lr']}")
    print(f"  Max LR:     {config['max_lr']}")
    print("\n⚠️  WARNING: For 1000-class training, LR finder is HIGHLY RECOMMENDED!")
    print("   Set 'find_lr': True in config for better results.")

## Cell 8: Create Optimizer, Scheduler, and Trainer

**Using OneCycleLR based on successful 100-class training**

**Key learnings from 100-class training:**
- Lower LR often works better than suggested max_lr
- pct_start=0.3-0.6 depending on convergence pattern
- SGD with Nesterov momentum is reliable

In [None]:
print("\nCreating optimizer and scheduler...")

# Optimizer: SGD with Nesterov momentum
optimizer = optim.SGD(
    model.parameters(),
    lr=config['initial_lr'],
    momentum=0.9,
    weight_decay=config['weight_decay'],
    nesterov=True
)

# Scheduler: OneCycleLR for optimal convergence
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config['max_lr'],
    epochs=config['num_epochs'],
    steps_per_epoch=len(train_loader),
    pct_start=config['pct_start'],  # Warmup percentage
    anneal_strategy='cos',
    div_factor=config['div_factor'],  # initial_lr = max_lr / div_factor
    final_div_factor=config['final_div_factor']  # final_lr = initial_lr / final_div_factor
)

# Loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])

# Create trainer
trainer = Trainer(
    model=model,
    device=device,
    checkpoint_dir=config['checkpoint_dir'],
    max_grad_norm=config['max_grad_norm']
)

print("\n" + "="*70)
print("TRAINING SETUP COMPLETE")
print("="*70)
print(f"Optimizer:           SGD with Nesterov momentum")
print(f"Scheduler:           OneCycleLR")
print(f"Initial LR:          {config['initial_lr']:.6f}")
print(f"Max LR:              {config['max_lr']:.6f}")
print(f"Warmup:              {config['pct_start']*100:.0f}% of epochs (first {int(config['num_epochs']*config['pct_start'])} epochs)")
print(f"Loss function:       CrossEntropyLoss (label_smoothing={config['label_smoothing']})")
print(f"Gradient clipping:   {config['max_grad_norm']}")
print(f"Weight decay:        {config['weight_decay']}")
print("="*70)
print(f"\n✓ Ready to start training!")

## Cell 9: Start Training

**Target: 82%+ validation accuracy in 90 epochs**

**Expected Timeline (training from scratch on 1000 classes):**
- Epochs 1-30:  Warmup and initial learning (15-35% val acc)
- Epochs 31-60: Steady improvement (35-55% val acc)
- Epochs 61-90: Continued learning (55-75% val acc)
- May need 100-120 epochs for 80%+ accuracy

**Hardware Performance:**
- **g5.xlarge (A10G 24GB):** ~8-10 min/epoch
  - 90 epochs: ~12-15 hours
  - Cost: ~$12-15 (on-demand) or ~$4-5 (spot)

- **g5.2xlarge (A10G 24GB + more CPU):** ~6-8 min/epoch  
  - 90 epochs: ~9-12 hours
  - Cost: ~$11-15 (on-demand) or ~$3-5 (spot)

- **p3.2xlarge (V100 16GB):** ~5-6 min/epoch
  - 90 epochs: ~7-9 hours
  - Cost: ~$21-27 (on-demand) or ~$6-8 (spot)

**Checkpoints:**
- Best model saved automatically
- Checkpoints saved every 5 epochs
- Training can resume from last checkpoint if interrupted

In [None]:
print("\n" + "="*70)
print("STARTING TRAINING - ImageNet 1000 Classes")
print("="*70)
print(f"Target epochs:              {config['num_epochs']}")
print(f"Goal:                       82%+ validation accuracy")
print(f"Training samples:           ~1.28M images")
print(f"Validation samples:         ~50K images")
print(f"Batch size:                 {config['batch_size']}")
print(f"Iterations per epoch:       {len(train_loader):,}")
print(f"Total iterations:           {len(train_loader) * config['num_epochs']:,}")
print(f"Augmentation:               {config['augmentation_strength']}")
print(f"Device:                     {device}")
print("="*70)
print("\n⏰ Training will take 7-15 hours depending on GPU.")
print("📊 Progress will be displayed after each epoch.")
print("💾 Checkpoints saved every 5 epochs + best model.")
print("\n🚀 Starting training now...\n")
print("="*70 + "\n")

import time
start_time = time.time()

# Train the model
history = trainer.train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=config['num_epochs'],
    device=device,
    save_frequency=config['save_frequency']
)

end_time = time.time()
training_time_hours = (end_time - start_time) / 3600

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)
print(f"Total training time: {training_time_hours:.2f} hours")
print(f"Average time per epoch: {training_time_hours / config['num_epochs']:.2f} hours")
print("="*70)

## Cell 10: Plot Training History

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Plot 1: Training and Validation Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2, color='blue')
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2, color='orange')
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Loss over Epochs', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Training and Validation Accuracy
axes[0, 1].plot(history['train_acc'], label='Train Accuracy', linewidth=2, color='blue')
axes[0, 1].plot(history['val_acc'], label='Val Accuracy', linewidth=2, color='orange')
axes[0, 1].axhline(y=82, color='red', linestyle='--', label='82% Target', alpha=0.7, linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
axes[0, 1].set_title('Accuracy over Epochs', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Learning Rate Schedule
axes[1, 0].plot(history['learning_rate'], linewidth=2, color='green')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Learning Rate', fontsize=12)
axes[1, 0].set_title('Learning Rate Schedule (OneCycleLR)', fontsize=14, fontweight='bold')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Validation Accuracy Zoomed
axes[1, 1].plot(history['val_acc'], label='Val Accuracy', linewidth=3, color='green')
axes[1, 1].axhline(y=82, color='red', linestyle='--', label='82% Target', alpha=0.7, linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Validation Accuracy (%)', fontsize=12)
axes[1, 1].set_title('Validation Accuracy Progress', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history_1000classes.png', dpi=150, bbox_inches='tight')
plt.show()

# Print final statistics
print("\n" + "="*70)
print("FINAL TRAINING RESULTS")
print("="*70)
print(f"Best Validation Accuracy:  {max(history['val_acc']):.2f}%")
print(f"Final Training Accuracy:   {history['train_acc'][-1]:.2f}%")
print(f"Final Validation Accuracy: {history['val_acc'][-1]:.2f}%")
print(f"Final Training Loss:       {history['train_loss'][-1]:.4f}")
print(f"Final Validation Loss:     {history['val_loss'][-1]:.4f}")
print(f"Train/Val Accuracy Gap:    {history['train_acc'][-1] - history['val_acc'][-1]:.2f}%")
print("="*70)

if max(history['val_acc']) >= 82:
    print("\n🎉 TARGET ACHIEVED: 82%+ validation accuracy!")
    print("✓ Successfully trained ResNet50 on full ImageNet-1K!")
    print("✓ Model ready for deployment and inference.")
elif max(history['val_acc']) >= 75:
    print(f"\n✓ Good progress! Reached {max(history['val_acc']):.2f}% validation accuracy.")
    print("\nTo reach 82%+ accuracy, consider:")
    print("  - Train for more epochs (100-120 total)")
    print("  - Adjust learning rate (try lower max_lr)")
    print("  - Fine-tune from current checkpoint")
else:
    print(f"\n⚠️  Current best: {max(history['val_acc']):.2f}% (target: 82%+)")
    print("\nRecommendations:")
    print("  - Train for more epochs (need 100-120 for full convergence)")
    print("  - Check learning rate (may need adjustment)")
    print("  - Verify data augmentation is appropriate")
    print("  - Consider starting from pretrained weights")

print("="*70)

## Cell 11: Evaluate Best Model

Load the best checkpoint and evaluate on validation set

In [None]:
print("\nLoading best checkpoint for final evaluation...")

# Load best model
best_checkpoint_path = Path(config['checkpoint_dir']) / 'best_model.pth'
if best_checkpoint_path.exists():
    checkpoint = torch.load(best_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
    print(f"  Best validation accuracy: {checkpoint['best_acc']:.2f}%")
else:
    print("⚠️  Best checkpoint not found, using current model")

# Evaluate on validation set
print("\nRunning final evaluation on full validation set (50,000 images)...")
print("This will take ~5-10 minutes...\n")
model.eval()

correct = 0
total = 0
top5_correct = 0

from tqdm import tqdm
with torch.no_grad():
    for images, labels in tqdm(val_loader, desc='Evaluating'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        
        # Top-1 accuracy
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Top-5 accuracy
        _, top5_pred = outputs.topk(5, 1, largest=True, sorted=True)
        top5_correct += top5_pred.eq(labels.view(-1, 1).expand_as(top5_pred)).sum().item()

top1_acc = 100. * correct / total
top5_acc = 100. * top5_correct / total

print("\n" + "="*70)
print("FINAL VALIDATION RESULTS - ImageNet 1000 Classes")
print("="*70)
print(f"Top-1 Accuracy:          {top1_acc:.2f}%")
print(f"Top-5 Accuracy:          {top5_acc:.2f}%")
print(f"Total samples evaluated: {total:,}")
print(f"Number of classes:       1000")
print("="*70)

if top1_acc >= 82:
    print("\n🎉 EXCELLENT! Achieved 82%+ top-1 accuracy!")
    print("✓ Model performance exceeds ImageNet competition baseline.")
elif top1_acc >= 76:
    print("\n✓ GREAT! Achieved competitive ImageNet accuracy.")
    print(f"  Top-1: {top1_acc:.2f}% (official ResNet50 baseline: 76.13%)")
    print(f"  Top-5: {top5_acc:.2f}%")
elif top1_acc >= 70:
    print("\n✓ GOOD! Solid performance for training from scratch.")
    print(f"  Consider training more epochs for further improvement.")
else:
    print(f"\n⚠️  Top-1: {top1_acc:.2f}% (target: 82%+)")
    print("  Model may benefit from:")
    print("  - More training epochs")
    print("  - Learning rate adjustment")
    print("  - Transfer learning from pretrained weights")

## Cell 12: Save Final Model for Deployment

In [None]:
# Save model for deployment
deployment_dir = Path('./deployment')
deployment_dir.mkdir(exist_ok=True)

# Save complete model with metadata
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': config['num_classes'],
    'final_val_acc': top1_acc,
    'final_top5_acc': top5_acc,
    'config': config,
    'training_history': history,
    'pytorch_version': torch.__version__,
}, deployment_dir / 'resnet50_1000class_final.pth')

print("\n" + "="*70)
print("MODEL SAVED FOR DEPLOYMENT")
print("="*70)
print(f"Location:            {deployment_dir / 'resnet50_1000class_final.pth'}")
print(f"Model size:          ~{Path(deployment_dir / 'resnet50_1000class_final.pth').stat().st_size / 1e6:.1f} MB")
print(f"Top-1 Accuracy:      {top1_acc:.2f}%")
print(f"Top-5 Accuracy:      {top5_acc:.2f}%")
print(f"Number of classes:   {config['num_classes']}")
print(f"PyTorch version:     {torch.__version__}")
print("="*70)
print("\n✓ Model ready for inference and deployment!")
print("\nTo use this model:")
print("```python")
print("checkpoint = torch.load('deployment/resnet50_1000class_final.pth')")
print("model = create_resnet50(num_classes=1000)")
print("model.load_state_dict(checkpoint['model_state_dict'])")
print("model.eval()")
print("```")

## Summary & Next Steps

### Training Complete!

You've successfully trained ResNet50 on the full ImageNet-1K dataset (1000 classes, 1.28M images).

### Key Achievements:
- ✅ Trained from scratch on 1000 classes
- ✅ Used LR finder to optimize learning rates
- ✅ Applied OneCycleLR scheduler
- ✅ Medium augmentation for stable training
- ✅ Model checkpoints saved throughout training

### Results Summary:
- Best validation accuracy achieved
- Training time and cost tracked
- Final model saved for deployment

### If You Want to Improve Further:

1. **Train Longer:**
   - Continue from last checkpoint for 10-20 more epochs
   - Accuracy often improves with extended training

2. **Adjust Learning Rate:**
   - Try lower max_lr (e.g., current_max_lr / 2)
   - Experiment with pct_start (0.4-0.6)

3. **Try Transfer Learning:**
   - Start with official pretrained weights
   - Fine-tune for 10-20 epochs
   - Can reach 80-82% much faster

4. **Experiment with Augmentation:**
   - Try 'light' or 'heavy' augmentation
   - Add MixUp or CutMix

### Cost Optimization Tips:

1. **Use Spot Instances:**
   - 60-70% cheaper than on-demand
   - g5.xlarge spot: ~$0.30/hour (vs $1.01/hour)

2. **Checkpoint Frequently:**
   - Save every 5 epochs (already configured)
   - Can resume if spot instance interrupted

3. **Monitor Training:**
   - Stop early if val accuracy plateaus
   - No need to waste compute on unhelpful epochs

### Download & Cleanup:

```bash
# Download results to local machine
scp -i your-key.pem ubuntu@<instance-ip>:~/Image_net_training_model/deployment/* ./
scp -i your-key.pem ubuntu@<instance-ip>:~/Image_net_training_model/*.png ./

# After downloading, cleanup AWS:
# 1. Terminate EC2 instance
# 2. Delete or snapshot EBS volume
# 3. Verify no running resources in AWS Console
```

### Congratulations on completing ImageNet-1K training! 🎉