<a href="https://colab.research.google.com/github/theboredman/CSE468/blob/main/Quiz_1/CNN/Using_CNN_CIFAR100_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Prepare the data

In [2]:
num_classes = 100
input_shape = (3, 32, 32)  # PyTorch uses (C, H, W) format

from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandAugment

# Define transforms for training
train_transform = transforms.Compose([
    transforms.Resize((80, 80)),
    transforms.RandomCrop(72, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),

    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),

    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]),
    # Random Erasing
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
])

test_transform = transforms.Compose([
    transforms.Resize((72, 72)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

# Load CIFAR-100 dataset
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=train_transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=test_transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

100%|██████████| 169M/169M [00:13<00:00, 12.4MB/s]


Training samples: 50000
Test samples: 10000


## Configure the hyperparameters

In [3]:
learning_rate = 0.001
weight_decay = 0.05
batch_size = 256
num_epochs = 50
image_size = 72

mlp_head_units = [2048, 1024]

# Advanced hyperparameters optimized for speed
initial_learning_rate = 0.003
label_smoothing = 0.1
dropout_rate = 0.3
warmup_epochs = 3
min_lr = 1e-6
mixup_alpha = 0.4
cutmix_alpha = 1.0
mixup_prob = 0.8
stochastic_depth_rate = 0.3
gradient_clip_norm = 1.0
ema_decay = 0.9998
accumulation_steps = 2

## Implement multilayer perceptron (MLP)

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim

        for units in hidden_units:
            layers.extend([
                nn.Linear(prev_dim, units),
                nn.GELU(),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = units

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

## CNN Architecture

We'll build a CNN with convolutional layers for feature extraction followed by dense layers for classification.

In [5]:
class StochasticDepth(nn.Module):
    """Stochastic Depth (Drop Path) for better regularization"""
    def __init__(self, drop_prob=0.0):
        super(StochasticDepth, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0.:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor

class SqueezeExcitationBlock(nn.Module):
    """Enhanced Squeeze-and-Excitation block"""
    def __init__(self, channels, ratio=16):
        super(SqueezeExcitationBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        reduced_channels = max(1, channels // ratio)
        self.fc1 = nn.Linear(channels, reduced_channels)
        self.fc2 = nn.Linear(reduced_channels, channels)
        self.swish = nn.SiLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        se = self.global_pool(x).view(b, c)
        se = self.swish(self.fc1(se))
        se = self.sigmoid(self.fc2(se)).view(b, c, 1, 1)
        return x * se

class FusedMBConv(nn.Module):
    """Fused MBConv from EfficientNetV2 - faster and more efficient"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, expansion_factor=4, drop_path=0.0):
        super(FusedMBConv, self).__init__()
        self.stride = stride
        self.use_residual = stride == 1 and in_channels == out_channels
        expanded_channels = in_channels * expansion_factor

        # Fused expansion + depthwise
        self.fused_conv = nn.Sequential(
            nn.Conv2d(in_channels, expanded_channels, kernel_size, stride, padding=kernel_size//2, bias=False),
            nn.GroupNorm(16, expanded_channels),
            nn.SiLU()
        )

        # SE block
        self.se = SqueezeExcitationBlock(expanded_channels, ratio=4)

        # Projection
        self.project = nn.Sequential(
            nn.Conv2d(expanded_channels, out_channels, 1, bias=False),
            nn.GroupNorm(16, out_channels)
        )

        # Stochastic depth
        self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x):
        identity = x
        x = self.fused_conv(x)
        x = self.se(x)
        x = self.project(x)

        if self.use_residual:
            x = self.drop_path(x)
            x = x + identity

        return x

class ImprovedResidualBlock(nn.Module):
    """Enhanced residual block with Stochastic Depth"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, use_se=True, drop_path=0.0):
        super(ImprovedResidualBlock, self).__init__()
        self.use_se = use_se
        self.use_residual = stride == 1 and in_channels == out_channels

        # First conv layer with Group Normalization
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size//2, bias=False)
        self.gn1 = nn.GroupNorm(16, out_channels)

        # Second conv layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding=kernel_size//2, bias=False)
        self.gn2 = nn.GroupNorm(16, out_channels)

        # Squeeze-and-Excitation
        if use_se:
            self.se = SqueezeExcitationBlock(out_channels, ratio=4)

        # Shortcut connection
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.GroupNorm(16, out_channels)
            )
        else:
            self.shortcut = nn.Identity()

        self.swish = nn.SiLU()
        self.drop_path = StochasticDepth(drop_path) if drop_path > 0 and self.use_residual else nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)

        out = self.swish(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))

        if self.use_se:
            out = self.se(out)

        out = self.drop_path(out)
        out = out + residual
        out = self.swish(out)
        return out

class CNNClassifier(nn.Module):
    def __init__(self, num_classes=100, mlp_head_units=[2048, 1024], drop_path_rate=0.2):
        super(CNNClassifier, self).__init__()

        # Calculate drop path rates (linearly increasing)
        total_blocks = 16
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_blocks)]
        block_idx = 0

        # Enhanced Stem with more capacity
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1, bias=False),
            nn.GroupNorm(8, 32),
            nn.SiLU(),
            nn.Conv2d(32, 48, 3, padding=1, bias=False),
            nn.GroupNorm(8, 48),
            nn.SiLU()
        )

        # Stage 1: FusedMBConv blocks (efficient early stages)
        self.stage1 = nn.Sequential(
            FusedMBConv(48, 64, stride=1, expansion_factor=4, drop_path=dpr[block_idx]),
            FusedMBConv(64, 64, stride=1, expansion_factor=4, drop_path=dpr[block_idx+1]),
            FusedMBConv(64, 64, stride=2, expansion_factor=4, drop_path=dpr[block_idx+2]),
            nn.Dropout2d(0.1)
        )
        block_idx += 3

        # Stage 2: Mixed block types
        self.stage2 = nn.Sequential(
            FusedMBConv(64, 96, stride=1, expansion_factor=4, drop_path=dpr[block_idx]),
            ImprovedResidualBlock(96, 96, stride=1, use_se=True, drop_path=dpr[block_idx+1]),
            FusedMBConv(96, 96, stride=2, expansion_factor=4, drop_path=dpr[block_idx+2]),
            nn.Dropout2d(0.15)
        )
        block_idx += 3

        # Stage 3: Deeper feature extraction
        self.stage3 = nn.Sequential(
            ImprovedResidualBlock(96, 144, stride=1, use_se=True, drop_path=dpr[block_idx]),
            FusedMBConv(144, 144, stride=1, expansion_factor=6, drop_path=dpr[block_idx+1]),
            ImprovedResidualBlock(144, 144, stride=1, use_se=True, drop_path=dpr[block_idx+2]),
            nn.AvgPool2d(2),
            nn.Dropout2d(0.2)
        )
        block_idx += 3

        # Stage 4: High-level features
        self.stage4 = nn.Sequential(
            FusedMBConv(144, 192, stride=1, expansion_factor=6, drop_path=dpr[block_idx]),
            ImprovedResidualBlock(192, 192, stride=1, use_se=True, drop_path=dpr[block_idx+1]),
            FusedMBConv(192, 192, stride=1, expansion_factor=6, drop_path=dpr[block_idx+2]),
            ImprovedResidualBlock(192, 192, stride=1, use_se=True, drop_path=dpr[block_idx+3]),
            nn.Dropout2d(0.25)
        )
        block_idx += 4

        # Stage 5: Final high-capacity features
        self.stage5 = nn.Sequential(
            ImprovedResidualBlock(192, 256, stride=1, use_se=True, drop_path=dpr[block_idx]),
            FusedMBConv(256, 256, stride=1, expansion_factor=8, drop_path=dpr[block_idx+1]),
            ImprovedResidualBlock(256, 256, stride=1, use_se=True, drop_path=dpr[block_idx+2])
        )

        # Multi-scale feature aggregation
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.global_max_pool = nn.AdaptiveMaxPool2d(1)

        # Enhanced classifier with GeM pooling information
        self.classifier_head = nn.Sequential(
            nn.Linear(512, 896, bias=False),  # 256*2 from pooling + extra capacity
            nn.GroupNorm(1, 896),
            nn.SiLU(),
            nn.Dropout(0.5),

            nn.Linear(896, 512, bias=False),
            nn.GroupNorm(1, 512),
            nn.SiLU(),
            nn.Dropout(0.4)
        )

        # MLP classification head
        self.mlp = MLP(512, mlp_head_units, 0.5)

        # Final classification layer
        self.final_classifier = nn.Linear(mlp_head_units[-1], num_classes)

        # Initialize weights with better strategy
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.GroupNorm, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Stem
        x = self.stem(x)

        # Progressive stages
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)

        # Multi-scale global pooling
        gap = self.global_avg_pool(x).flatten(1)
        gmp = self.global_max_pool(x).flatten(1)
        x = torch.cat([gap, gmp], dim=1)

        # Classification head
        x = self.classifier_head(x)
        x = self.mlp(x)
        x = self.final_classifier(x)

        return x

