In [None]:
import math
from typing import Tuple, Dict, Any, Optional, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision import datasets, transforms
from torch.amp import autocast, GradScaler
import random
import time
from datetime import datetime
import copy
from collections import OrderedDict
import os
import json
import pickle
from pathlib import Path

# ============================================================================
# GPU DEVICE SETUP FOR T4
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

def print_gpu_utilization():
    if torch.cuda.is_available():
        print(f"GPU Memory Used: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
        print(f"GPU Memory Cached: {torch.cuda.memory_reserved(0) / 1024**3:.1f} GB")

# ============================================================================
# ENHANCED CONFIGURATION FOR 99% TARGET
# ============================================================================

DATASET_NAMES = ['FashionMNIST']
HIDDEN_DIM: int = 3072
LEARNING_RATE: float = 8e-4
weight_decay_ = 2e-4
DROPOUT_RATE: float = 0.3
EPOCHS_PER_DATASET: int = 200
BATCH_SIZE: int = 512
ENSEMBLE_SIZE: int = 7
TTA_AUGMENTS: int = 8

# Checkpoint configuration
CHECKPOINT_DIR = "./checkpoints"
SAVE_TOP_K = 3  # Keep top 3 best models
SAVE_EVERY_N_EPOCHS = 5  # Save checkpoint every N epochs

# ============================================================================
# MODEL CHECKPOINTING SYSTEM
# ============================================================================

class ModelCheckpoint:
    """Advanced model checkpointing system"""

    def __init__(self, checkpoint_dir: str = CHECKPOINT_DIR, save_top_k: int = SAVE_TOP_K):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.save_top_k = save_top_k
        self.best_models = []  # List of (accuracy, checkpoint_path) tuples

    def save_checkpoint(self, model, optimizer, scheduler, epoch, accuracy, loss,
                       model_name="model", additional_info=None):
        """Save model checkpoint with all training state"""

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        checkpoint_name = f"{model_name}_epoch{epoch}_acc{accuracy:.2f}_{timestamp}.pt"
        checkpoint_path = self.checkpoint_dir / checkpoint_name

        # Prepare checkpoint data
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'accuracy': accuracy,
            'loss': loss,
            'model_name': model_name,
            'timestamp': timestamp,
            'model_architecture': model.__class__.__name__,
            'random_state': {
                'torch': torch.get_rng_state(),
                'numpy': np.random.get_state(),
                'python': random.getstate(),
            }
        }

        if torch.cuda.is_available():
            checkpoint['cuda_random_state'] = torch.cuda.get_rng_state()

        if additional_info:
            checkpoint['additional_info'] = additional_info

        # Save checkpoint
        torch.save(checkpoint, checkpoint_path)
        print(f"✅ Checkpoint saved: {checkpoint_name} (Acc: {accuracy:.2f}%)")

        # Update best models list
        self.best_models.append((accuracy, checkpoint_path, checkpoint_name))
        self.best_models.sort(key=lambda x: x[0], reverse=True)  # Sort by accuracy desc

        # Keep only top-k models
        if len(self.best_models) > self.save_top_k:
            # Delete old checkpoints
            for _, old_path, old_name in self.best_models[self.save_top_k:]:
                if old_path.exists():
                    old_path.unlink()
                    print(f"🗑️ Removed old checkpoint: {old_name}")
            self.best_models = self.best_models[:self.save_top_k]

        # Save best models info
        self.save_best_models_info()

        return checkpoint_path

    def save_best_models_info(self):
        """Save information about best models"""
        best_models_info = {
            'best_models': [
                {
                    'accuracy': acc,
                    'checkpoint_path': str(path),
                    'checkpoint_name': name,
                }
                for acc, path, name in self.best_models
            ],
            'last_updated': datetime.now().isoformat()
        }

        info_path = self.checkpoint_dir / "best_models_info.json"
        with open(info_path, 'w') as f:
            json.dump(best_models_info, f, indent=2)

    def load_best_checkpoint(self, model, optimizer=None, scheduler=None):
        """Load the best checkpoint"""
        if not self.best_models:
            self.load_best_models_info()

        if not self.best_models:
            print("⚠️ No checkpoints found")
            return None

        best_accuracy, best_path, best_name = self.best_models[0]
        return self.load_checkpoint(best_path, model, optimizer, scheduler)

    def load_checkpoint(self, checkpoint_path, model, optimizer=None, scheduler=None):
        """Load specific checkpoint"""
        checkpoint_path = Path(checkpoint_path)

        if not checkpoint_path.exists():
            print(f"❌ Checkpoint not found: {checkpoint_path}")
            return None

        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)

            # Load model state
            model.load_state_dict(checkpoint['model_state_dict'])

            # Load optimizer state
            if optimizer and 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            # Load scheduler state
            if scheduler and 'scheduler_state_dict' in checkpoint:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            # Restore random states
            if 'random_state' in checkpoint:
                torch.set_rng_state(checkpoint['random_state']['torch'])
                np.random.set_state(checkpoint['random_state']['numpy'])
                random.setstate(checkpoint['random_state']['python'])

                if torch.cuda.is_available() and 'cuda_random_state' in checkpoint:
                    torch.cuda.set_rng_state(checkpoint['cuda_random_state'])

            print(f"✅ Loaded checkpoint: {checkpoint_path.name}")
            print(f"   Epoch: {checkpoint['epoch']}, Accuracy: {checkpoint['accuracy']:.2f}%")

            return {
                'epoch': checkpoint['epoch'],
                'accuracy': checkpoint['accuracy'],
                'loss': checkpoint['loss'],
                'model_name': checkpoint.get('model_name', 'unknown'),
                'timestamp': checkpoint.get('timestamp', 'unknown')
            }

        except Exception as e:
            print(f"❌ Error loading checkpoint {checkpoint_path}: {e}")
            return None

    def load_best_models_info(self):
        """Load information about best models"""
        info_path = self.checkpoint_dir / "best_models_info.json"

        if info_path.exists():
            try:
                with open(info_path, 'r') as f:
                    info = json.load(f)

                self.best_models = []
                for model_info in info['best_models']:
                    path = Path(model_info['checkpoint_path'])
                    if path.exists():
                        self.best_models.append((
                            model_info['accuracy'],
                            path,
                            model_info['checkpoint_name']
                        ))

                print(f"📂 Loaded {len(self.best_models)} checkpoint(s) from history")

            except Exception as e:
                print(f"⚠️ Error loading checkpoint info: {e}")

    def list_checkpoints(self):
        """List all available checkpoints"""
        if not self.best_models:
            self.load_best_models_info()

        if not self.best_models:
            print("No checkpoints found")
            return

        print("\n📋 Available Checkpoints:")
        print("-" * 60)
        for i, (accuracy, path, name) in enumerate(self.best_models):
            print(f"{i+1}. {name}")
            print(f"   Accuracy: {accuracy:.2f}%")
            print(f"   Path: {path}")
            print("-" * 60)

    def get_resume_info(self):
        """Get information for resuming training"""
        if not self.best_models:
            self.load_best_models_info()

        if not self.best_models:
            return None

        best_accuracy, best_path, best_name = self.best_models[0]

        return {
            'checkpoint_path': best_path,
            'checkpoint_name': best_name,
            'best_accuracy': best_accuracy
        }

