CIFAR-100 CNN 訓練系統
計算智慧期末專案

學生：李泓斌
學號：C111110141
任課教授：曾建誠

訓練成果：78.40% 驗證準確率

In [None]:
#!/usr/bin/env python3
"""
CIFAR-100 CNN Training System
Production-ready implementation with focus on performance and maintainability
"""

import os
import sys
import time
import json
import warnings
from dataclasses import dataclass, asdict
from typing import Dict, Tuple, Optional, List
from pathlib import Path
import logging
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

In [None]:
# =============================================================================
# System Configuration and Environment Check
# =============================================================================

class SystemChecker:
    """System environment verification and GPU configuration"""

    @staticmethod
    def check_environment() -> Dict[str, any]:
        """Verify system capabilities and return configuration"""
        config = {
            'cuda_available': torch.cuda.is_available(),
            'cuda_version': torch.version.cuda if torch.cuda.is_available() else None,
            'pytorch_version': torch.__version__,
            'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
            'device_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
            'total_memory': torch.cuda.get_device_properties(0).total_memory // (1024**3) if torch.cuda.is_available() else 0
        }

        # Memory optimization settings
        if config['cuda_available']:
            torch.backends.cudnn.benchmark = True
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

        return config

    @staticmethod
    def get_optimal_workers() -> int:
        """Determine optimal number of data loader workers"""
        try:
            import multiprocessing
            return min(4, multiprocessing.cpu_count())
        except:
            return 0  # Fallback for Colab environment

In [None]:
# =============================================================================
# Configuration Management
# =============================================================================

@dataclass
class TrainingConfig:
    """Training configuration with sensible defaults"""
    # Model parameters
    num_classes: int = 100
    input_channels: int = 3
    image_size: int = 32

    # Training parameters
    batch_size: int = 128
    epochs: int = 100
    learning_rate: float = 0.1
    momentum: float = 0.9
    weight_decay: float = 5e-4

    # Optimization
    gradient_clip: float = 1.0
    mixed_precision: bool = True
    gradient_accumulation_steps: int = 1

    # Regularization
    dropout_rate: float = 0.3
    label_smoothing: float = 0.1

    # Learning rate schedule
    lr_scheduler: str = 'cosine'  # 'cosine' or 'step'
    lr_milestones: List[int] = None
    lr_gamma: float = 0.1
    warmup_epochs: int = 5

    # Early stopping
    patience: int = 15
    min_delta: float = 0.001

    # Data augmentation
    augmentation_strength: str = 'moderate'  # 'light', 'moderate', 'strong'

    # System
    num_workers: int = 4
    pin_memory: bool = True
    checkpoint_interval: int = 5

    def __post_init__(self):
        if self.lr_milestones is None:
            self.lr_milestones = [30, 60, 80]

        # Adjust batch size based on available GPU memory
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(0).total_memory // (1024**3)
            if gpu_memory < 8:
                self.batch_size = min(self.batch_size, 64)
                print(f"Adjusted batch size to {self.batch_size} based on GPU memory")

In [None]:
# =============================================================================
# Data Pipeline
# =============================================================================

class CIFAR100DataModule:
    """Efficient data loading and augmentation pipeline"""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.mean = (0.5071, 0.4867, 0.4408)
        self.std = (0.2675, 0.2565, 0.2761)

    def get_transforms(self, train: bool = True) -> transforms.Compose:
        """Get data transformation pipeline"""
        if train:
            augmentation_configs = {
                'light': {
                    'rotation': 5,
                    'translate': (0.05, 0.05),
                    'scale': (0.95, 1.05)
                },
                'moderate': {
                    'rotation': 10,
                    'translate': (0.1, 0.1),
                    'scale': (0.9, 1.1)
                },
                'strong': {
                    'rotation': 15,
                    'translate': (0.15, 0.15),
                    'scale': (0.85, 1.15)
                }
            }

            aug_config = augmentation_configs[self.config.augmentation_strength]

            transform_list = [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomAffine(
                    degrees=aug_config['rotation'],
                    translate=aug_config['translate'],
                    scale=aug_config['scale']
                ),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]

            # Add Cutout augmentation for strong augmentation
            if self.config.augmentation_strength == 'strong':
                transform_list.append(CutoutTransform(n_holes=1, length=8))
        else:
            transform_list = [
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]

        return transforms.Compose(transform_list)

    def get_dataloaders(self) -> Tuple[DataLoader, DataLoader]:
        """Create training and validation dataloaders"""
        train_transform = self.get_transforms(train=True)
        val_transform = self.get_transforms(train=False)

        train_dataset = CIFAR100(
            root='./data',
            train=True,
            download=True,
            transform=train_transform
        )

        val_dataset = CIFAR100(
            root='./data',
            train=False,
            download=True,
            transform=val_transform
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
            persistent_workers=self.config.num_workers > 0,
            prefetch_factor=2 if self.config.num_workers > 0 else None
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size * 2,  # Larger batch for validation
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
            persistent_workers=self.config.num_workers > 0,
            prefetch_factor=2 if self.config.num_workers > 0 else None
        )

        return train_loader, val_loader