## Training and Evaluation Functions

In [6]:
class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing cross entropy loss"""
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        log_prob = F.log_softmax(pred, dim=-1)
        weight = pred.new_ones(pred.size()) * self.smoothing / (pred.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

class ModelEMA:
    """Exponential Moving Average of model parameters"""
    def __init__(self, model, decay=0.9999):
        self.module = type(model)(num_classes=model.final_classifier.out_features,
                                   mlp_head_units=[2048, 1024],
                                   drop_path_rate=stochastic_depth_rate).to(next(model.parameters()).device)
        self.module.load_state_dict(model.state_dict())
        self.module.eval()
        self.decay = decay
        self.device = next(model.parameters()).device

    def update(self, model):
        with torch.no_grad():
            for ema_param, model_param in zip(self.module.parameters(), model.parameters()):
                ema_param.data.mul_(self.decay).add_(model_param.data, alpha=1 - self.decay)

def mixup_data(x, y, alpha=0.2):
    """Apply Mixup augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    """Apply CutMix augmentation"""
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    _, _, H, W = x.size()
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # Uniform sampling
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]

    # Adjust lambda to match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

def calculate_accuracy(outputs, targets, topk=(1, 5, 10)):
    """Calculate top-k accuracy"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = targets.size(0)

        _, pred = outputs.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def train_epoch(model, train_loader, criterion, optimizer, device, ema=None, use_mixup=True, scaler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        # Mixed precision training
        with torch.amp.autocast('cuda', enabled=(scaler is not None)):
            # Apply Mixup or CutMix
            if use_mixup and np.random.rand() < mixup_prob:
                if np.random.rand() < 0.5:
                    data, target_a, target_b, lam = mixup_data(data, target, mixup_alpha)
                else:
                    data, target_a, target_b, lam = cutmix_data(data, target, cutmix_alpha)

                output = model(data)
                loss = lam * criterion(output, target_a) + (1 - lam) * criterion(output, target_b)
            else:
                output = model(data)
                loss = criterion(output, target)

            # Gradient accumulation
            loss = loss / accumulation_steps

        # Backward pass with mixed precision
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        # Update weights every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            if scaler is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm)
                optimizer.step()

            optimizer.zero_grad()

            # Update EMA
            if ema is not None:
                ema.update(model)

        running_loss += loss.item() * accumulation_steps
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        pbar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.3f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })

    return running_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    top1_correct = 0
    top5_correct = 0
    top10_correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += criterion(output, target).item()

            # Calculate top-k accuracies
            acc1, acc5, acc10 = calculate_accuracy(output, target, topk=(1, 5, 10))
            top1_correct += acc1.item() * target.size(0) / 100
            top5_correct += acc5.item() * target.size(0) / 100
            top10_correct += acc10.item() * target.size(0) / 100
            total += target.size(0)

            pbar.set_postfix({
                'Loss': f'{val_loss/total:.3f}',
                'Top1': f'{100.*top1_correct/total:.2f}%',
                'Top5': f'{100.*top5_correct/total:.2f}%'
            })

    return (val_loss / len(val_loader),
            100. * top1_correct / total,
            100. * top5_correct / total,
            100. * top10_correct / total)

def warmup_cosine_annealing_lr(epoch, warmup_epochs, total_epochs, initial_lr, min_lr):
    """Warmup + Cosine annealing learning rate schedule"""
    if epoch < warmup_epochs:
        return initial_lr * (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return min_lr + (initial_lr - min_lr) * 0.5 * (1 + np.cos(np.pi * progress))

def train_model(model, train_loader, val_loader, num_epochs, device):
    # Loss function with label smoothing
    criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing)

    # Optimizer with higher LR for faster convergence
    optimizer = optim.AdamW(model.parameters(),
                           lr=initial_learning_rate,
                           weight_decay=weight_decay,
                           betas=(0.9, 0.999),
                           eps=1e-8)

    # OneCycleLR for faster convergence - best scheduler for limited epochs
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=initial_learning_rate,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader) // accumulation_steps,
        pct_start=0.1,  # 10% warmup
        anneal_strategy='cos',
        div_factor=25.0,  # initial_lr = max_lr / 25
        final_div_factor=1e4  # min_lr = initial_lr / 1e4
    )

    # Initialize EMA
    ema = ModelEMA(model, decay=ema_decay)

    # Mixed precision training for faster computation
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'val_top5_acc': [], 'val_top10_acc': [],
        'lr': []
    }

    best_val_acc = 0.0
    patience_counter = 0
    patience = 15  # Early stopping patience

    print(f"\n⚡ Fast Training Mode: {num_epochs} epochs with aggressive optimization")
    print(f"Effective batch size: {batch_size * accumulation_steps}")
    print(f"Using {'mixed precision' if scaler else 'full precision'} training\n")

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)

        # Training
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer,
                                           device, ema=ema, use_mixup=True, scaler=scaler)

        # Validation with EMA model
        val_loss, val_acc, val_top5_acc, val_top10_acc = validate(ema.module, val_loader, criterion, device)

        # Update learning rate (OneCycleLR steps per batch, not per epoch)
        current_lr = optimizer.param_groups[0]['lr']

        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_top5_acc'].append(val_top5_acc)
        history['val_top10_acc'].append(val_top10_acc)
        history['lr'].append(current_lr)

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, Val Top5: {val_top5_acc:.2f}%, Val Top10: {val_top10_acc:.2f}%')
        print(f'Learning Rate: {current_lr:.6f}')

        # Early stopping and model checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'ema_state_dict': ema.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, 'best_model.pth')
            print(f'✓ New best validation accuracy: {best_val_acc:.2f}%')
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'⚠ Early stopping at epoch {epoch+1}')
            break

    # Load best EMA model
    checkpoint = torch.load('best_model.pth')
    ema.module.load_state_dict(checkpoint['ema_state_dict'])

    return ema.module, history

## Compile, train, and evaluate the model

In [None]:
# Create data loaders with optimized settings
# Split training data for validation
train_size = int(0.9 * len(train_dataset))  # Use more training data
val_size = len(train_dataset) - train_size
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# Optimized data loaders
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,
                         num_workers=4, pin_memory=True, persistent_workers=True,
                         prefetch_factor=2)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False,
                        num_workers=4, pin_memory=True, persistent_workers=True,
                        prefetch_factor=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=4, pin_memory=True, persistent_workers=True,
                         prefetch_factor=2)

# Create model with Stochastic Depth
model = CNNClassifier(num_classes=num_classes, mlp_head_units=mlp_head_units,
                     drop_path_rate=stochastic_depth_rate).to(device)

# Print model summary
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("=" * 70)
print("CNN MODEL FOR CIFAR-100")
print("=" * 70)
print(f"\nModel Configuration:")
print(f"  • Total trainable parameters: {count_parameters(model):,}")
print(f"  • Batch size: {batch_size} (effective: {batch_size * accumulation_steps} with gradient accumulation)")
print(f"  • Training epochs: {num_epochs}")
print(f"  • Initial learning rate: {initial_learning_rate}")
print(f"  • Weight decay: {weight_decay}")
print(f"  • Gradient accumulation steps: {accumulation_steps}")
print(f"\nFast Training Optimizations:")
print(f"  ⚡ OneCycleLR scheduler (superconvergence)")
print(f"  ⚡ Mixed precision training (FP16)")
print(f"  ⚡ Gradient accumulation (2x effective batch)")
print(f"  ⚡ Larger batch size for stability")
print(f"  ⚡ Higher learning rate (0.003)")
print(f"  ⚡ Aggressive augmentation (mixup_prob={mixup_prob})")
print(f"\nRegularization:")
print(f"  • Stochastic depth: {stochastic_depth_rate}")
print(f"  • EMA decay: {ema_decay}")
print(f"  • Label smoothing: {label_smoothing}")
print(f"  • Mixup alpha: {mixup_alpha}")
print(f"  • CutMix alpha: {cutmix_alpha}")
print(f"\nAugmentation Strategy:")
print(f"  • AutoAugment (CIFAR10 policy)")
print(f"  • RandAugment (2 ops, magnitude 9)")
print(f"  • Random Erasing (p=0.25)")
print(f"  • Mixup/CutMix (p={mixup_prob})")
print(f"\nArchitecture Features:")
print(f"  • FusedMBConv blocks (EfficientNetV2)")
print(f"  • Stochastic Depth regularization")
print(f"  • Squeeze-and-Excitation attention")
print(f"  • Multi-scale global pooling")
print(f"  • Exponential Moving Average")
print("=" * 70)

# Train the model
print("Expected completion: ~50 epochs with early stopping")
model, history = train_model(model, train_loader, val_loader, num_epochs, device)

# Evaluate on test set
print("\n" + "=" * 70)
print("FINAL EVALUATION ON TEST SET")
print("=" * 70)
criterion = LabelSmoothingCrossEntropy(smoothing=0.0)  # No smoothing for evaluation
test_loss, test_acc, test_top5_acc, test_top10_acc = validate(model, test_loader, criterion, device)

print(f"\n📊 Final Test Results:")
print(f"  • Test accuracy (Top-1): {test_acc:.2f}%")
print(f"  • Test top-5 accuracy:   {test_top5_acc:.2f}%")
print(f"  • Test top-10 accuracy:  {test_top10_acc:.2f}%")
print(f"  • Test loss:             {test_loss:.4f}")
print("=" * 70)

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'test_acc': test_acc,
    'test_top5_acc': test_top5_acc,
    'test_top10_acc': test_top10_acc,
    'history': history
}, 'final_sota_model.pth')
print("\n✓ Model saved as 'final_sota_model.pth'")

CNN MODEL FOR CIFAR-100

Model Configuration:
  • Total trainable parameters: 23,442,308
  • Batch size: 256 (effective: 512 with gradient accumulation)
  • Training epochs: 50
  • Initial learning rate: 0.003
  • Weight decay: 0.05
  • Gradient accumulation steps: 2

Fast Training Optimizations:
  ⚡ OneCycleLR scheduler (superconvergence)
  ⚡ Mixed precision training (FP16)
  ⚡ Gradient accumulation (2x effective batch)
  ⚡ Larger batch size for stability
  ⚡ Higher learning rate (0.003)
  ⚡ Aggressive augmentation (mixup_prob=0.8)

Regularization:
  • Stochastic depth: 0.3
  • EMA decay: 0.9998
  • Label smoothing: 0.1
  • Mixup alpha: 0.4
  • CutMix alpha: 1.0

Augmentation Strategy:
  • AutoAugment (CIFAR10 policy)
  • RandAugment (2 ops, magnitude 9)
  • Random Erasing (p=0.25)
  • Mixup/CutMix (p=0.8)

Architecture Features:
  • FusedMBConv blocks (EfficientNetV2)
  • Stochastic Depth regularization
  • Squeeze-and-Excitation attention
  • Multi-scale global pooling
  • Exponenti

  scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None



⚡ Fast Training Mode: 50 epochs with aggressive optimization
Effective batch size: 512
Using mixed precision training


Epoch 1/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:48<00:00,  1.04it/s, Loss=4.611, Acc=0.95%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.18it/s, Loss=0.018, Top1=0.88%, Top5=5.06%]


Train Loss: 4.6113, Train Acc: 0.95%
Val Loss: 4.6065, Val Acc: 0.88%, Val Top5: 5.06%, Val Top10: 9.78%
Learning Rate: 0.000120
✓ New best validation accuracy: 0.88%

Epoch 2/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.607, Acc=1.06%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.28it/s, Loss=0.018, Top1=1.04%, Top5=5.04%]


Train Loss: 4.6074, Train Acc: 1.06%
Val Loss: 4.6062, Val Acc: 1.04%, Val Top5: 5.04%, Val Top10: 9.54%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.04%

Epoch 3/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.607, Acc=1.00%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=0.96%, Top5=4.74%]


Train Loss: 4.6069, Train Acc: 1.00%
Val Loss: 4.6063, Val Acc: 0.96%, Val Top5: 4.74%, Val Top10: 9.58%
Learning Rate: 0.000120

Epoch 4/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.606, Acc=0.94%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=0.84%, Top5=4.52%]


Train Loss: 4.6064, Train Acc: 0.94%
Val Loss: 4.6063, Val Acc: 0.84%, Val Top5: 4.52%, Val Top10: 9.28%
Learning Rate: 0.000120

Epoch 5/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.606, Acc=0.98%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.02%, Top5=4.50%]


Train Loss: 4.6057, Train Acc: 0.98%
Val Loss: 4.6063, Val Acc: 1.02%, Val Top5: 4.50%, Val Top10: 8.96%
Learning Rate: 0.000120

Epoch 6/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.602, Acc=1.16%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.28it/s, Loss=0.018, Top1=0.86%, Top5=4.58%]


Train Loss: 4.6019, Train Acc: 1.16%
Val Loss: 4.6060, Val Acc: 0.86%, Val Top5: 4.58%, Val Top10: 9.60%
Learning Rate: 0.000120

Epoch 7/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.566, Acc=1.53%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=0.64%, Top5=5.06%]


Train Loss: 4.5664, Train Acc: 1.53%
Val Loss: 4.6063, Val Acc: 0.64%, Val Top5: 5.06%, Val Top10: 9.20%
Learning Rate: 0.000120

Epoch 8/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.532, Acc=1.54%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.04%, Top5=4.46%]


Train Loss: 4.5317, Train Acc: 1.54%
Val Loss: 4.6064, Val Acc: 1.04%, Val Top5: 4.46%, Val Top10: 9.54%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.04%

Epoch 9/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.511, Acc=1.87%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.28it/s, Loss=0.018, Top1=0.96%, Top5=4.72%]


Train Loss: 4.5110, Train Acc: 1.87%
Val Loss: 4.6060, Val Acc: 0.96%, Val Top5: 4.72%, Val Top10: 9.08%
Learning Rate: 0.000120

Epoch 10/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.487, Acc=1.92%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=0.92%, Top5=4.70%]


Train Loss: 4.4871, Train Acc: 1.92%
Val Loss: 4.6058, Val Acc: 0.92%, Val Top5: 4.70%, Val Top10: 9.34%
Learning Rate: 0.000120

Epoch 11/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.478, Acc=2.17%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.02%, Top5=5.14%]


Train Loss: 4.4781, Train Acc: 2.17%
Val Loss: 4.6053, Val Acc: 1.02%, Val Top5: 5.14%, Val Top10: 9.22%
Learning Rate: 0.000120

Epoch 12/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.456, Acc=2.35%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.04%, Top5=4.96%]


Train Loss: 4.4561, Train Acc: 2.35%
Val Loss: 4.6052, Val Acc: 1.04%, Val Top5: 4.96%, Val Top10: 9.64%
Learning Rate: 0.000120

Epoch 13/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.440, Acc=2.53%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.14%, Top5=4.84%]


Train Loss: 4.4396, Train Acc: 2.53%
Val Loss: 4.6046, Val Acc: 1.14%, Val Top5: 4.84%, Val Top10: 10.02%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.14%

Epoch 14/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.416, Acc=2.78%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.18%, Top5=5.16%]


Train Loss: 4.4164, Train Acc: 2.78%
Val Loss: 4.6042, Val Acc: 1.18%, Val Top5: 5.16%, Val Top10: 10.58%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.18%

Epoch 15/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.385, Acc=3.05%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.16%, Top5=4.96%]


Train Loss: 4.3846, Train Acc: 3.05%
Val Loss: 4.6033, Val Acc: 1.16%, Val Top5: 4.96%, Val Top10: 10.20%
Learning Rate: 0.000120

Epoch 16/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.371, Acc=3.42%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.18%, Top5=5.34%]


Train Loss: 4.3708, Train Acc: 3.42%
Val Loss: 4.6028, Val Acc: 1.18%, Val Top5: 5.34%, Val Top10: 10.68%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.18%

Epoch 17/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.363, Acc=3.46%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.26%, Top5=5.68%]


Train Loss: 4.3626, Train Acc: 3.46%
Val Loss: 4.6016, Val Acc: 1.26%, Val Top5: 5.68%, Val Top10: 10.60%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.26%

Epoch 18/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.355, Acc=3.75%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.18%, Top5=5.32%]


Train Loss: 4.3550, Train Acc: 3.75%
Val Loss: 4.5999, Val Acc: 1.18%, Val Top5: 5.32%, Val Top10: 11.16%
Learning Rate: 0.000120

Epoch 19/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.310, Acc=4.51%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.06%, Top5=5.86%]


Train Loss: 4.3104, Train Acc: 4.51%
Val Loss: 4.5993, Val Acc: 1.06%, Val Top5: 5.86%, Val Top10: 11.34%
Learning Rate: 0.000120

Epoch 20/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.319, Acc=4.01%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.08%, Top5=5.96%]


Train Loss: 4.3188, Train Acc: 4.01%
Val Loss: 4.5983, Val Acc: 1.08%, Val Top5: 5.96%, Val Top10: 11.20%
Learning Rate: 0.000120

Epoch 21/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.295, Acc=4.46%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.10%, Top5=6.24%]


Train Loss: 4.2945, Train Acc: 4.46%
Val Loss: 4.5963, Val Acc: 1.10%, Val Top5: 6.24%, Val Top10: 11.96%
Learning Rate: 0.000120

Epoch 22/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.270, Acc=4.66%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.02%, Top5=5.98%]


Train Loss: 4.2696, Train Acc: 4.66%
Val Loss: 4.5944, Val Acc: 1.02%, Val Top5: 5.98%, Val Top10: 12.30%
Learning Rate: 0.000120

Epoch 23/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.239, Acc=5.20%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.28it/s, Loss=0.018, Top1=1.08%, Top5=6.10%]


Train Loss: 4.2389, Train Acc: 5.20%
Val Loss: 4.5921, Val Acc: 1.08%, Val Top5: 6.10%, Val Top10: 12.22%
Learning Rate: 0.000120

Epoch 24/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.214, Acc=5.54%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.18%, Top5=6.32%]


Train Loss: 4.2141, Train Acc: 5.54%
Val Loss: 4.5891, Val Acc: 1.18%, Val Top5: 6.32%, Val Top10: 12.14%
Learning Rate: 0.000120

Epoch 25/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.06it/s, Loss=4.212, Acc=5.86%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.30it/s, Loss=0.018, Top1=1.46%, Top5=6.56%]


Train Loss: 4.2116, Train Acc: 5.86%
Val Loss: 4.5875, Val Acc: 1.46%, Val Top5: 6.56%, Val Top10: 12.84%
Learning Rate: 0.000120
✓ New best validation accuracy: 1.46%

Epoch 26/50
--------------------------------------------------


Training: 100%|██████████| 176/176 [02:46<00:00,  1.05it/s, Loss=4.159, Acc=6.02%]
Validation: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s, Loss=0.018, Top1=1.44%, Top5=6.40%]


Train Loss: 4.1595, Train Acc: 6.02%
Val Loss: 4.5858, Val Acc: 1.44%, Val Top5: 6.40%, Val Top10: 13.16%
Learning Rate: 0.000120

Epoch 27/50
--------------------------------------------------


Training:  63%|██████▎   | 111/176 [01:45<01:01,  1.06it/s, Loss=4.186, Acc=6.19%]

## Plot Training and Testing Results

In [None]:
def plot_training_and_testing_results(history, test_acc, test_top5_acc, test_top10_acc, test_loss):
    """Plot comprehensive training, validation, and testing results"""
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

    # 1. Training and Validation Loss
    ax1 = fig.add_subplot(gs[0, 0])
    epochs = range(1, len(history['train_loss']) + 1)
    ax1.plot(epochs, history['train_loss'], 'b-o', label='Training Loss', linewidth=2, markersize=4)
    ax1.plot(epochs, history['val_loss'], 'r-s', label='Validation Loss', linewidth=2, markersize=4)
    ax1.set_title('Training vs Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    # 2. Training and Validation Accuracy
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(epochs, history['train_acc'], 'b-o', label='Training Accuracy', linewidth=2, markersize=4)
    ax2.plot(epochs, history['val_acc'], 'r-s', label='Validation Accuracy', linewidth=2, markersize=4)
    ax2.set_title('Training vs Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)

    # 3. Top-K Validation Accuracies
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.plot(epochs, history['val_acc'], 'g-o', label='Top-1', linewidth=2, markersize=4)
    ax3.plot(epochs, history['val_top5_acc'], 'orange', marker='s', label='Top-5', linewidth=2, markersize=4)
    ax3.plot(epochs, history['val_top10_acc'], 'purple', marker='^', label='Top-10', linewidth=2, markersize=4)
    ax3.set_title('Validation Top-K Accuracy', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Epoch', fontsize=12)
    ax3.set_ylabel('Accuracy (%)', fontsize=12)
    ax3.legend(fontsize=10)
    ax3.grid(True, alpha=0.3)

    # 4. Learning Rate Schedule
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.plot(epochs, history['lr'], 'c-', linewidth=2)
    ax4.set_title('Learning Rate Schedule (OneCycleLR)', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Epoch', fontsize=12)
    ax4.set_ylabel('Learning Rate', fontsize=12)
    ax4.set_yscale('log')
    ax4.grid(True, alpha=0.3)

    # 5. Final Performance Comparison - Bar Chart
    ax5 = fig.add_subplot(gs[1, 1])
    final_train_acc = history['train_acc'][-1]
    final_val_acc = history['val_acc'][-1]

    categories = ['Training', 'Validation', 'Test']
    accuracies = [final_train_acc, final_val_acc, test_acc]
    colors = ['#3498db', '#e74c3c', '#2ecc71']

    bars = ax5.bar(categories, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    ax5.set_title('Final Top-1 Accuracy Comparison', fontsize=14, fontweight='bold')
    ax5.set_ylabel('Accuracy (%)', fontsize=12)
    ax5.set_ylim([0, 100])

    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax5.text(bar.get_x() + bar.get_width()/2., height,
                f'{acc:.2f}%',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    ax5.grid(True, alpha=0.3, axis='y')

    # 6. Top-K Test Accuracy Comparison
    ax6 = fig.add_subplot(gs[1, 2])
    topk_categories = ['Top-1', 'Top-5', 'Top-10']
    topk_accuracies = [test_acc, test_top5_acc, test_top10_acc]
    topk_colors = ['#e74c3c', '#f39c12', '#9b59b6']

    bars2 = ax6.bar(topk_categories, topk_accuracies, color=topk_colors, alpha=0.7, edgecolor='black', linewidth=2)
    ax6.set_title('Test Set Top-K Accuracy', fontsize=14, fontweight='bold')
    ax6.set_ylabel('Accuracy (%)', fontsize=12)
    ax6.set_ylim([0, 100])

    # Add value labels
    for bar, acc in zip(bars2, topk_accuracies):
        height = bar.get_height()
        ax6.text(bar.get_x() + bar.get_width()/2., height,
                f'{acc:.2f}%',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    ax6.grid(True, alpha=0.3, axis='y')

    # 7. Accuracy Progression (All Sets)
    ax7 = fig.add_subplot(gs[2, :2])
    ax7.plot(epochs, history['train_acc'], 'b-o', label='Training', linewidth=2.5, markersize=5, alpha=0.8)
    ax7.plot(epochs, history['val_acc'], 'r-s', label='Validation', linewidth=2.5, markersize=5, alpha=0.8)
    ax7.axhline(y=test_acc, color='g', linestyle='--', linewidth=3, label=f'Test (Final: {test_acc:.2f}%)')
    ax7.fill_between(epochs, history['train_acc'], alpha=0.2, color='blue')
    ax7.fill_between(epochs, history['val_acc'], alpha=0.2, color='red')
    ax7.set_title('Accuracy Progression: Train vs Val vs Test', fontsize=14, fontweight='bold')
    ax7.set_xlabel('Epoch', fontsize=12)
    ax7.set_ylabel('Accuracy (%)', fontsize=12)
    ax7.legend(fontsize=11, loc='lower right')
    ax7.grid(True, alpha=0.3)

    # 8. Performance Summary Table
    ax8 = fig.add_subplot(gs[2, 2])
    ax8.axis('off')

    summary_data = [
        ['Metric', 'Train', 'Val', 'Test'],
        ['Top-1 Acc', f'{final_train_acc:.2f}%', f'{final_val_acc:.2f}%', f'{test_acc:.2f}%'],
        ['Top-5 Acc', '-', f'{history["val_top5_acc"][-1]:.2f}%', f'{test_top5_acc:.2f}%'],
        ['Top-10 Acc', '-', f'{history["val_top10_acc"][-1]:.2f}%', f'{test_top10_acc:.2f}%'],
        ['Loss', f'{history["train_loss"][-1]:.4f}', f'{history["val_loss"][-1]:.4f}', f'{test_loss:.4f}'],
        ['Epochs', str(len(epochs)), '-', '-']
    ]

    table = ax8.table(cellText=summary_data, cellLoc='center', loc='center',
                     colWidths=[0.25, 0.25, 0.25, 0.25])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2.5)

    # Style the header row
    for i in range(4):
        table[(0, i)].set_facecolor('#34495e')
        table[(0, i)].set_text_props(weight='bold', color='white')

    # Alternate row colors
    for i in range(1, len(summary_data)):
        for j in range(4):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#ecf0f1')
            else:
                table[(i, j)].set_facecolor('#ffffff')

    ax8.set_title('Performance Summary', fontsize=14, fontweight='bold', pad=20)

    plt.suptitle('Complete Training and Testing Analysis', fontsize=18, fontweight='bold', y=0.995)
    plt.savefig('training_testing_results.png', dpi=300, bbox_inches='tight')
    print("\n✓ Plot saved as 'training_testing_results.png'")
    plt.show()

# Plot comprehensive results
plot_training_and_testing_results(history, test_acc, test_top5_acc, test_top10_acc, test_loss)