# ============================================================================
# ADVANCED DATA AUGMENTATION STRATEGIES
# ============================================================================

class CutMix(object):
    """CutMix augmentation"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def __call__(self, x, y):
        if self.alpha <= 0:
            return x, y, y, 1.0

        batch_size = x.size(0)
        lam = np.random.beta(self.alpha, self.alpha)
        index = torch.randperm(batch_size).to(device)

        bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lam)
        x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
        y_a, y_b = y, y[index]
        return x, y_a, y_b, lam

    def _rand_bbox(self, size, lam):
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int32(W * cut_rat)
        cut_h = np.int32(H * cut_rat)

        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2

class MixUp(object):
    """Enhanced MixUp with adaptive alpha"""
    def __init__(self, alpha=0.4):
        self.alpha = alpha

    def __call__(self, x, y):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1

        batch_size = x.size(0)
        index = torch.randperm(batch_size).to(device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

class AdvancedAugmentDataset(Dataset):
    """Advanced dataset with sophisticated augmentations"""
    def __init__(self, base_dataset, num_augments=3, heavy_aug=True):
        self.base = base_dataset
        self.num_augments = num_augments
        self.heavy_aug = heavy_aug

        if heavy_aug:
            self.aug_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2)),
                transforms.RandomResizedCrop(28, scale=(0.7, 1.0)),
                transforms.ColorJitter(brightness=0.3, contrast=0.3),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
                transforms.RandomErasing(p=0.6, scale=(0.02, 0.4), ratio=(0.3, 3.3)),
            ])

    def __len__(self):
        return len(self.base) * self.num_augments

    def __getitem__(self, idx):
        base_idx = idx // self.num_augments
        image, label = self.base[base_idx]

        if self.heavy_aug and hasattr(self, 'aug_transform'):
            if isinstance(image, torch.Tensor):
                image = transforms.ToPILImage()(image)
            image = self.aug_transform(image)

        return image, label

def mixup_cutmix_criterion(criterion, pred, y_a, y_b, lam):
    """Unified criterion for MixUp and CutMix"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def get_advanced_data_loaders(dataset_name: str, batch_size: int) -> Tuple[DataLoader, DataLoader, tuple, int]:
    """Advanced data loading with heavy augmentation"""

    if dataset_name == 'FashionMNIST':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(25),
            transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.3)),
            transforms.RandomResizedCrop(28, scale=(0.6, 1.0)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.RandomErasing(p=0.7, scale=(0.02, 0.5), ratio=(0.3, 3.3)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
        n_classes = 10
        n_channels = 1
        img_size = 28

    enhanced_train = AdvancedAugmentDataset(train_dataset, 3, heavy_aug=True)

    train_loader = DataLoader(
        enhanced_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=3,
        persistent_workers=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
        prefetch_factor=3,
        persistent_workers=True
    )

    input_dim = (n_channels, img_size, img_size)
    return train_loader, test_loader, input_dim, n_classes

# ============================================================================
# VISION TRANSFORMER COMPONENTS
# ============================================================================

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention for Vision Transformer"""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer block with layer normalization"""
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0.1, stochastic_depth=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout),
        )

        self.stochastic_depth = stochastic_depth

    def forward(self, x):
        if self.training and self.stochastic_depth > 0:
            if torch.rand(1) < self.stochastic_depth:
                return x

        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# ============================================================================
# ADVANCED ARCHITECTURE COMPONENTS
# ============================================================================

class StochasticDepth(nn.Module):
    """Stochastic Depth for regularization"""
    def __init__(self, drop_rate=0.1):
        super().__init__()
        self.drop_rate = drop_rate

    def forward(self, x, residual):
        if not self.training or self.drop_rate == 0:
            return x + residual

        batch_size = x.shape[0]
        keep_prob = 1 - self.drop_rate
        mask = torch.rand(batch_size, 1, 1, 1, device=x.device) < keep_prob
        return x + residual * mask.float() / keep_prob

class SEBlock(nn.Module):
    """Enhanced Squeeze-and-Excitation Block"""
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.SiLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class EfficientBlock(nn.Module):
    """EfficientNet-inspired block with MBConv"""
    def __init__(self, in_channels, out_channels, stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.stride = stride
        hidden_dim = in_channels * expand_ratio
        self.use_residual = stride == 1 and in_channels == out_channels

        layers = []

        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(inplace=True)
            ])

        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(inplace=True)
        ])

        if se_ratio > 0:
            layers.append(SEBlock(hidden_dim, int(1/se_ratio)))

        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        ])

        self.conv = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(0.1)

    def forward(self, x):
        if self.use_residual:
            return self.stochastic_depth(x, self.conv(x))
        else:
            return self.conv(x)

# ============================================================================
# HYBRID CNN-TRANSFORMER ARCHITECTURE
# ============================================================================