class CutoutTransform:
    """Cutout augmentation for improved regularization"""

    def __init__(self, n_holes: int = 1, length: int = 8):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = torch.ones((h, w), dtype=torch.float32)

        for _ in range(self.n_holes):
            y = torch.randint(high=h, size=(1,)).item()
            x = torch.randint(high=w, size=(1,)).item()

            y1 = max(0, y - self.length // 2)
            y2 = min(h, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(w, x + self.length // 2)

            mask[y1:y2, x1:x2] = 0.

        mask = mask.expand_as(img)
        img = img * mask
        return img

In [None]:
# =============================================================================
# Model Architecture
# =============================================================================

class ResidualBlock(nn.Module):
    """Efficient residual block with batch normalization"""

    def __init__(self, in_channels: int, out_channels: int, stride: int = 1, dropout_rate: float = 0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()

        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class CIFAR100CNN(nn.Module):
    """
    Efficient CNN for CIFAR-100 classification
    Architecture: Conv blocks with residual connections -> Global pooling -> FC
    """

    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config

        # Initial convolution
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        # Residual blocks with increasing channels
        self.layer1 = self._make_layer(64, 128, 2, stride=1)
        self.layer2 = self._make_layer(128, 256, 2, stride=2)
        self.layer3 = self._make_layer(256, 512, 2, stride=2)

        # Global average pooling and classifier
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.fc = nn.Linear(512, config.num_classes)

        # Weight initialization
        self._initialize_weights()

    def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride, self.config.dropout_rate))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1, self.config.dropout_rate))
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

    def get_model_stats(self) -> Dict[str, int]:
        """Calculate model complexity statistics"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        # Estimate FLOPs for 32x32 input
        def count_conv2d_flops(module):
            # FLOPs = 2 * input_channels * output_channels * kernel_h * kernel_w * output_h * output_w
            return 2 * module.in_channels * module.out_channels * \
                   module.kernel_size[0] * module.kernel_size[1] * \
                   (32 // module.stride[0]) * (32 // module.stride[0])

        total_flops = 0
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                total_flops += count_conv2d_flops(m)

        return {
            'total_params': total_params,
            'trainable_params': trainable_params,
            'total_params_mb': total_params * 4 / (1024 * 1024),  # Assuming float32
            'estimated_flops_m': total_flops / 1e6
        }

In [None]:
# =============================================================================
# Training Components
# =============================================================================

class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing for better generalization"""

    def __init__(self, smoothing: float = 0.1):
        super().__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (pred.size(-1) - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=-1))

