In [7]:
# Advanced PyTorch ResNet-50 Training Pipeline (From-scratch, >78% recipe)
# -----------------------------------------------------------
# Pure PyTorch implementation with full recipe to reach >78% Top-1 on ImageNet-100
# Features included:
# - Full customizable ResNet-50 (all classes visible and editable)
# - Stochastic Depth (DropPath) in Bottleneck blocks
# - Strong augmentations: RandomResizedCrop, RandAugment, MixUp, CutMix, ColorJitter, RandomErasing
# - Label smoothing
# - AMP (torch.amp) mixed precision training
# - Gradient accumulation to simulate large global batch
# - EMA (Exponential Moving Average) of model weights
# - Cosine LR scheduler with linear warmup
# - Checkpointing and resume support
# - Designed to run on single GPU (RTX 3070 / Colab Free) with sensible defaults

# %%
# 0) Install dependencies (run in Colab/local once)
!pip install --upgrade pip
!pip install torch torchvision datasets --quiet

# %%
# 1) Imports
import os
import math
import time
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from datasets import load_dataset

# %%
# 2) Configuration / Hyperparameters (tweakable)
CFG = {
    'data_hf_repo': 'clane9/imagenet-100',
    'data_dir': './imagenet100_hf',
    'num_classes': 100,
    'img_size': 224,
    'batch_size': 32,                # per-step batch size
    'accum_steps': 8,                # accumulation to simulate effective batch (32*8=256)
    'epochs': 120,
    'base_lr': 0.002,                # good starting LR for AdamW with effective batch 256
    'weight_decay': 0.05,
    'opt': 'adamw',                  # 'sgd' or 'adamw'
    'momentum': 0.9,
    'warmup_epochs': 5,
    'min_lr': 1e-6,
    'label_smoothing': 0.1,
    'mixup_alpha': 0.8,
    'cutmix_alpha': 1.0,
    'use_mixup': True,
    'use_cutmix': True,
    'randaugment_n': 2,
    'randaugment_m': 9,
    'reprob': 0.25,                  # random erase prob
    'drop_path_prob': 0.2,           # stochastic depth max probability
    'ema_decay': 0.9999,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'save_dir': './checkpoints',
    'save_every': 5,
}

os.makedirs(CFG['data_dir'], exist_ok=True)
os.makedirs(CFG['save_dir'], exist_ok=True)

# %%
# 3) Prepare dataset from Hugging Face (robust conversion)
print('Loading HuggingFace dataset...')
dataset = load_dataset(CFG['data_hf_repo'])

TRAIN_DIR = Path(CFG['data_dir']) / 'train'
VAL_DIR = Path(CFG['data_dir']) / 'val'

for split_name, split_path in [('train', TRAIN_DIR), ('validation', VAL_DIR)]:
    split_path.mkdir(parents=True, exist_ok=True)
    split_data = dataset[split_name]
    print(f'Converting {split_name}, num_samples={len(split_data)} ->', split_path)
    for idx, item in enumerate(split_data):
        label = item['label']
        label_dir = split_path / str(label)
        label_dir.mkdir(parents=True, exist_ok=True)
        img = item['image']
        # ensure PIL Image
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        img.save(label_dir / f'{idx}.jpg')

print('Dataset conversion done.')


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)



[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip
ERROR: To modify pip, please run the following command:
E:\ML\ERA\s9\venv_RESNET\Scripts\python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Loading HuggingFace dataset...


Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

Converting train, num_samples=126689 -> imagenet100_hf\train
Converting validation, num_samples=5000 -> imagenet100_hf\val
Dataset conversion done.


In [9]:
import torch

# Check if a CUDA-enabled GPU is available
if torch.cuda.is_available():
    # Get the name of the GPU
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU Device: {gpu_name}")
else:
    print("No GPU available.")

GPU Device: NVIDIA GeForce RTX 3070 Laptop GPU


In [None]:

# %%
# 4) Transforms & Dataloaders
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(CFG['img_size']),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandAugment(num_ops=CFG['randaugment_n'], magnitude=CFG['randaugment_m']),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=CFG['reprob']),
])

