# ImageNet-1K Training - Medium Recipe 🌶️🌶️

**Target Hardware:** AWS p3.16xlarge (8x NVIDIA V100 GPUs, 16GB VRAM, 64 vCPUs)

**Training Time:** 300-700 minutes for full ImageNet-1K (1.2M images)  
**Target Accuracy:** 78.1% - 79.5%

## Implemented Techniques:

### Speed Up Methods:
1. ✅ **BlurPool** - Antialiased downsampling for shift-invariance
2. ✅ **FixRes** - Fine-tuning at higher resolution than training
3. ✅ **Label Smoothing** - Prevents overconfident predictions (0.1 smoothing)
4. ✅ **Progressive Resizing** - 128px → 224px → 288px
5. ✅ **MixUp** - Data augmentation mixing sample pairs
6. ✅ **SAM** - Sharpness Aware Minimization optimizer

### Additional Optimizations:
1. ✅ **Channels Last** - Optimized memory format for GPU
2. ✅ **Mixed Precision (FP16)** - Faster training with 16-bit floats
3. ✅ **Distributed Data Parallel** - Multi-GPU training across 8 GPUs
4. ✅ **Synchronized Batch Normalization** - Consistent batch stats across GPUs
5. ✅ **Dynamic Batch Sizing** - Adjusts batch size based on resolution

### Hardware Utilization (100%):
- **GPU:** All 8 V100s with DDP + mixed precision
- **CPU:** 64/64 vCPUs (8 workers × 8 GPUs)
- **VRAM:** Optimized per resolution with dynamic batching

### Training Schedule:
| Epochs | Resolution | Batch/GPU | Total Batch | Augmentation |
|--------|-----------|-----------|-------------|--------------|
| 0-10   | 128x128   | 512       | 4,096       | Train (MixUp) |
| 10-85  | 224x224   | 320       | 2,560       | Train (MixUp) |
| 85-90  | 288x288   | 256       | 2,048       | Test (FixRes) |


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_optimizer import SAM
import torchmetrics
import os
import antialiased_cnns  # For BlurPool

In [None]:
class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int, num_workers: int):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = 128  # Start with the smallest resolution
        self.use_train_augs = True
        
        # Standard ImageNet normalization
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def update_resolution(self, image_size: int, use_train_augs: bool, batch_size: int = None):
        """Called by the callback to update the transforms and optionally batch size."""
        self.image_size = image_size
        self.use_train_augs = use_train_augs
        if batch_size is not None:
            self.batch_size = batch_size

    def train_dataloader(self):
        """This is called by the Trainer to get the dataloader."""
        if self.use_train_augs:
            # Standard training augmentations
            transform = T.Compose([
                T.RandomResizedCrop(self.image_size),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                self.normalize,
            ])
        else:
            # Test-style augmentations for the FixRes phase
            transform = T.Compose([
                T.Resize(int(self.image_size * 256 / 224)), # Standard practice for validation
                T.CenterCrop(self.image_size),
                T.ToTensor(),
                self.normalize,
            ])
        
        train_dataset = ImageFolder(root=f"{self.data_dir}/train", transform=transform)
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,  # Enable for GPU training
            persistent_workers=True,
        )

    def val_dataloader(self):
        # Use the current image size for validation to match training resolution
        transform = T.Compose([
            T.Resize(int(self.image_size * 256 / 224)),
            T.CenterCrop(self.image_size),
            T.ToTensor(),
            self.normalize,
        ])
        val_dataset = ImageFolder(root=f"{self.data_dir}/val", transform=transform)
        return DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,  # Enable for GPU training
            persistent_workers=True,
        )

In [None]:
class ResolutionScheduleCallback(pl.Callback):
    def __init__(self, schedule):
        super().__init__()
        # e.g., {0: (128, True, 512), 10: (224, True, 320), 85: (288, False, 256)}
        # Format: {epoch: (resolution, use_train_augs, batch_size)}
        self.schedule = schedule

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.schedule:
            config = self.schedule[trainer.current_epoch]
            
            # Handle both old format (size, use_train_augs) and new format (size, use_train_augs, batch_size)
            if len(config) == 2:
                size, use_train_augs = config
                batch_size = None
            else:
                size, use_train_augs, batch_size = config
            
            if batch_size:
                print(f"\nEpoch {trainer.current_epoch}: Adjusting resolution to {size}x{size}, batch_size to {batch_size} per GPU")
            else:
                print(f"\nEpoch {trainer.current_epoch}: Adjusting resolution to {size}x{size}")
            
            # Update the datamodule's parameters
            trainer.datamodule.update_resolution(size, use_train_augs, batch_size)
            
            # Force recreation of both train and val dataloaders with new transforms
            # In PyTorch Lightning 2.x, we need to reload the dataloaders
            trainer.fit_loop._data_source.instance = None
            trainer.fit_loop.setup_data()
            
            # Also reset the validation dataloader
            if trainer.val_dataloaders is not None:
                trainer._evaluation_loop._data_source.instance = None