class HybridFashionNet(nn.Module):
    """Hybrid CNN-Transformer for 99% target"""

    def __init__(self, input_dim: tuple, hidden_dim: int, output_dim: int, dropout_rate: float = 0.3):
        super(HybridFashionNet, self).__init__()

        channels, height, width = input_dim

        self.stem = nn.Sequential(
            nn.Conv2d(channels, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.SiLU(inplace=True)
        )

        self.efficient_block1 = EfficientBlock(64, 64, stride=1)
        self.efficient_block2 = EfficientBlock(64, 128, stride=2)
        self.efficient_block3 = EfficientBlock(128, 128, stride=1)
        self.efficient_block4 = EfficientBlock(128, 256, stride=2)
        self.efficient_block5 = EfficientBlock(256, 256, stride=1)
        self.efficient_block6 = EfficientBlock(256, 512, stride=2)

        self.patch_embed = nn.Linear(512, 384)
        self.pos_embed = nn.Parameter(torch.zeros(1, 16, 384))

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(384, num_heads=8, stochastic_depth=0.1 * i / 4)
            for i in range(4)
        ])

        self.norm = nn.LayerNorm(384)

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cnn_classifier = nn.Linear(512, hidden_dim // 2)
        self.transformer_classifier = nn.Linear(384, hidden_dim // 2)

        self.fusion_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate // 2),

            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate // 4),

            nn.Linear(hidden_dim // 4, output_dim)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        x = self.stem(x)
        x = self.efficient_block1(x)
        x = self.efficient_block2(x)
        x = self.efficient_block3(x)
        x = self.efficient_block4(x)
        x = self.efficient_block5(x)
        x = self.efficient_block6(x)

        cnn_features = self.global_avg_pool(x).flatten(1)
        cnn_out = self.cnn_classifier(cnn_features)

        B, C, H, W = x.shape
        transformer_input = x.flatten(2).transpose(1, 2)
        transformer_input = self.patch_embed(transformer_input)
        transformer_input = transformer_input + self.pos_embed

        for block in self.transformer_blocks:
            transformer_input = block(transformer_input)

        transformer_input = self.norm(transformer_input)
        transformer_out = self.transformer_classifier(transformer_input.mean(1))

        fused = torch.cat([cnn_out, transformer_out], dim=1)
        return self.fusion_classifier(fused)

class EfficientFashionNet(nn.Module):
    """EfficientNet-inspired architecture"""
    def __init__(self, input_dim: tuple, hidden_dim: int, output_dim: int, dropout_rate: float = 0.3):
        super().__init__()
        channels, height, width = input_dim

        self.stem = nn.Sequential(
            nn.Conv2d(channels, 32, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.SiLU(inplace=True)
        )

        self.blocks = nn.Sequential(
            EfficientBlock(32, 64, stride=1, expand_ratio=1),
            EfficientBlock(64, 64, stride=1, expand_ratio=4),
            EfficientBlock(64, 128, stride=2, expand_ratio=4),
            EfficientBlock(128, 128, stride=1, expand_ratio=4),
            EfficientBlock(128, 256, stride=2, expand_ratio=4),
            EfficientBlock(256, 256, stride=1, expand_ratio=6),
            EfficientBlock(256, 512, stride=2, expand_ratio=6),
            EfficientBlock(512, 512, stride=1, expand_ratio=6),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, hidden_dim),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, output_dim)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        return self.classifier(x)

# ============================================================================
# TEST-TIME AUGMENTATION
# ============================================================================

@torch.no_grad()
def test_time_augmentation(model: nn.Module, x: torch.Tensor, num_augments: int = 8) -> torch.Tensor:
    """Test-time augmentation for better accuracy"""
    model.eval()
    predictions = []

    with autocast('cuda'):
        pred = F.softmax(model(x), dim=1)
        predictions.append(pred)

    for _ in range(num_augments - 1):
        aug_x = x.clone()

        if torch.rand(1) < 0.5:
            aug_x = torch.flip(aug_x, dims=[3])

        if torch.rand(1) < 0.7:
            noise = torch.randn_like(aug_x) * 0.02
            aug_x = aug_x + noise
            aug_x = torch.clamp(aug_x, -1, 1)

        with autocast('cuda'):
            pred = F.softmax(model(aug_x), dim=1)
            predictions.append(pred)

    return torch.stack(predictions).mean(0)

# ============================================================================
# KNOWLEDGE DISTILLATION
# ============================================================================

class KnowledgeDistillationLoss(nn.Module):
    """Knowledge distillation loss for ensemble training"""
    def __init__(self, temperature=4.0, alpha=0.3):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)

        distillation_loss = self.kl_loss(student_log_probs, teacher_probs) * (self.temperature ** 2)
        classification_loss = self.ce_loss(student_logits, labels)

        return self.alpha * classification_loss + (1 - self.alpha) * distillation_loss

# ============================================================================
# ADVANCED TRAINING WITH CHECKPOINTING
# ============================================================================

def train_advanced_model(model: nn.Module, train_loader: DataLoader, criterion: nn.Module,
                        test_loader: DataLoader, optimizer: optim.Optimizer, scheduler,
                        num_epochs: int = 200, verbose: bool = True, teacher_model=None,
                        model_name: str = "model", resume_from_checkpoint: bool = True) -> float:

    model = model.to(device)
    scaler = GradScaler('cuda')
    mixup = MixUp(alpha=0.4)
    cutmix = CutMix(alpha=1.0)

    # Initialize checkpoint system
    checkpoint = ModelCheckpoint()

    if teacher_model is not None:
        teacher_model.eval()
        kd_criterion = KnowledgeDistillationLoss()

    # Resume from checkpoint if available
    start_epoch = 1
    best_accuracy = 0.0
    patience_counter = 0
    max_patience = 30

    if resume_from_checkpoint:
        resume_info = checkpoint.get_resume_info()
        if resume_info:
            print(f"\n🔄 Attempting to resume training for {model_name}...")
            loaded_info = checkpoint.load_checkpoint(
                resume_info['checkpoint_path'], model, optimizer, scheduler
            )

            if loaded_info:
                start_epoch = loaded_info['epoch'] + 1
                best_accuracy = loaded_info['accuracy']
                print(f"✅ Resumed from epoch {loaded_info['epoch']} with accuracy {best_accuracy:.2f}%")
            else:
                print("⚠️ Failed to load checkpoint, starting from scratch")
        else:
            print(f"ℹ️ No checkpoint found for {model_name}, starting fresh training")

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        total_loss = 0.0

        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            batch_x = batch_x.to(device, non_blocking=True)
            batch_y = batch_y.to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast('cuda'):
                aug_prob = np.random.random()

                if epoch > 30 and aug_prob < 0.3:
                    mixed_x, y_a, y_b, lam = mixup(batch_x, batch_y)
                    outputs = model(mixed_x)
                    loss = mixup_cutmix_criterion(criterion, outputs, y_a, y_b, lam)
                elif epoch > 30 and aug_prob < 0.6:
                    mixed_x, y_a, y_b, lam = cutmix(batch_x, batch_y)
                    outputs = model(mixed_x)
                    loss = mixup_cutmix_criterion(criterion, outputs, y_a, y_b, lam)
                else:
                    outputs = model(batch_x)
                    loss = criterion(outputs, batch_y)

                    if teacher_model is not None and epoch > 50:
                        with torch.no_grad():
                            teacher_outputs = teacher_model(batch_x)
                        kd_loss = kd_criterion(outputs, teacher_outputs, batch_y)
                        loss = 0.7 * loss + 0.3 * kd_loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            if torch.cuda.is_available() and batch_idx % 50 == 0:
                torch.cuda.empty_cache()

        scheduler.step()
        avg_loss = total_loss / len(train_loader)

        # Evaluate and save checkpoints
        if epoch % SAVE_EVERY_N_EPOCHS == 0 or epoch == num_epochs:
            test_loss, test_acc = evaluate_advanced_model(model, test_loader, criterion)

            if verbose:
                current_time = datetime.now().strftime("%B %d, %Y at %I:%M:%S %p")
                print(f"Epoch {epoch}/{num_epochs}: Loss = {avg_loss:.4f} | "
                      f"Test Acc = {test_acc:.2f}% | LR = {scheduler.get_last_lr()[0]:.6f} | Time: {current_time}")
                print_gpu_utilization()

            # Save checkpoint if it's the best so far
            if test_acc > best_accuracy:
                best_accuracy = test_acc
                patience_counter = 0

                # Save the best model
                checkpoint.save_checkpoint(
                    model, optimizer, scheduler, epoch, test_acc, avg_loss,
                    model_name=model_name,
                    additional_info={
                        'training_config': {
                            'hidden_dim': model.fusion_classifier[0].in_features if hasattr(model, 'fusion_classifier') else 'unknown',
                            'dropout_rate': 'variable',
                            'batch_size': train_loader.batch_size,
                            'learning_rate': scheduler.get_last_lr()[0],
                        }
                    }
                )
                print(f"🎯 New best accuracy: {test_acc:.2f}% (improved by {test_acc - (best_accuracy if best_accuracy != test_acc else 0):.2f}%)")
            else:
                patience_counter += SAVE_EVERY_N_EPOCHS

            # Also save periodic checkpoint regardless of performance
            if epoch % (SAVE_EVERY_N_EPOCHS * 4) == 0:
                checkpoint.save_checkpoint(
                    model, optimizer, scheduler, epoch, test_acc, avg_loss,
                    model_name=f"{model_name}_periodic",
                    additional_info={'checkpoint_type': 'periodic'}
                )

            if patience_counter >= max_patience:
                print(f"⏹️ Early stopping at epoch {epoch}")
                break

    # Final checkpoint save
    if epoch == num_epochs or patience_counter >= max_patience:
        final_test_loss, final_test_acc = evaluate_advanced_model(model, test_loader, criterion)
        checkpoint.save_checkpoint(
            model, optimizer, scheduler, epoch, final_test_acc, final_test_loss,
            model_name=f"{model_name}_final",
            additional_info={'checkpoint_type': 'final'}
        )

    return best_accuracy

@torch.no_grad()
def evaluate_advanced_model(model: nn.Module, test_loader: DataLoader, criterion: nn.Module,
                           use_tta: bool = False) -> Tuple[float, float]:
    """Advanced evaluation with optional TTA"""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        if use_tta:
            outputs = test_time_augmentation(model, batch_x, TTA_AUGMENTS)
        else:
            with autocast('cuda'):
                outputs = F.softmax(model(batch_x), dim=1)

        if use_tta:
            raw_outputs = torch.log(outputs + 1e-8)
        else:
            raw_outputs = model(batch_x)

        loss = criterion(raw_outputs, batch_y)
        total_loss += loss.item()

        preds = outputs.argmax(dim=1)
        correct += (preds == batch_y).sum().item()
        total += batch_y.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

# ============================================================================
# ADVANCED ENSEMBLE SYSTEM WITH CHECKPOINTING
# ============================================================================

def train_diverse_ensemble(input_dim, hidden_dim, n_classes, train_loader, test_loader,
                          ensemble_size=7, resume_training=True) -> Tuple[List[nn.Module], List[float]]:
    """Train diverse ensemble with checkpoint support"""

    models = []
    best_accuracies = []
    architectures = []

    model_configs = [
        (HybridFashionNet, "Hybrid_CNN_Transformer"),
        (EfficientFashionNet, "EfficientNet_Style"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         HybridFashionNet(input_dim, hidden_dim, n_classes, dropout_rate), "Hybrid_Variant_1"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         EfficientFashionNet(input_dim, hidden_dim + 256, n_classes, dropout_rate), "Large_EfficientNet"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         HybridFashionNet(input_dim, hidden_dim + 512, n_classes, dropout_rate * 0.8), "Large_Hybrid"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         EfficientFashionNet(input_dim, hidden_dim, n_classes, dropout_rate * 1.2), "Regularized_EfficientNet"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         HybridFashionNet(input_dim, hidden_dim, n_classes, dropout_rate * 1.1), "Regularized_Hybrid"),
    ]

    for i in range(min(ensemble_size, len(model_configs))):
        print(f"\n{'='*80}")
        print(f"Training Ensemble Model {i+1}/{ensemble_size}: {model_configs[i][1]}")
        print(f"{'='*80}")

        model_class = model_configs[i][0]
        model_name = f"ensemble_{i+1}_{model_configs[i][1]}"
        model_hidden = hidden_dim + (i * 128)
        model_dropout = DROPOUT_RATE + i * 0.02

        model = model_class(input_dim, model_hidden, n_classes, model_dropout)
        model = model.to(device)
        architectures.append(model_configs[i][1])

        if i % 2 == 0:
            optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE * (0.8 + i*0.05),
                                   weight_decay=weight_decay_ * (1 + i*0.1), betas=(0.9, 0.999))
        else:
            optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE * (0.9 + i*0.05),
                                  weight_decay=weight_decay_ * (1 + i*0.1))

        if i % 3 == 0:
            scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=EPOCHS_PER_DATASET//4, eta_min=1e-7)
        elif i % 3 == 1:
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer, max_lr=LEARNING_RATE * 2, epochs=EPOCHS_PER_DATASET,
                steps_per_epoch=len(train_loader))
        else:
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=EPOCHS_PER_DATASET, eta_min=1e-7)

        criterion = nn.CrossEntropyLoss(label_smoothing=0.1 + i*0.02)

        teacher_model = models[-1] if len(models) > 0 and best_accuracies[-1] > 94.0 else None

        best_acc = train_advanced_model(
            model, train_loader, criterion, test_loader, optimizer, scheduler,
            num_epochs=EPOCHS_PER_DATASET, teacher_model=teacher_model,
            model_name=model_name, resume_from_checkpoint=resume_training
        )

        models.append(model)
        best_accuracies.append(best_acc)

        print(f"✅ Model {i+1} ({model_configs[i][1]}) Best Accuracy: {best_acc:.2f}%")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return models, best_accuracies