val_transform = transforms.Compose([
    transforms.Resize(int(CFG['img_size'] * 256 / 224)),
    transforms.CenterCrop(CFG['img_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class ImageFolderFromHF(Dataset):
    def __init__(self, root, transform=None):
        self.root = Path(root)
        self.transform = transform
        self.samples = []
        for class_dir in sorted([p for p in self.root.iterdir() if p.is_dir()], key=lambda x: int(x.name)):
            label = int(class_dir.name)
            for img_path in class_dir.glob('*.*'):
                self.samples.append((img_path, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        p, label = self.samples[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

train_ds = ImageFolderFromHF(TRAIN_DIR, transform=train_transform)
val_ds = ImageFolderFromHF(VAL_DIR, transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=CFG['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CFG['batch_size'], shuffle=False, num_workers=4, pin_memory=True)

print('Train samples:', len(train_ds), 'Val samples:', len(val_ds))


In [None]:

# %%
# 5) Utilities: Mixup / CutMix / Label Smoothing
import random
from torch.distributions.beta import Beta

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = math.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = random.randint(0, W)
    cy = random.randint(0, H)

    bbx1 = max(0, cx - cut_w // 2)
    bby1 = max(0, cy - cut_h // 2)
    bbx2 = min(W, cx + cut_w // 2)
    bby2 = min(H, cy + cut_h // 2)

    return bbx1, bby1, bbx2, bby2


def mixup_data(x, y, alpha=0.8):
    if alpha > 0:
        lam = Beta(alpha, alpha).sample().item()
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = Beta(alpha, alpha).sample().item()
    else:
        lam = 1
    index = torch.randperm(x.size(0)).to(x.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))


In [None]:

# %%
# 6) Model: ResNet-50 with Stochastic Depth (DropPath)
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (as in https://arxiv.org/abs/1603.09382)
    """
    def __init__(self, drop_prob: float = 0.0):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1, downsample=None, drop_prob=0.0):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.drop_path = DropPath(drop_prob) if drop_prob > 0.0 else nn.Identity()
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.drop_path(out)
        out += identity
        out = self.relu(out)
        return out

class ResNet50Custom(nn.Module):
    def __init__(self, num_classes=100, drop_path_prob=0.0):
        super(ResNet50Custom, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # layers
        self.layer1 = self._make_layer(64, 3, drop_path_prob)
        self.layer2 = self._make_layer(128, 4, drop_path_prob, stride=2)
        self.layer3 = self._make_layer(256, 6, drop_path_prob, stride=2)
        self.layer4 = self._make_layer(512, 3, drop_path_prob, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
    def _make_layer(self, planes, blocks, drop_path_prob, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * Bottleneck.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * Bottleneck.expansion),
            )
        layers = []
        # linearly scale drop path probability across blocks
        for i in range(blocks):
            prob = drop_path_prob * (i / max(1, blocks - 1))
            if i == 0:
                layers.append(Bottleneck(self.in_planes, planes, stride, downsample, drop_prob=prob))
            else:
                layers.append(Bottleneck(self.in_planes, planes, 1, None, drop_prob=prob))
            self.in_planes = planes * Bottleneck.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# %%
# 7) EMA Helper
class ModelEMA:
    """Simple EMA of model parameters"""
    def __init__(self, model, decay=0.9999, device=None):
        self.ema_model = self._clone_model(model)
        self.decay = decay
        self.device = device
        if device:
            self.ema_model.to(device)
        for p in self.ema_model.parameters():
            p.requires_grad_(False)
    def _clone_model(self, model):
        import copy
        m = copy.deepcopy(model)
        return m
    def update(self, model):
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema_model.state_dict().items():
                model_v = msd[k].detach()
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
    def state_dict(self):
        return self.ema_model.state_dict()
    def load_state_dict(self, sd):
        self.ema_model.load_state_dict(sd)

# %%


In [None]:
# 8) Create model, optimizer, criterion, scaler, scheduler
device = torch.device(CFG['device'])
model = ResNet50Custom(num_classes=CFG['num_classes'], drop_path_prob=CFG['drop_path_prob']).to(device)

if CFG['opt'] == 'adamw':
    optimizer = optim.AdamW(model.parameters(), lr=CFG['base_lr'], weight_decay=CFG['weight_decay'])
else:
    optimizer = optim.SGD(model.parameters(), lr=CFG['base_lr'], momentum=CFG['momentum'], weight_decay=CFG['weight_decay'])

# label smoothing criterion
criterion = LabelSmoothingLoss(CFG['num_classes'], smoothing=CFG['label_smoothing'])
scaler = torch.cuda.amp.GradScaler()

# LR scheduler: cosine with linear warmup
total_steps = math.ceil(len(train_loader) / 1) * CFG['epochs']  # approximate steps
warmup_steps = max(1, int(len(train_loader) * CFG['warmup_epochs']))

def get_lr(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    # cosine decay after warmup
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1. + math.cos(math.pi * progress))

# EMA
ema = ModelEMA(model, decay=CFG['ema_decay'], device=device)

# optionally load checkpoint
start_epoch = 0
ckpt_path = os.path.join(CFG['save_dir'], 'last_checkpoint.pth')
if os.path.exists(ckpt_path):
    print('Loading checkpoint', ckpt_path)
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    optimizer.load_state_dict(ckpt['opt_state'])
    start_epoch = ckpt.get('epoch', 0) + 1
    ema.load_state_dict(ckpt.get('ema_state', ema.state_dict()))


In [None]:


# %%
# 9) Training loop (with AMP, accumulation, mixup/cutmix, dynamic LR)
# Phase 2 fixes applied:
# - LR warmup logging & per-step lr set
# - Disable MixUp/CutMix during warmup_epochs
# - EMA decay ramp (start smaller during warmup, then increase)\# - Log both EMA and raw model validation accuracies
# - Ensure correct GradScaler usage order and gradient clipping

def get_ema_decay(epoch, base_decay=CFG['ema_decay'], warmup_epochs=CFG['warmup_epochs']):
    # Ramp EMA decay from a smaller value to base_decay over warmup_epochs
    if epoch < warmup_epochs:
        # start with weaker EMA (so it doesn't dominate early noisy weights)
        start = 0.9
        return start + (base_decay - start) * (epoch / max(1, warmup_epochs))
    return base_decay


def train_one_epoch(epoch):
    import sys
    print(f"  [DEBUG] Entered train_one_epoch function", flush=True)
    sys.stdout.flush()
    
    model.train()
    print(f"  [DEBUG] Set model to train mode", flush=True)
    sys.stdout.flush()
    
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    from tqdm.auto import tqdm
    print(f"  [DEBUG] Imported tqdm", flush=True)
    sys.stdout.flush()
    
    print(f"  Training epoch {epoch+1}...", flush=True)
    print(f"  Loading first batch (this may take 10-30 seconds)...", flush=True)
    print(f"  [DEBUG] About to create iterator from train_loader...", flush=True)
    sys.stdout.flush()
    
    # Test: Try to get first batch without tqdm
    try:
        print(f"  [DEBUG] Calling iter(train_loader)...", flush=True)
        sys.stdout.flush()
        train_iter = iter(train_loader)
        print(f"  [DEBUG] Iterator created successfully!", flush=True)
        sys.stdout.flush()
        
        print(f"  [DEBUG] Calling next(train_iter)...", flush=True)
        sys.stdout.flush()
        first_images, first_labels = next(train_iter)
        print(f"  ✓ First batch loaded! Shape: {first_images.shape}, Labels: {first_labels.shape}", flush=True)
        print(f"  Batch size actual: {first_images.size(0)}, Expected: {CFG['batch_size']}", flush=True)
        del train_iter  # Clean up
    except Exception as e:
        print(f"  ✗ ERROR loading first batch: {e}", flush=True)
        import traceback
        traceback.print_exc()
        return
    
    sys.stdout.flush()
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG["epochs"]}', ncols=100, mininterval=2.0)
    for i, (images, labels) in enumerate(pbar):
        
        step = epoch * len(train_loader) + i
        # compute LR multiplier (cosine schedule with warmup)
        lr_scale = get_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(CFG['min_lr'], CFG['base_lr'] * lr_scale)
        current_lr = optimizer.param_groups[0]['lr']

        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Disable heavy regularization (mixup/cutmix) during warmup epochs
        use_mix = (CFG['use_mixup'] and (epoch >= CFG['warmup_epochs']) and random.random() < 0.5)
        use_cut = (CFG['use_cutmix'] and (epoch >= CFG['warmup_epochs']) and random.random() < 0.5)

        if use_mix:
            images, y_a, y_b, lam = mixup_data(images, labels, alpha=CFG['mixup_alpha'])
        elif use_cut:
            images, y_a, y_b, lam = cutmix_data(images, labels, alpha=CFG['cutmix_alpha'])
        else:
            y_a, y_b, lam = labels, None, None

        with torch.amp.autocast(device_type='cuda' if device.type=='cuda' else None):
            outputs = model(images)
            if use_mix or use_cut:
                loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)
            else:
                loss = criterion(outputs, labels)
            loss = loss / CFG['accum_steps']

        # Backward + optimization (correct scaler order)
        scaler.scale(loss).backward()

        if (i + 1) % CFG['accum_steps'] == 0:
            # unscale before clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scaler.step(optimizer)       # scaler steps the optimizer
            scaler.update()              # then scaler updates
            optimizer.zero_grad()

            # update EMA with ramped decay
            ema.decay = get_ema_decay(epoch)
            ema.update(model)

        running_loss += loss.item() * images.size(0) * CFG['accum_steps']
        _, preds = outputs.max(1)
        running_total += labels.size(0)
        running_correct += (preds == labels).sum().item()
        
        # Update progress bar and print status
        if i % 10 == 0:
            pbar.set_postfix({
                'loss': f'{running_loss/running_total:.4f}',
                'acc': f'{running_correct/running_total:.4f}',
                'lr': f'{current_lr:.6f}'
            })
        
        # Print every 100 batches for extra visibility
        if i > 0 and i % 100 == 0:
            print(f"    Batch {i}/{len(train_loader)}: loss={running_loss/running_total:.4f}, acc={running_correct/running_total:.4f}", flush=True)
            sys.stdout.flush()

    epoch_loss = running_loss / running_total
    epoch_acc = running_correct / running_total
    print(f'Epoch {epoch+1} Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} LR: {current_lr:.6f}', flush=True)


def validate(epoch, use_ema=True):
    # Evaluate both EMA-weighted model and raw model for debugging
    import sys
    from tqdm.auto import tqdm
    print(f"  Starting validation...", flush=True)
    results = {}
    for name, model_to_eval in [('raw', model), ('ema', ema.ema_model)]:
        model_to_eval.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Validation ({name})', leave=False)
            for images, labels in val_pbar:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                with torch.amp.autocast(device_type='cuda' if device.type=='cuda' else None):
                    outputs = model_to_eval(images)
                _, preds = outputs.max(1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()
                
                # Update validation progress
                val_pbar.set_postfix({'acc': f'{correct/total:.4f}'})
        acc = correct / total
        print(f'Validation Acc ({name}): {acc:.4f}', flush=True)
        results[name] = acc
    return results


### PATCH: Fix DataLoader for Windows (Run this cell before training!)


In [None]:
# CRITICAL FIX: Recreate DataLoaders with num_workers=0 for Windows
print("Recreating DataLoaders with num_workers=0 for Windows compatibility...", flush=True)

train_loader = DataLoader(
    train_ds, 
    batch_size=CFG['batch_size'], 
    shuffle=True, 
    num_workers=0,  # CHANGED FROM 4 TO 0
    pin_memory=True
)

val_loader = DataLoader(
    val_ds, 
    batch_size=CFG['batch_size'], 
    shuffle=False, 
    num_workers=0,  # CHANGED FROM 4 TO 0
    pin_memory=True
)

print(f"✓ DataLoaders recreated successfully", flush=True)
print(f"  Train batches: {len(train_loader)}", flush=True)
print(f"  Val batches: {len(val_loader)}", flush=True)
print(f"  Num workers: {train_loader.num_workers}", flush=True)
print("Now run the training cell!", flush=True)


In [None]:

# 10) Run training with new debug guards
import sys
sys.stdout.flush()  # Clear any buffered output
print("="*60, flush=True)
print("Training Started - Outputs will appear in real-time", flush=True)
print("="*60, flush=True)

# Diagnostic checks
print(f"Device: {device}", flush=True)
print(f"Model on device: {next(model.parameters()).device}", flush=True)
print(f"Total epochs to run: {CFG['epochs']}", flush=True)
print(f"Start epoch: {start_epoch}", flush=True)
print(f"Train loader batches: {len(train_loader)}", flush=True)
print(f"Train loader num_workers: {train_loader.num_workers}", flush=True)
print(f"Train dataset length: {len(train_loader.dataset)}", flush=True)

# Quick test: Can we access the dataset directly?
print("\n[DEBUG] Testing dataset access...", flush=True)
try:
    test_sample = train_loader.dataset[0]
    print(f"  ✓ Dataset[0] accessible, type: {type(test_sample)}", flush=True)
except Exception as e:
    print(f"  ✗ ERROR accessing dataset[0]: {e}", flush=True)

print("Starting training loop...", flush=True)
sys.stdout.flush()

best_val = 0.0
for epoch in range(start_epoch, CFG['epochs']):
    print(f"\n>>> Starting Epoch {epoch+1}/{CFG['epochs']} <<<", flush=True)
    sys.stdout.flush()
    
    tic = time.time()
    
    try:
        train_one_epoch(epoch)
        print(f">>> Training epoch {epoch+1} completed <<<", flush=True)
    except Exception as e:
        print(f"ERROR in train_one_epoch: {e}", flush=True)
        import traceback
        traceback.print_exc()
        break
    
    try:
        val_results = validate(epoch, use_ema=True)
        print(f">>> Validation completed <<<", flush=True)
    except Exception as e:
        print(f"ERROR in validate: {e}", flush=True)
        import traceback
        traceback.print_exc()
        val_results = {'ema': 0.0}
    toc = time.time()
    print(f'Epoch time: {(toc-tic)/60:.2f} mins', flush=True)

    # checkpoint
    val_acc = val_results.get('ema', 0.0)
    if (epoch + 1) % CFG['save_every'] == 0 or val_acc > best_val:
        ckpt = {
            'epoch': epoch,
            'model_state': model.state_dict(),
            'opt_state': optimizer.state_dict(),
            'ema_state': ema.state_dict(),
            'cfg': CFG,
        }
        torch.save(ckpt, os.path.join(CFG['save_dir'], f'checkpoint_epoch{epoch+1}.pth'))
        torch.save(ckpt, os.path.join(CFG['save_dir'], 'last_checkpoint.pth'))
        print('Saved checkpoint', flush=True)
    if val_acc > best_val:
        best_val = val_acc

print('Training finished. Best Val Acc (ema):', best_val, flush=True)

# End of Phase 2 updates: debug guards applied (EMA ramp, disabled mixup in warmup, raw+ema eval).