class EarlyStopping:
    """Early stopping to prevent overfitting"""

    def __init__(self, patience: int = 10, min_delta: float = 0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

        return self.early_stop

class MetricTracker:
    """Track and analyze training metrics"""

    def __init__(self):
        self.metrics = defaultdict(list)
        self.current_epoch = 0

    def update(self, metric_dict: Dict[str, float], epoch: int = None):
        if epoch is not None:
            self.current_epoch = epoch

        for key, value in metric_dict.items():
            self.metrics[key].append(value)

    def get_best(self, metric: str, mode: str = 'max') -> Tuple[float, int]:
        """Get best value and epoch for a metric"""
        values = self.metrics[metric]
        if not values:
            return None, None

        if mode == 'max':
            best_idx = np.argmax(values)
        else:
            best_idx = np.argmin(values)

        return values[best_idx], best_idx

    def plot_metrics(self, save_path: Optional[str] = None):
        """Plot training curves"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))

        # Loss curves
        axes[0, 0].plot(self.metrics['train_loss'], label='Train Loss')
        axes[0, 0].plot(self.metrics['val_loss'], label='Val Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Loss Curves')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # Accuracy curves
        axes[0, 1].plot(self.metrics['train_acc'], label='Train Acc')
        axes[0, 1].plot(self.metrics['val_acc'], label='Val Acc')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy (%)')
        axes[0, 1].set_title('Accuracy Curves')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # Learning rate
        axes[1, 0].plot(self.metrics['learning_rate'])
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_title('Learning Rate Schedule')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True)

        # Top-5 Accuracy
        if 'val_acc_top5' in self.metrics:
            axes[1, 1].plot(self.metrics['val_acc_top5'], label='Val Top-5')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Top-5 Accuracy (%)')
            axes[1, 1].set_title('Top-5 Accuracy')
            axes[1, 1].legend()
            axes[1, 1].grid(True)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=100, bbox_inches='tight')
        plt.show()

In [None]:
# =============================================================================
# Training Engine
# =============================================================================

class Trainer:
    """Main training orchestrator with optimizations"""

    def __init__(self, model: nn.Module, config: TrainingConfig, device: torch.device):
        self.model = model
        self.config = config
        self.device = device

        # Move model to device
        self.model = self.model.to(device)

        # Loss function
        if config.label_smoothing > 0:
            self.criterion = LabelSmoothingCrossEntropy(config.label_smoothing)
        else:
            self.criterion = nn.CrossEntropyLoss()

        # Optimizer
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config.learning_rate,
            momentum=config.momentum,
            weight_decay=config.weight_decay,
            nesterov=True
        )

        # Learning rate scheduler
        if config.lr_scheduler == 'cosine':
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=config.epochs
            )
        else:
            self.scheduler = optim.lr_scheduler.MultiStepLR(
                self.optimizer,
                milestones=config.lr_milestones,
                gamma=config.lr_gamma
            )

        # Mixed precision training
        self.scaler = GradScaler() if config.mixed_precision else None

        # Metrics and tracking
        self.metric_tracker = MetricTracker()
        self.early_stopping = EarlyStopping(config.patience, config.min_delta)

        # Best model tracking
        self.best_val_acc = 0.0
        self.best_model_state = None

    def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]:
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Progress bar
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.config.epochs} [Train]')

        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # Mixed precision training
            if self.config.mixed_precision:
                with autocast():
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, targets)
                    loss = loss / self.config.gradient_accumulation_steps

                self.scaler.scale(loss).backward()

                if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                    # Gradient clipping
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
            else:
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                loss = loss / self.config.gradient_accumulation_steps
                loss.backward()

                if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
                    self.optimizer.step()
                    self.optimizer.zero_grad()

            # Metrics
            running_loss += loss.item() * self.config.gradient_accumulation_steps
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Update progress bar
            current_acc = 100. * correct / total
            pbar.set_postfix({
                'loss': f'{running_loss/(batch_idx+1):.4f}',
                'acc': f'{current_acc:.2f}%'
            })

        return {
            'train_loss': running_loss / len(train_loader),
            'train_acc': 100. * correct / total
        }

    def validate(self, val_loader: DataLoader, epoch: int) -> Dict[str, float]:
        """Validate the model"""
        self.model.eval()
        running_loss = 0.0
        correct = 0
        correct_top5 = 0
        total = 0

        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{self.config.epochs} [Val]')

            for inputs, targets in pbar:
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                if self.config.mixed_precision:
                    with autocast():
                        outputs = self.model(inputs)
                        loss = self.criterion(outputs, targets)
                else:
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, targets)

                running_loss += loss.item()

                # Top-1 accuracy
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # Top-5 accuracy
                _, pred_top5 = outputs.topk(5, 1, True, True)
                correct_top5 += pred_top5.eq(targets.view(-1, 1).expand_as(pred_top5)).sum().item()

                # Update progress bar
                current_acc = 100. * correct / total
                pbar.set_postfix({
                    'loss': f'{running_loss/(pbar.n+1):.4f}',
                    'acc': f'{current_acc:.2f}%'
                })

        return {
            'val_loss': running_loss / len(val_loader),
            'val_acc': 100. * correct / total,
            'val_acc_top5': 100. * correct_top5 / total
        }

    def train(self, train_loader: DataLoader, val_loader: DataLoader):
        """Full training loop"""
        print(f"\nStarting training on {self.device}")
        print(f"Model architecture: {self.model.__class__.__name__}")

        # Model statistics
        stats = self.model.get_model_stats()
        print(f"Total parameters: {stats['total_params']:,}")
        print(f"Trainable parameters: {stats['trainable_params']:,}")
        print(f"Model size: {stats['total_params_mb']:.2f} MB")
        print(f"Estimated FLOPs: {stats['estimated_flops_m']:.2f}M")
        print("-" * 50)

        start_time = time.time()

        for epoch in range(self.config.epochs):
            epoch_start = time.time()

            # Training
            train_metrics = self.train_epoch(train_loader, epoch)

            # Validation
            val_metrics = self.validate(val_loader, epoch)

            # Learning rate scheduling
            current_lr = self.optimizer.param_groups[0]['lr']
            self.scheduler.step()

            # Warm-up learning rate
            if epoch < self.config.warmup_epochs:
                warmup_lr = self.config.learning_rate * (epoch + 1) / self.config.warmup_epochs
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = warmup_lr
                current_lr = warmup_lr

            # Track metrics
            self.metric_tracker.update({
                **train_metrics,
                **val_metrics,
                'learning_rate': current_lr
            }, epoch)

            # Print epoch summary
            epoch_time = time.time() - epoch_start
            print(f"\nEpoch {epoch+1}/{self.config.epochs} - {epoch_time:.1f}s")
            print(f"Train Loss: {train_metrics['train_loss']:.4f}, Train Acc: {train_metrics['train_acc']:.2f}%")
            print(f"Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_acc']:.2f}%, Top-5: {val_metrics['val_acc_top5']:.2f}%")
            print(f"Learning Rate: {current_lr:.6f}")

            # Save best model
            if val_metrics['val_acc'] > self.best_val_acc:
                self.best_val_acc = val_metrics['val_acc']
                self.best_model_state = self.model.state_dict().copy()
                print(f"New best model! Val Acc: {self.best_val_acc:.2f}%")

            # Checkpoint
            if (epoch + 1) % self.config.checkpoint_interval == 0:
                self.save_checkpoint(epoch, val_metrics['val_acc'])

            # Early stopping
            if self.early_stopping(val_metrics['val_loss']):
                print(f"\nEarly stopping triggered at epoch {epoch+1}")
                break

            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        total_time = time.time() - start_time
        print(f"\nTraining completed in {total_time/60:.1f} minutes")
        print(f"Best validation accuracy: {self.best_val_acc:.2f}%")

        # Load best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)

        return self.metric_tracker

    def save_checkpoint(self, epoch: int, val_acc: float):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'val_acc': val_acc,
            'config': asdict(self.config)
        }

        checkpoint_path = f'checkpoint_epoch_{epoch+1}_acc_{val_acc:.2f}.pth'
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

    def test_model(self, test_loader: DataLoader) -> Dict[str, float]:
        """Final model evaluation"""
        self.model.eval()
        correct = 0
        correct_top5 = 0
        total = 0
        class_correct = list(0. for i in range(self.config.num_classes))
        class_total = list(0. for i in range(self.config.num_classes))

        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc='Testing'):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)

                # Top-1
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # Top-5
                _, pred_top5 = outputs.topk(5, 1, True, True)
                correct_top5 += pred_top5.eq(targets.view(-1, 1).expand_as(pred_top5)).sum().item()

                # Per-class accuracy
                c = predicted.eq(targets).squeeze()
                for i in range(targets.size(0)):
                    label = targets[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

        # Calculate metrics
        test_acc = 100. * correct / total
        test_acc_top5 = 100. * correct_top5 / total

        print(f"\nTest Results:")
        print(f"Top-1 Accuracy: {test_acc:.2f}%")
        print(f"Top-5 Accuracy: {test_acc_top5:.2f}%")

        # Per-class accuracy analysis
        class_accuracies = []
        for i in range(self.config.num_classes):
            if class_total[i] > 0:
                acc = 100 * class_correct[i] / class_total[i]
                class_accuracies.append(acc)

        print(f"Mean class accuracy: {np.mean(class_accuracies):.2f}%")
        print(f"Std class accuracy: {np.std(class_accuracies):.2f}%")
        print(f"Min class accuracy: {np.min(class_accuracies):.2f}%")
        print(f"Max class accuracy: {np.max(class_accuracies):.2f}%")

        return {
            'test_acc': test_acc,
            'test_acc_top5': test_acc_top5,
            'class_accuracies': class_accuracies
        }

In [None]:
# =============================================================================
# Performance Monitoring
# =============================================================================

class PerformanceMonitor:
    """Monitor GPU usage and training performance"""

    def __init__(self):
        self.gpu_available = torch.cuda.is_available()

    def get_gpu_memory_usage(self) -> Dict[str, float]:
        """Get current GPU memory statistics"""
        if not self.gpu_available:
            return {}

        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        max_allocated = torch.cuda.max_memory_allocated() / 1024**3

        return {
            'allocated_gb': allocated,
            'reserved_gb': reserved,
            'max_allocated_gb': max_allocated,
            'free_gb': reserved - allocated
        }

    def profile_dataloader(self, dataloader: DataLoader, num_batches: int = 10) -> Dict[str, float]:
        """Profile data loading performance"""
        load_times = []

        for i, (data, target) in enumerate(dataloader):
            if i >= num_batches:
                break

            start = time.time()
            data = data.cuda(non_blocking=True) if self.gpu_available else data
            target = target.cuda(non_blocking=True) if self.gpu_available else target
            torch.cuda.synchronize() if self.gpu_available else None
            load_times.append(time.time() - start)

        return {
            'mean_load_time': np.mean(load_times),
            'std_load_time': np.std(load_times),
            'min_load_time': np.min(load_times),
            'max_load_time': np.max(load_times)
        }

    def benchmark_model(self, model: nn.Module, input_size: Tuple[int, ...],
                       batch_sizes: List[int] = [16, 32, 64, 128]) -> Dict[int, Dict[str, float]]:
        """Benchmark model inference speed"""
        model.eval()
        results = {}

        for batch_size in batch_sizes:
            try:
                dummy_input = torch.randn(batch_size, *input_size)
                if self.gpu_available:
                    dummy_input = dummy_input.cuda()
                    model = model.cuda()

                # Warm-up
                for _ in range(10):
                    with torch.no_grad():
                        _ = model(dummy_input)

                if self.gpu_available:
                    torch.cuda.synchronize()

                # Benchmark
                times = []
                for _ in range(100):
                    start = time.time()
                    with torch.no_grad():
                        _ = model(dummy_input)
                    if self.gpu_available:
                        torch.cuda.synchronize()
                    times.append(time.time() - start)

                results[batch_size] = {
                    'mean_time': np.mean(times),
                    'std_time': np.std(times),
                    'throughput': batch_size / np.mean(times),
                    'memory_gb': self.get_gpu_memory_usage().get('allocated_gb', 0)
                }

                # Clear cache
                if self.gpu_available:
                    torch.cuda.empty_cache()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    results[batch_size] = {'error': 'OOM'}
                    if self.gpu_available:
                        torch.cuda.empty_cache()
                else:
                    raise e

        return results

In [None]:
# =============================================================================
# Main Execution Pipeline
# =============================================================================

class CIFAR100Pipeline:
    """Complete training pipeline orchestrator"""

    def __init__(self, config: Optional[TrainingConfig] = None):
        self.config = config or TrainingConfig()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # System check
        self.system_info = SystemChecker.check_environment()
        self._print_system_info()

        # Adjust configuration based on system
        if self.system_info['cuda_available']:
            self.config.num_workers = SystemChecker.get_optimal_workers()
        else:
            self.config.mixed_precision = False
            self.config.num_workers = 0
            print("Warning: CUDA not available, falling back to CPU training")

    def _print_system_info(self):
        """Print system configuration"""
        print("="*50)
        print("System Configuration")
        print("="*50)
        for key, value in self.system_info.items():
            print(f"{key}: {value}")
        print("="*50)

    def run(self):
        """Execute complete training pipeline"""
        # Initialize components
        print("\nInitializing data pipeline...")
        data_module = CIFAR100DataModule(self.config)
        train_loader, val_loader = data_module.get_dataloaders()

        print(f"Training samples: {len(train_loader.dataset)}")
        print(f"Validation samples: {len(val_loader.dataset)}")

        # Initialize model
        print("\nInitializing model...")
        model = CIFAR100CNN(self.config)

        # Performance monitoring
        monitor = PerformanceMonitor()

        # Profile data loading
        print("\nProfiling data loader...")
        loader_stats = monitor.profile_dataloader(train_loader, num_batches=10)
        print(f"Mean batch load time: {loader_stats['mean_load_time']*1000:.2f}ms")

        # Benchmark model
        print("\nBenchmarking model performance...")
        benchmark_results = monitor.benchmark_model(
            model,
            input_size=(3, 32, 32),
            batch_sizes=[32, 64, 128, 256]
        )

        print("\nInference Benchmark Results:")
        print("-"*50)
        for batch_size, stats in benchmark_results.items():
            if 'error' in stats:
                print(f"Batch {batch_size}: {stats['error']}")
            else:
                print(f"Batch {batch_size}: {stats['throughput']:.1f} img/s, "
                      f"Memory: {stats['memory_gb']:.2f}GB")
        print("-"*50)

        # Initialize trainer
        print("\nInitializing trainer...")
        trainer = Trainer(model, self.config, self.device)

        # Training
        print("\n" + "="*50)
        print("Starting Training")
        print("="*50)

        try:
            metric_tracker = trainer.train(train_loader, val_loader)

            # Final evaluation
            print("\n" + "="*50)
            print("Final Evaluation")
            print("="*50)

            test_results = trainer.test_model(val_loader)

            # Plot training curves
            print("\nGenerating training curves...")
            metric_tracker.plot_metrics(save_path='training_curves.png')

            # Memory usage summary
            if monitor.gpu_available:
                mem_stats = monitor.get_gpu_memory_usage()
                print("\nGPU Memory Usage:")
                print(f"Peak allocated: {mem_stats['max_allocated_gb']:.2f}GB")
                print(f"Current allocated: {mem_stats['allocated_gb']:.2f}GB")

            # Training summary
            best_val_acc, best_epoch = metric_tracker.get_best('val_acc', mode='max')
            print("\n" + "="*50)
            print("Training Summary")
            print("="*50)
            print(f"Best validation accuracy: {best_val_acc:.2f}% (Epoch {best_epoch+1})")
            print(f"Final test accuracy: {test_results['test_acc']:.2f}%")
            print(f"Final test top-5 accuracy: {test_results['test_acc_top5']:.2f}%")

            # Save final model
            self._save_final_model(trainer.model, test_results)

            return trainer.model, metric_tracker, test_results

        except KeyboardInterrupt:
            print("\nTraining interrupted by user")
            return trainer.model, trainer.metric_tracker, None
        except Exception as e:
            print(f"\nError during training: {str(e)}")
            raise e
        finally:
            # Cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    def _save_final_model(self, model: nn.Module, test_results: Dict):
        """Save the final trained model"""
        model_save_path = 'cifar100_final_model.pth'

        save_dict = {
            'model_state_dict': model.state_dict(),
            'config': asdict(self.config),
            'test_results': test_results,
            'model_stats': model.get_model_stats()
        }

        torch.save(save_dict, model_save_path)
        print(f"\nFinal model saved to: {model_save_path}")

In [None]:
# =============================================================================
# Quick Start Function
# =============================================================================

def train_cifar100(
    epochs: int = 100,
    batch_size: int = 128,
    learning_rate: float = 0.1,
    mixed_precision: bool = True,
    augmentation: str = 'moderate'
):
    """
    Quick start training function

    Args:
        epochs: Number of training epochs
        batch_size: Batch size for training
        learning_rate: Initial learning rate
        mixed_precision: Use mixed precision training
        augmentation: Data augmentation strength ('light', 'moderate', 'strong')

    Returns:
        Trained model, metrics tracker, and test results
    """
    config = TrainingConfig(
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        mixed_precision=mixed_precision,
        augmentation_strength=augmentation
    )

    pipeline = CIFAR100Pipeline(config)
    return pipeline.run()

In [None]:
# =============================================================================
# Colab Execution
# =============================================================================

if __name__ == "__main__":
    # Check if running in Colab
    try:
        import google.colab
        IN_COLAB = True
        print("Running in Google Colab environment")

        #  Mount Google Drive for saving checkpoints (optional)
        from google.colab import drive
        drive.mount('/content/drive')

    except ImportError:
        IN_COLAB = False
        print("Running in local environment")

    # Execute training pipeline
    print("\n" + "="*60)
    print("CIFAR-100 CNN Training System")
    print("Production-Ready Implementation")
    print("="*60)

    # Train with default configuration
    model, metrics, results = train_cifar100(
        epochs=150,
        batch_size=256,
        learning_rate=0.15,
        mixed_precision=True,
        augmentation='strong'
    )

    print("\n" + "="*60)
    print("Training Complete!")
    print("="*60)