@torch.no_grad()
def advanced_ensemble_predict(models: List[nn.Module], test_loader: DataLoader,
                            use_tta: bool = True) -> float:
    """Advanced ensemble prediction with TTA and weighted voting"""
    for model in models:
        model.eval()

    correct = 0
    total = 0

    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        ensemble_outputs = []
        for model in models:
            if use_tta:
                outputs = test_time_augmentation(model, batch_x, TTA_AUGMENTS)
            else:
                with autocast('cuda'):
                    outputs = F.softmax(model(batch_x), dim=1)
            ensemble_outputs.append(outputs)

        avg_outputs = torch.stack(ensemble_outputs).mean(0)
        preds = avg_outputs.argmax(dim=1)

        correct += (preds == batch_y).sum().item()
        total += batch_y.size(0)

    accuracy = 100.0 * correct / total
    return accuracy

# ============================================================================
# MAIN EXPERIMENT PIPELINE WITH CHECKPOINTING
# ============================================================================

def run_advanced_experiments(resume_training=True):
    """Run advanced experiments with checkpoint support"""

    for dataset_name in DATASET_NAMES:
        print(f"\n{'='*90}")
        print(f"RUNNING ADVANCED SOTA EXPERIMENTS ON {dataset_name}")
        print(f"TARGET: 99% ACCURACY WITH ADVANCED TECHNIQUES + CHECKPOINTING")
        print(f"{'='*90}")

        train_loader, test_loader, input_dim, n_classes = get_advanced_data_loaders(dataset_name, BATCH_SIZE)
        print(f"Dataset: {dataset_name} | Input: {input_dim} | Classes: {n_classes}")
        print(f"Train: {len(train_loader.dataset)} | Test: {len(test_loader.dataset)}")
        print(f"Advanced Augmentation: Enabled | TTA: {TTA_AUGMENTS} augments")
        print(f"Checkpoint Directory: {CHECKPOINT_DIR}")
        print(f"Resume Training: {'Enabled' if resume_training else 'Disabled'}")
        print_gpu_utilization()

        # List existing checkpoints
        checkpoint_manager = ModelCheckpoint()
        checkpoint_manager.list_checkpoints()

        models, individual_accuracies = train_diverse_ensemble(
            input_dim, HIDDEN_DIM, n_classes, train_loader, test_loader,
            ENSEMBLE_SIZE, resume_training=resume_training
        )

        print(f"\n{'='*70}")
        print("INDIVIDUAL MODEL RESULTS")
        print(f"{'='*70}")
        for i, acc in enumerate(individual_accuracies):
            print(f"Model {i+1}: {acc:.2f}%")

        print(f"\n{'='*70}")
        print("ENSEMBLE RESULTS WITH TTA")
        print(f"{'='*70}")

        ensemble_acc_no_tta = advanced_ensemble_predict(models, test_loader, use_tta=False)
        print(f"Ensemble Accuracy (no TTA): {ensemble_acc_no_tta:.2f}%")

        ensemble_acc_tta = advanced_ensemble_predict(models, test_loader, use_tta=True)
        print(f"Ensemble Accuracy (with TTA): {ensemble_acc_tta:.2f}%")

        final_accuracy = max(max(individual_accuracies), ensemble_acc_no_tta, ensemble_acc_tta)

        # Save ensemble results
        ensemble_checkpoint = ModelCheckpoint()
        ensemble_results = {
            'individual_accuracies': individual_accuracies,
            'ensemble_no_tta': ensemble_acc_no_tta,
            'ensemble_with_tta': ensemble_acc_tta,
            'final_accuracy': final_accuracy,
            'timestamp': datetime.now().isoformat(),
            'experiment_config': {
                'dataset': dataset_name,
                'ensemble_size': len(models),
                'epochs': EPOCHS_PER_DATASET,
                'batch_size': BATCH_SIZE,
                'hidden_dim': HIDDEN_DIM,
            }
        }

        results_path = Path(CHECKPOINT_DIR) / "ensemble_results.json"
        with open(results_path, 'w') as f:
            json.dump(ensemble_results, f, indent=2)

        print(f"\n{'='*70}")
        print("FINAL ADVANCED SOTA RESULTS")
        print(f"{'='*70}")
        print(f"Best Individual Model: {max(individual_accuracies):.2f}%")
        print(f"Best Ensemble (no TTA): {ensemble_acc_no_tta:.2f}%")
        print(f"Best Ensemble (with TTA): {ensemble_acc_tta:.2f}%")
        print(f"FINAL BEST ACCURACY: {final_accuracy:.2f}%")
        print(f"Results saved to: {results_path}")

        if final_accuracy >= 99.0:
            print(f"🎉🎉 OUTSTANDING! {final_accuracy:.2f}% ≥ 99% TARGET ACHIEVED!")
        elif final_accuracy >= 97.0:
            print(f"🎉 EXCELLENT! {final_accuracy:.2f}% ≥ 97%")
        elif final_accuracy >= 95.0:
            print(f"✅ VERY GOOD! {final_accuracy:.2f}% ≥ 95%")
        elif final_accuracy >= 92.0:
            print(f"🟡 GOOD! {final_accuracy:.2f}% ≥ 92%")
        else:
            print(f"⚠️ NEEDS IMPROVEMENT: {final_accuracy:.2f}%")


