# ImageNet Training - Medium Recipe (Local) 🌶️🌶️

**Target Hardware:** MacBook M4 Pro (Apple Silicon - MPS)

**Training Time:** Longer on single GPU (good for ImageNet-Mini)  
**Target Accuracy:** 78.1% - 79.5%

## Implemented Techniques:

### Speed Up Methods:
1. ✅ **BlurPool** - Antialiased downsampling for shift-invariance
2. ✅ **FixRes** - Fine-tuning at higher resolution than training
3. ✅ **Label Smoothing** - Prevents overconfident predictions (0.1 smoothing)
4. ✅ **Progressive Resizing** - 128px → 224px → 288px
5. ✅ **MixUp** - Data augmentation mixing sample pairs

### Additional Optimizations:
1. ✅ **Channels Last** - Optimized memory format for GPU

### Hardware Configuration:
- **GPU:** 1x Apple Silicon (MPS)
- **CPU:** 4 data loading workers
- **Precision:** 32-bit (MPS compatibility)

### Training Schedule:
| Epochs | Resolution | Augmentation |
|--------|-----------|--------------|
| 0-10   | 128x128   | Train (MixUp) |
| 10-85  | 224x224   | Train (MixUp) |
| 85-90  | 288x288   | Test (FixRes) |

**Note:** Uses fixed batch size of 64 throughout all stages


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchmetrics
import os
import antialiased_cnns  # For BlurPool

In [2]:
class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int, num_workers: int):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = 128  # Start with the smallest resolution
        self.use_train_augs = True
        
        # Standard ImageNet normalization
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def update_resolution(self, image_size: int, use_train_augs: bool, batch_size: int = None):
        """Called by the callback to update the transforms and optionally batch size."""
        self.image_size = image_size
        self.use_train_augs = use_train_augs
        if batch_size is not None:
            self.batch_size = batch_size

    def train_dataloader(self):
        """This is called by the Trainer to get the dataloader."""
        if self.use_train_augs:
            # Standard training augmentations
            transform = T.Compose([
                T.RandomResizedCrop(self.image_size),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                self.normalize,
            ])
        else:
            # Test-style augmentations for the FixRes phase
            transform = T.Compose([
                T.Resize(int(self.image_size * 256 / 224)), # Standard practice for validation
                T.CenterCrop(self.image_size),
                T.ToTensor(),
                self.normalize,
            ])
        
        train_dataset = ImageFolder(root=f"{self.data_dir}/train", transform=transform)
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=False,  # MPS doesn't support pinned memory
            persistent_workers=True,
        )

    def val_dataloader(self):
        # Use the current image size for validation to match training resolution
        transform = T.Compose([
            T.Resize(int(self.image_size * 256 / 224)),
            T.CenterCrop(self.image_size),
            T.ToTensor(),
            self.normalize,
        ])
        val_dataset = ImageFolder(root=f"{self.data_dir}/val", transform=transform)
        return DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=False,  # MPS doesn't support pinned memory
            persistent_workers=True,
        )

In [3]:
class ResolutionScheduleCallback(pl.Callback):
    def __init__(self, schedule):
        super().__init__()
        # e.g., {0: (128, True, 512), 10: (224, True, 320), 85: (288, False, 256)}
        # Format: {epoch: (resolution, use_train_augs, batch_size)}
        self.schedule = schedule

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.schedule:
            config = self.schedule[trainer.current_epoch]
            
            # Handle both old format (size, use_train_augs) and new format (size, use_train_augs, batch_size)
            if len(config) == 2:
                size, use_train_augs = config
                batch_size = None
            else:
                size, use_train_augs, batch_size = config
            
            if batch_size:
                print(f"\nEpoch {trainer.current_epoch}: Adjusting resolution to {size}x{size}, batch_size to {batch_size} per GPU")
            else:
                print(f"\nEpoch {trainer.current_epoch}: Adjusting resolution to {size}x{size}")
            
            # Update the datamodule's parameters
            trainer.datamodule.update_resolution(size, use_train_augs, batch_size)
            
            # Force recreation of both train and val dataloaders with new transforms
            # In PyTorch Lightning 2.x, we need to reload the dataloaders
            trainer.fit_loop._data_source.instance = None
            trainer.fit_loop.setup_data()
            
            # Also reset the validation dataloader
            if trainer.val_dataloaders is not None:
                trainer._evaluation_loop._data_source.instance = None