In [None]:
class ImageNetLitModule(pl.LightningModule):
    def __init__(self, learning_rate=0.1, momentum=0.9, weight_decay=1e-4, mixup_alpha=0.2):
        super().__init__()
        self.save_hyperparameters()
        
        ## Technique: BlurPool (antialiased downsampling)
        self.model = antialiased_cnns.resnet50(pretrained=False, filter_size=4)
        
        ## Technique: Label Smoothing
        self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.automatic_optimization = False
        
        # Initialize accuracy metric
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=1000)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        images, labels = batch
        
        ## Technique: MixUp - Data augmentation that mixes pairs of samples
        lam = torch.distributions.beta.Beta(self.hparams.mixup_alpha, self.hparams.mixup_alpha).sample().to(self.device)
        shuffled_indices = torch.randperm(images.size(0))
        mixed_images = lam * images + (1 - lam) * images[shuffled_indices]
        labels_a, labels_b = labels, labels[shuffled_indices]

        ## Technique: SAM (Sharpness Aware Minimization) - Two-Step Update
        # First forward-backward pass
        outputs = self(mixed_images)
        loss = lam * self.loss_fn(outputs, labels_a) + (1 - lam) * self.loss_fn(outputs, labels_b)
        self.manual_backward(loss)
        optimizer.first_step(zero_grad=True)
        
        # Second forward-backward pass for SAM
        outputs_2 = self(mixed_images)
        loss_2 = lam * self.loss_fn(outputs_2, labels_a) + (1 - lam) * self.loss_fn(outputs_2, labels_b)
        self.manual_backward(loss_2)
        optimizer.second_step(zero_grad=True)
        
        self.log("train_loss", loss_2, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss_2

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.loss_fn(outputs, labels)
        
        # Update accuracy metric
        preds = torch.argmax(outputs, dim=1)
        self.val_accuracy(preds, labels)
        
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        base_optimizer = torch.optim.SGD
        optimizer = SAM(
            self.parameters(),
            base_optimizer,
            lr=self.hparams.learning_rate,
            momentum=self.hparams.momentum,
            weight_decay=self.hparams.weight_decay
        )
        return optimizer

In [None]:
if __name__ == '__main__':
    # this is directory of imagenet1k in EBS
    IMAGENET_PATH = '/home/ec2-user/imagenet1k' # Should contain 'train/' and 'val/' subfolders

    # this is the directory of imagenet1k-mini in local
    # IMAGENET_PATH = 'imagenet-mini' # Should contain 'train/' and 'val/' subfolders

    # Instantiate the DataModule
    # p3.16xlarge has 8 V100 GPUs (16GB each), dynamic batch sizing for max utilization
    datamodule = ImageNetDataModule(
        data_dir=IMAGENET_PATH,
        batch_size=512,  # Initial batch size (will be dynamically adjusted)
        num_workers=8    # Per GPU workers (8 workers * 8 GPUs = 64 total, uses all 64 vCPUs)
    )
    
    # Instantiate the Model
    model = ImageNetLitModule()

    ## Technique: Channels Last (optimized memory format for GPU)
    model = model.to(memory_format=torch.channels_last)
    
    ## Technique: Progressive Resizing + FixRes + Dynamic Batch Sizing
    # Dynamically adjust resolution AND batch size to maximize GPU utilization
    # Format: {epoch: (resolution, use_train_augs, batch_size_per_gpu)}
    # Total batch: 512*8=4096 → 320*8=2560 → 256*8=2048
    res_schedule = {
        0: (128, True, 512),   # Small images = larger batches
        10: (224, True, 320),  # Medium images = medium batches
        85: (288, False, 256)  # Large images + FixRes = smaller batches
    }
    schedule_callback = ResolutionScheduleCallback(schedule=res_schedule)

    # Configure the Trainer for p3.16xlarge (8 x V100 GPUs)
    trainer = pl.Trainer(
        max_epochs=90,
        accelerator='gpu',           # GPU acceleration
        devices=8,                   # Use all 8 GPUs
        strategy='ddp',              # Distributed Data Parallel
        precision='16-mixed',        # Mixed precision training for faster computation
        callbacks=[schedule_callback],
        log_every_n_steps=50,
        sync_batchnorm=True,         # Sync batch norm across GPUs
    )

    # Start training!
    trainer.fit(model, datamodule=datamodule)

In [None]:
# Additional analysis or experimentation can go here