def main():
    SEED_ = 42
    print("="*90)
    print("ADVANCED SOTA FASHIONMNIST WITH CHECKPOINT SYSTEM")
    print("FEATURES: Model Saving/Loading, Resume Training, Hybrid Architecture")
    print("="*90)
    print(f"Device: {device}")
    print(f"Ensemble Size: {ENSEMBLE_SIZE}")
    print(f"Epochs per Model: {EPOCHS_PER_DATASET}")
    print(f"Hidden Dim: {HIDDEN_DIM}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Checkpoint Directory: {CHECKPOINT_DIR}")
    print(f"Save Top K Models: {SAVE_TOP_K}")
    print(f"Save Every N Epochs: {SAVE_EVERY_N_EPOCHS}")

    torch.manual_seed(SEED_)
    np.random.seed(SEED_)
    random.seed(SEED_)

    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

    # Option to resume training or start fresh
    resume_choice = input("\nResume training from checkpoints? (y/n, default=y): ").strip().lower()
    resume_training = resume_choice != 'n'

    run_advanced_experiments(resume_training=resume_training)
    print("\n🎉 ADVANCED SOTA EXPERIMENTS WITH CHECKPOINTING COMPLETED!")

    # Final checkpoint summary
    print(f"\n{'='*60}")
    print("CHECKPOINT SUMMARY")
    print(f"{'='*60}")
    checkpoint_manager = ModelCheckpoint()
    checkpoint_manager.list_checkpoints()

#if __name__ == "__main__":
#    main()


In [None]:
main()


In [None]:
import math
from typing import Tuple, Dict, Any, Optional, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision import datasets, transforms
from torch.amp import autocast, GradScaler  # Updated import
import random
import time
from datetime import datetime
import copy
from collections import OrderedDict

# ============================================================================
# GPU DEVICE SETUP FOR T4
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

def print_gpu_utilization():
    if torch.cuda.is_available():
        print(f"GPU Memory Used: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
        print(f"GPU Memory Cached: {torch.cuda.memory_reserved(0) / 1024**3:.1f} GB")

# ============================================================================
# ENHANCED CONFIGURATION FOR 99% TARGET
# ============================================================================

DATASET_NAMES = ['FashionMNIST']
HIDDEN_DIM: int = 3072  # Increased capacity
LEARNING_RATE: float = 8e-4  # Slightly reduced for stability
weight_decay_ = 2e-4  # Increased regularization
DROPOUT_RATE: float = 0.3  # Reduced for better capacity
EPOCHS_PER_DATASET: int = 200  # Extended training
BATCH_SIZE: int = 512  # Reduced for more stable training
ENSEMBLE_SIZE: int = 7  # Larger ensemble
TTA_AUGMENTS: int = 8  # Test-time augmentation

# ============================================================================
# ADVANCED DATA AUGMENTATION STRATEGIES
# ============================================================================

class CutMix(object):
    """CutMix augmentation"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def __call__(self, x, y):
        if self.alpha <= 0:
            return x, y, y, 1.0

        batch_size = x.size(0)
        lam = np.random.beta(self.alpha, self.alpha)
        index = torch.randperm(batch_size).to(device)

        bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lam)
        x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

        # Adjust lambda to match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
        y_a, y_b = y, y[index]
        return x, y_a, y_b, lam

    def _rand_bbox(self, size, lam):
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int32(W * cut_rat)
        cut_h = np.int32(H * cut_rat)

        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2

class MixUp(object):
    """Enhanced MixUp with adaptive alpha"""
    def __init__(self, alpha=0.4):
        self.alpha = alpha

    def __call__(self, x, y):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1

        batch_size = x.size(0)
        index = torch.randperm(batch_size).to(device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

class AdvancedAugmentDataset(Dataset):
    """Advanced dataset with sophisticated augmentations"""
    def __init__(self, base_dataset, num_augments=3, heavy_aug=True):
        self.base = base_dataset
        self.num_augments = num_augments
        self.heavy_aug = heavy_aug

        # Heavy augmentation for training
        if heavy_aug:
            self.aug_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2)),
                transforms.RandomResizedCrop(28, scale=(0.7, 1.0)),
                transforms.ColorJitter(brightness=0.3, contrast=0.3),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
                transforms.RandomErasing(p=0.6, scale=(0.02, 0.4), ratio=(0.3, 3.3)),
            ])

    def __len__(self):
        return len(self.base) * self.num_augments

    def __getitem__(self, idx):
        base_idx = idx // self.num_augments
        image, label = self.base[base_idx]

        if self.heavy_aug and hasattr(self, 'aug_transform'):
            # Convert tensor back to PIL for additional augmentation
            if isinstance(image, torch.Tensor):
                image = transforms.ToPILImage()(image)
            image = self.aug_transform(image)

        return image, label

def mixup_cutmix_criterion(criterion, pred, y_a, y_b, lam):
    """Unified criterion for MixUp and CutMix"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def get_advanced_data_loaders(dataset_name: str, batch_size: int) -> Tuple[DataLoader, DataLoader, tuple, int]:
    """Advanced data loading with heavy augmentation"""

    if dataset_name == 'FashionMNIST':
        # Sophisticated training augmentation
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(25),
            transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.3)),
            transforms.RandomResizedCrop(28, scale=(0.6, 1.0)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.RandomErasing(p=0.7, scale=(0.02, 0.5), ratio=(0.3, 3.3)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
        n_classes = 10
        n_channels = 1
        img_size = 28

    # Heavy augmentation with 3x data multiplication
    enhanced_train = AdvancedAugmentDataset(train_dataset, 3, heavy_aug=True)

    train_loader = DataLoader(
        enhanced_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,  # Increased workers
        pin_memory=True,
        drop_last=True,
        prefetch_factor=3,
        persistent_workers=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
        prefetch_factor=3,
        persistent_workers=True
    )

    input_dim = (n_channels, img_size, img_size)
    return train_loader, test_loader, input_dim, n_classes

# ============================================================================
# VISION TRANSFORMER COMPONENTS
# ============================================================================

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention for Vision Transformer"""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer block with layer normalization"""
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0.1, stochastic_depth=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout),
        )

        self.stochastic_depth = stochastic_depth

    def forward(self, x):
        # Stochastic depth for regularization
        if self.training and self.stochastic_depth > 0:
            if torch.rand(1) < self.stochastic_depth:
                return x

        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# ============================================================================