In [None]:
class ImageNetLitModule(pl.LightningModule):
    def __init__(self, learning_rate=0.1, momentum=0.9, weight_decay=1e-4, mixup_alpha=0.2):
        super().__init__()
        self.save_hyperparameters()
        
        ## Technique: BlurPool (antialiased downsampling)
        self.model = antialiased_cnns.resnet50(pretrained=False, filter_size=4)
        
        ## Technique: Label Smoothing
        self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
        
        # Initialize accuracy metric
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=1000)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        
        ## Technique: MixUp - Data augmentation that mixes pairs of samples
        lam = torch.distributions.beta.Beta(self.hparams.mixup_alpha, self.hparams.mixup_alpha).sample().to(self.device)
        shuffled_indices = torch.randperm(images.size(0))
        mixed_images = lam * images + (1 - lam) * images[shuffled_indices]
        labels_a, labels_b = labels, labels[shuffled_indices]

        # Forward pass
        outputs = self(mixed_images)
        loss = lam * self.loss_fn(outputs, labels_a) + (1 - lam) * self.loss_fn(outputs, labels_b)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.loss_fn(outputs, labels)
        
        # Update accuracy metric
        preds = torch.argmax(outputs, dim=1)
        self.val_accuracy(preds, labels)
        
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.learning_rate,
            momentum=self.hparams.momentum,
            weight_decay=self.hparams.weight_decay
        )
        return optimizer

In [None]:
if __name__ == '__main__':
    # Local ImageNet-Mini path
    IMAGENET_PATH = 'imagenet-mini'  # Should contain 'train/' and 'val/' subfolders

    # Instantiate the DataModule
    # MacBook M4 Pro - optimized batch size for single MPS device
    datamodule = ImageNetDataModule(
        data_dir=IMAGENET_PATH,
        batch_size=64,   # Fixed batch size throughout training
        num_workers=4    # Optimized for MacBook M4 Pro
    )
    
    # Instantiate the Model
    model = ImageNetLitModule()

    ## Technique: Channels Last (optimized memory format for GPU)
    model = model.to(memory_format=torch.channels_last)
    
    ## Technique: Progressive Resizing + FixRes
    # Dynamically adjust resolution and augmentation strategy
    # Format: {epoch: (resolution, use_train_augs)}
    res_schedule = {
        0: (128, True),    # Small images, train augmentations
        10: (224, True),   # Medium images, train augmentations
        85: (288, False)   # Large images, test augmentations (FixRes)
    }
    schedule_callback = ResolutionScheduleCallback(schedule=res_schedule)

    # Configure the Trainer for MacBook M4 Pro (MPS)
    trainer = pl.Trainer(
        max_epochs=90,
        accelerator='mps',           # Apple Silicon MPS acceleration
        devices=1,                   # Single MPS device
        precision='32',              # 32-bit precision for MPS compatibility
        callbacks=[schedule_callback],
        log_every_n_steps=50,
    )

    # Start training!
    trainer.fit(model, datamodule=datamodule)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/yash/Documents/ERA/mini-capstone/init/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name         | Type               | Params | Mode 
------------------------------------------------------------

Training: |          | 0/? [00:00<?, ?it/s]                                
Epoch 0: Adjusting resolution to 128x128, batch_size to 128 per GPU
Epoch 0:   9%|▉         | 25/272 [00:26<04:22,  0.94it/s, v_num=7, train_loss_step=7.980]


Detected KeyboardInterrupt, attempting graceful shutdown ...


: 

In [None]:
# Additional analysis or experimentation can go here