# ADVANCED ARCHITECTURE COMPONENTS
# ============================================================================

class StochasticDepth(nn.Module):
    """Stochastic Depth for regularization"""
    def __init__(self, drop_rate=0.1):
        super().__init__()
        self.drop_rate = drop_rate

    def forward(self, x, residual):
        if not self.training or self.drop_rate == 0:
            return x + residual

        batch_size = x.shape[0]
        keep_prob = 1 - self.drop_rate
        mask = torch.rand(batch_size, 1, 1, 1, device=x.device) < keep_prob
        return x + residual * mask.float() / keep_prob

class SEBlock(nn.Module):
    """Enhanced Squeeze-and-Excitation Block"""
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.SiLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class EfficientBlock(nn.Module):
    """EfficientNet-inspired block with MBConv"""
    def __init__(self, in_channels, out_channels, stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.stride = stride
        hidden_dim = in_channels * expand_ratio
        self.use_residual = stride == 1 and in_channels == out_channels

        layers = []

        # Pointwise expansion
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(inplace=True)
            ])

        # Depthwise convolution
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(inplace=True)
        ])

        # SE block
        if se_ratio > 0:
            layers.append(SEBlock(hidden_dim, int(1/se_ratio)))

        # Pointwise compression
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        ])

        self.conv = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(0.1)

    def forward(self, x):
        if self.use_residual:
            return self.stochastic_depth(x, self.conv(x))
        else:
            return self.conv(x)

class ResidualBlock(nn.Module):
    """Enhanced Residual Block with advanced features"""
    def __init__(self, in_channels, out_channels, stride=1, use_se=True, drop_path=0.1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.se = SEBlock(out_channels) if use_se else None
        self.dropout = nn.Dropout2d(0.1)
        self.stochastic_depth = StochasticDepth(drop_path)

    def forward(self, x):
        residual = self.shortcut(x)

        out = F.silu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))

        if self.se:
            out = self.se(out)

        return self.stochastic_depth(residual, out)

# ============================================================================
# HYBRID CNN-TRANSFORMER ARCHITECTURE
# ============================================================================

class HybridFashionNet(nn.Module):
    """Hybrid CNN-Transformer for 99% target"""

    def __init__(self, input_dim: tuple, hidden_dim: int, output_dim: int, dropout_rate: float = 0.3):
        super(HybridFashionNet, self).__init__()

        channels, height, width = input_dim

        # CNN Feature Extraction
        self.stem = nn.Sequential(
            nn.Conv2d(channels, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.SiLU(inplace=True)
        )

        # EfficientNet-style blocks
        self.efficient_block1 = EfficientBlock(64, 64, stride=1)
        self.efficient_block2 = EfficientBlock(64, 128, stride=2)
        self.efficient_block3 = EfficientBlock(128, 128, stride=1)
        self.efficient_block4 = EfficientBlock(128, 256, stride=2)
        self.efficient_block5 = EfficientBlock(256, 256, stride=1)
        self.efficient_block6 = EfficientBlock(256, 512, stride=2)

        # Transformer components
        self.patch_embed = nn.Linear(512, 384)
        self.pos_embed = nn.Parameter(torch.zeros(1, 16, 384))  # 4x4 patches after CNN

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(384, num_heads=8, stochastic_depth=0.1 * i / 4)
            for i in range(4)
        ])

        self.norm = nn.LayerNorm(384)

        # Advanced classifier with multiple paths
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cnn_classifier = nn.Linear(512, hidden_dim // 2)

        self.transformer_classifier = nn.Linear(384, hidden_dim // 2)

        self.fusion_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate // 2),

            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate // 4),

            nn.Linear(hidden_dim // 4, output_dim)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """Advanced weight initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Initialize positional embeddings
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        # CNN path
        x = self.stem(x)
        x = self.efficient_block1(x)
        x = self.efficient_block2(x)
        x = self.efficient_block3(x)
        x = self.efficient_block4(x)
        x = self.efficient_block5(x)
        x = self.efficient_block6(x)

        # Split for dual processing
        cnn_features = self.global_avg_pool(x).flatten(1)
        cnn_out = self.cnn_classifier(cnn_features)

        # Transformer path
        B, C, H, W = x.shape
        transformer_input = x.flatten(2).transpose(1, 2)  # B, HW, C
        transformer_input = self.patch_embed(transformer_input)
        transformer_input = transformer_input + self.pos_embed

        for block in self.transformer_blocks:
            transformer_input = block(transformer_input)

        transformer_input = self.norm(transformer_input)
        transformer_out = self.transformer_classifier(transformer_input.mean(1))

        # Fusion
        fused = torch.cat([cnn_out, transformer_out], dim=1)
        return self.fusion_classifier(fused)

# ============================================================================
# ADVANCED ENSEMBLE ARCHITECTURE
# ============================================================================

class EfficientFashionNet(nn.Module):
    """EfficientNet-inspired architecture"""
    def __init__(self, input_dim: tuple, hidden_dim: int, output_dim: int, dropout_rate: float = 0.3):
        super().__init__()
        channels, height, width = input_dim

        self.stem = nn.Sequential(
            nn.Conv2d(channels, 32, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.SiLU(inplace=True)
        )

        self.blocks = nn.Sequential(
            EfficientBlock(32, 64, stride=1, expand_ratio=1),
            EfficientBlock(64, 64, stride=1, expand_ratio=4),
            EfficientBlock(64, 128, stride=2, expand_ratio=4),
            EfficientBlock(128, 128, stride=1, expand_ratio=4),
            EfficientBlock(128, 256, stride=2, expand_ratio=4),
            EfficientBlock(256, 256, stride=1, expand_ratio=6),
            EfficientBlock(256, 512, stride=2, expand_ratio=6),
            EfficientBlock(512, 512, stride=1, expand_ratio=6),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, hidden_dim),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, output_dim)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        return self.classifier(x)

# ============================================================================
# TEST-TIME AUGMENTATION
# ============================================================================

@torch.no_grad()
def test_time_augmentation(model: nn.Module, x: torch.Tensor, num_augments: int = 8) -> torch.Tensor:
    """Test-time augmentation for better accuracy"""
    model.eval()
    predictions = []

    # Original prediction
    with autocast('cuda'):
        pred = F.softmax(model(x), dim=1)
        predictions.append(pred)

    # Augmented predictions
    for _ in range(num_augments - 1):
        # Random augmentations
        aug_x = x.clone()

        # Random horizontal flip
        if torch.rand(1) < 0.5:
            aug_x = torch.flip(aug_x, dims=[3])

        # Small random rotation (simulated with small translations)
        if torch.rand(1) < 0.7:
            # Small random noise
            noise = torch.randn_like(aug_x) * 0.02
            aug_x = aug_x + noise
            aug_x = torch.clamp(aug_x, -1, 1)

        with autocast('cuda'):
            pred = F.softmax(model(aug_x), dim=1)
            predictions.append(pred)

    return torch.stack(predictions).mean(0)

# ============================================================================
# ADVANCED TRAINING WITH KNOWLEDGE DISTILLATION
# ============================================================================

class KnowledgeDistillationLoss(nn.Module):
    """Knowledge distillation loss for ensemble training"""
    def __init__(self, temperature=4.0, alpha=0.3):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Soft targets from teacher
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)

        distillation_loss = self.kl_loss(student_log_probs, teacher_probs) * (self.temperature ** 2)
        classification_loss = self.ce_loss(student_logits, labels)

        return self.alpha * classification_loss + (1 - self.alpha) * distillation_loss

def train_advanced_model(model: nn.Module, train_loader: DataLoader, criterion: nn.Module,
                        test_loader: DataLoader, optimizer: optim.Optimizer, scheduler,
                        num_epochs: int = 200, verbose: bool = True, teacher_model=None) -> float:

    model = model.to(device)
    scaler = GradScaler('cuda')
    mixup = MixUp(alpha=0.4)
    cutmix = CutMix(alpha=1.0)

    if teacher_model is not None:
        teacher_model.eval()
        kd_criterion = KnowledgeDistillationLoss()

    best_accuracy = 0.0
    patience_counter = 0
    max_patience = 30

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0

        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            batch_x = batch_x.to(device, non_blocking=True)
            batch_y = batch_y.to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast('cuda'):
                # Decide augmentation strategy
                aug_prob = np.random.random()

                if epoch > 30 and aug_prob < 0.3:
                    # MixUp
                    mixed_x, y_a, y_b, lam = mixup(batch_x, batch_y)
                    outputs = model(mixed_x)
                    loss = mixup_cutmix_criterion(criterion, outputs, y_a, y_b, lam)
                elif epoch > 30 and aug_prob < 0.6:
                    # CutMix
                    mixed_x, y_a, y_b, lam = cutmix(batch_x, batch_y)
                    outputs = model(mixed_x)
                    loss = mixup_cutmix_criterion(criterion, outputs, y_a, y_b, lam)
                else:
                    # Standard training
                    outputs = model(batch_x)
                    loss = criterion(outputs, batch_y)

                    # Knowledge distillation if teacher available
                    if teacher_model is not None and epoch > 50:
                        with torch.no_grad():
                            teacher_outputs = teacher_model(batch_x)
                        kd_loss = kd_criterion(outputs, teacher_outputs, batch_y)
                        loss = 0.7 * loss + 0.3 * kd_loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            if torch.cuda.is_available() and batch_idx % 50 == 0:
                torch.cuda.empty_cache()

        scheduler.step()
        avg_loss = total_loss / len(train_loader)

        if epoch % 5 == 0:  # Less frequent evaluation to save time
            test_loss, test_acc = evaluate_advanced_model(model, test_loader, criterion)

            if verbose:
                current_time = datetime.now().strftime("%B %d, %Y at %I:%M:%S %p")
                print(f"Epoch {epoch}/{num_epochs}: Loss = {avg_loss:.4f} | "
                      f"Test Acc = {test_acc:.2f}% | LR = {scheduler.get_last_lr()[0]:.6f} | Time: {current_time}")
                print_gpu_utilization()

            if test_acc > best_accuracy:
                best_accuracy = test_acc
                patience_counter = 0
            else:
                patience_counter += 5  # Increment by evaluation interval

            if patience_counter >= max_patience:
                print(f"Early stopping at epoch {epoch}")
                break

    return best_accuracy

@torch.no_grad()
def evaluate_advanced_model(model: nn.Module, test_loader: DataLoader, criterion: nn.Module,
                           use_tta: bool = True) -> Tuple[float, float]:
    """Advanced evaluation with optional TTA"""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        if use_tta:
            outputs = test_time_augmentation(model, batch_x, TTA_AUGMENTS)
        else:
            with autocast('cuda'):
                outputs = F.softmax(model(batch_x), dim=1)

        # Calculate loss using model's raw logits for TTA compatibility
        if use_tta:
            # Convert back to logits for loss calculation
            raw_outputs = torch.log(outputs + 1e-8)
        else:
            raw_outputs = model(batch_x)

        loss = criterion(raw_outputs, batch_y)
        total_loss += loss.item()

        preds = outputs.argmax(dim=1)
        correct += (preds == batch_y).sum().item()
        total += batch_y.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

# ============================================================================
# ADVANCED ENSEMBLE SYSTEM
# ============================================================================

def train_diverse_ensemble(input_dim, hidden_dim, n_classes, train_loader, test_loader,
                          ensemble_size=7) -> Tuple[List[nn.Module], List[float]]:
    """Train diverse ensemble with different architectures"""

    models = []
    best_accuracies = []
    architectures = []

    # Define diverse architectures
    model_configs = [
        (HybridFashionNet, "Hybrid CNN-Transformer"),
        (EfficientFashionNet, "EfficientNet-style"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         HybridFashionNet(input_dim, hidden_dim, n_classes, dropout_rate), "Hybrid Variant 1"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         EfficientFashionNet(input_dim, hidden_dim + 256, n_classes, dropout_rate), "Large EfficientNet"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         HybridFashionNet(input_dim, hidden_dim + 512, n_classes, dropout_rate * 0.8), "Large Hybrid"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         EfficientFashionNet(input_dim, hidden_dim, n_classes, dropout_rate * 1.2), "Regularized EfficientNet"),
        (lambda input_dim, hidden_dim, n_classes, dropout_rate:
         HybridFashionNet(input_dim, hidden_dim, n_classes, dropout_rate * 1.1), "Regularized Hybrid"),
    ]

    for i in range(min(ensemble_size, len(model_configs))):
        print(f"\n{'='*70}")
        print(f"Training Ensemble Model {i+1}/{ensemble_size}: {model_configs[i][1]}")
        print(f"{'='*70}")

        model_class = model_configs[i][0]
        model_hidden = hidden_dim + (i * 128)
        model_dropout = DROPOUT_RATE + i * 0.02

        model = model_class(input_dim, model_hidden, n_classes, model_dropout)
        model = model.to(device)
        architectures.append(model_configs[i][1])

        # Varied optimizers and schedules
        if i % 2 == 0:
            optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE * (0.8 + i*0.05),
                                   weight_decay=weight_decay_ * (1 + i*0.1), betas=(0.9, 0.999))
        else:
            optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE * (0.9 + i*0.05),
                                  weight_decay=weight_decay_ * (1 + i*0.1))

        # Different schedulers for diversity
        if i % 3 == 0:
            scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=EPOCHS_PER_DATASET//4, eta_min=1e-7)
        elif i % 3 == 1:
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer, max_lr=LEARNING_RATE * 2, epochs=EPOCHS_PER_DATASET,
                steps_per_epoch=len(train_loader))
        else:
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=EPOCHS_PER_DATASET, eta_min=1e-7)

        criterion = nn.CrossEntropyLoss(label_smoothing=0.1 + i*0.02)

        # Use previous best model as teacher for knowledge distillation
        teacher_model = models[-1] if len(models) > 0 and best_accuracies[-1] > 94.0 else None

        best_acc = train_advanced_model(
            model, train_loader, criterion, test_loader, optimizer, scheduler,
            num_epochs=EPOCHS_PER_DATASET, teacher_model=teacher_model
        )

        models.append(model)
        best_accuracies.append(best_acc)

        print(f"Model {i+1} ({model_configs[i][1]}) Best Accuracy: {best_acc:.2f}%")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return models, best_accuracies

@torch.no_grad()
def advanced_ensemble_predict(models: List[nn.Module], test_loader: DataLoader,
                            use_tta: bool = True) -> float:
    """Advanced ensemble prediction with TTA and weighted voting"""
    for model in models:
        model.eval()

    correct = 0
    total = 0

    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        ensemble_outputs = []
        for model in models:
            if use_tta:
                outputs = test_time_augmentation(model, batch_x, TTA_AUGMENTS)
            else:
                with autocast('cuda'):
                    outputs = F.softmax(model(batch_x), dim=1)
            ensemble_outputs.append(outputs)

        # Weighted average (equal weights for now, can be optimized)
        avg_outputs = torch.stack(ensemble_outputs).mean(0)
        preds = avg_outputs.argmax(dim=1)

        correct += (preds == batch_y).sum().item()
        total += batch_y.size(0)

    accuracy = 100.0 * correct / total
    return accuracy

# ============================================================================
# MAIN EXPERIMENT PIPELINE FOR 99% TARGET
# ============================================================================

def run_advanced_experiments():
    """Run advanced experiments targeting 99% accuracy"""

    for dataset_name in DATASET_NAMES:
        print(f"\n{'='*90}")
        print(f"RUNNING ADVANCED SOTA EXPERIMENTS ON {dataset_name}")
        print(f"TARGET: 99% ACCURACY WITH ADVANCED TECHNIQUES")
        print(f"{'='*90}")

        train_loader, test_loader, input_dim, n_classes = get_advanced_data_loaders(dataset_name, BATCH_SIZE)
        print(f"Dataset: {dataset_name} | Input: {input_dim} | Classes: {n_classes}")
        print(f"Train: {len(train_loader.dataset)} | Test: {len(test_loader.dataset)}")
        print(f"Advanced Augmentation: Enabled | TTA: {TTA_AUGMENTS} augments")
        print_gpu_utilization()

        models, individual_accuracies = train_diverse_ensemble(
            input_dim, HIDDEN_DIM, n_classes, train_loader, test_loader, ENSEMBLE_SIZE
        )

        print(f"\n{'='*70}")
        print("INDIVIDUAL MODEL RESULTS")
        print(f"{'='*70}")
        for i, acc in enumerate(individual_accuracies):
            print(f"Model {i+1}: {acc:.2f}%")

        print(f"\n{'='*70}")
        print("ENSEMBLE RESULTS WITH TTA")
        print(f"{'='*70}")

        # Ensemble without TTA
        ensemble_acc_no_tta = advanced_ensemble_predict(models, test_loader, use_tta=False)
        print(f"Ensemble Accuracy (no TTA): {ensemble_acc_no_tta:.2f}%")

        # Ensemble with TTA
        ensemble_acc_tta = advanced_ensemble_predict(models, test_loader, use_tta=True)
        print(f"Ensemble Accuracy (with TTA): {ensemble_acc_tta:.2f}%")

        final_accuracy = max(max(individual_accuracies), ensemble_acc_no_tta, ensemble_acc_tta)
        print(f"\n{'='*70}")
        print("FINAL ADVANCED SOTA RESULTS")
        print(f"{'='*70}")
        print(f"Best Individual Model: {max(individual_accuracies):.2f}%")
        print(f"Best Ensemble (no TTA): {ensemble_acc_no_tta:.2f}%")
        print(f"Best Ensemble (with TTA): {ensemble_acc_tta:.2f}%")
        print(f"FINAL BEST ACCURACY: {final_accuracy:.2f}%")

        if final_accuracy >= 99.0:
            print(f"🎉🎉 OUTSTANDING! {final_accuracy:.2f}% ≥ 99% TARGET ACHIEVED!")
        elif final_accuracy >= 97.0:
            print(f"🎉 EXCELLENT! {final_accuracy:.2f}% ≥ 97%")
        elif final_accuracy >= 95.0:
            print(f"✅ VERY GOOD! {final_accuracy:.2f}% ≥ 95%")
        elif final_accuracy >= 92.0:
            print(f"🟡 GOOD! {final_accuracy:.2f}% ≥ 92%")
        else:
            print(f"⚠️ NEEDS IMPROVEMENT: {final_accuracy:.2f}%")

def main():
    SEED_ = 42
    print("="*90)
    print("ADVANCED SOTA FASHIONMNIST IMPLEMENTATION FOR 99% TARGET")
    print("FEATURES: Hybrid CNN-Transformer, Advanced Augmentation, TTA, Knowledge Distillation")
    print("="*90)
    print(f"Device: {device}")
    print(f"Ensemble Size: {ENSEMBLE_SIZE}")
    print(f"Epochs per Model: {EPOCHS_PER_DATASET}")
    print(f"Hidden Dim: {HIDDEN_DIM}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Mixed Precision: Enabled")
    print(f"Test-Time Augmentation: {TTA_AUGMENTS} augments")
    print(f"Advanced Features: Hybrid Architecture, Stochastic Depth, Knowledge Distillation")

    torch.manual_seed(SEED_)
    np.random.seed(SEED_)
    random.seed(SEED_)

    # Enable optimizations
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

    run_advanced_experiments()
    print("\n🎉 ADVANCED SOTA EXPERIMENTS COMPLETED!")

