# Deepfake Detection Training Notebook

This notebook trains a LaDeDa ResNet50 model to classify real vs diffusion-generated fake images.

## Features:
- âœ… **Early Stopping** - Stops when validation loss doesn't improve for 4 epochs
- âœ… **Strong Data Augmentation** - Rotation, random crop, grayscale, color jitter
- âœ… **Dropout** - 0.4 probability in classifier head
- âœ… **Weight Decay** - 1e-3 for regularization
- âœ… **Layer Freezing** - Only trains layer4 + classifier (~10% params)

## 1. Setup and Imports

In [None]:
import os
import sys
import json
import random
import io
import time
from pathlib import Path
from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
@dataclass
class TrainingConfig:
    """Training configuration."""
    # Paths
    data_path: str = "./newfinetune2"
    output_dir: str = "./outputs"
    
    # Training parameters
    epochs: int = 30  # Max epochs (early stopping will likely trigger before)
    batch_size: int = 16
    lr: float = 3e-5
    weight_decay: float = 1e-3  # Strong weight decay for regularization
    
    # Early stopping
    early_stopping_patience: int = 4  # Stop if no improvement for 4 epochs
    
    # Model
    freeze_layers: List[str] = field(default_factory=lambda: ["conv1", "layer1", "layer2", "layer3"])
    dropout_rate: float = 0.4  # Dropout for classifier head
    
    # Other
    num_workers: int = 2
    seed: int = 42


@dataclass
class AugmentConfig:
    """Augmentation configuration with stronger augmentations."""
    # Compression/degradation
    jpeg_qualities: List[int] = field(default_factory=lambda: [30, 50, 75, 95])
    resize_scales: List[float] = field(default_factory=lambda: [0.5, 0.75, 1.25, 1.5])
    blur_sigmas: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.5])
    noise_sigmas: List[float] = field(default_factory=lambda: [3.0, 6.0, 10.0])
    
    # Color augmentation - stronger jitter
    brightness_range: Tuple[float, float] = (0.7, 1.3)
    contrast_range: Tuple[float, float] = (0.7, 1.3)
    saturation_range: Tuple[float, float] = (0.7, 1.3)
    hue_range: Tuple[float, float] = (-0.1, 0.1)
    
    # Geometric augmentations
    rotation_range: Tuple[float, float] = (-15, 15)
    random_crop_scale: Tuple[float, float] = (0.8, 1.0)
    random_crop_ratio: Tuple[float, float] = (0.9, 1.1)
    
    # Probabilities
    p_jpeg: float = 0.5
    p_resize: float = 0.3
    p_blur: float = 0.3
    p_noise: float = 0.3
    p_color: float = 0.5
    p_flip: float = 0.5
    p_rotation: float = 0.4
    p_grayscale: float = 0.1
    p_random_crop: float = 0.3


# Initialize configs
config = TrainingConfig()
augment_config = AugmentConfig()

# Constants
TARGET_SIZE = (256, 256)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

print("Configuration loaded:")
print(f"  - Epochs: {config.epochs} (with early stopping patience={config.early_stopping_patience})")
print(f"  - Batch size: {config.batch_size}")
print(f"  - Learning rate: {config.lr}")
print(f"  - Weight decay: {config.weight_decay}")
print(f"  - Dropout rate: {config.dropout_rate}")
print(f"  - Frozen layers: {config.freeze_layers}")

## 3. Set Random Seed

In [None]:
def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

set_seed(config.seed)
print(f"Random seed set to {config.seed}")

## 4. Data Loading Functions

In [None]:
@dataclass
class ImageMetadata:
    """Metadata for a single image."""
    path: Path
    label: int  # 0 = real, 1 = fake
    image_id: str

    @classmethod
    def from_path(cls, path: Path, label: int):
        return cls(path=path, label=label, image_id=path.stem)


def load_split(data_root: Path, split: str) -> List[ImageMetadata]:
    """Load images from a split directory."""
    images = []
    split_dir = data_root / split
    
    for label, class_name in [(0, "real"), (1, "fake")]:
        class_dir = split_dir / class_name
        if not class_dir.exists():
            continue
        
        for ext in ("*.jpg", "*.jpeg", "*.png", "*.JPG", "*.JPEG", "*.PNG"):
            for path in class_dir.glob(ext):
                images.append(ImageMetadata.from_path(path, label))
    
    return images

print("Data loading functions defined.")

## 5. Augmentation Pipeline

In [None]:
def jpeg_compress(image: Image.Image, quality: int) -> Image.Image:
    if image.mode != "RGB":
        image = image.convert("RGB")
    buffer = io.BytesIO()
    image.save(buffer, format="JPEG", quality=quality, subsampling=0)
    buffer.seek(0)
    out = Image.open(buffer)
    out.load()
    buffer.close()
    return out


def resize_chain(image: Image.Image, scale: float) -> Image.Image:
    w, h = image.size
    new_w = max(1, int(w * scale))
    new_h = max(1, int(h * scale))
    image = image.resize((new_w, new_h), Image.Resampling.BILINEAR)
    image = image.resize((w, h), Image.Resampling.BILINEAR)
    return image


def gaussian_blur(image: Image.Image, sigma: float) -> Image.Image:
    return image.filter(ImageFilter.GaussianBlur(radius=sigma))


def add_sensor_noise(image: Image.Image, sigma: float) -> Image.Image:
    arr = np.asarray(image, dtype=np.float32)
    noise = np.random.normal(0, sigma, arr.shape)
    noisy = np.clip(arr + noise, 0, 255).astype(np.uint8)
    return Image.fromarray(noisy)


def adjust_brightness_contrast(image: Image.Image, brightness: float, contrast: float) -> Image.Image:
    arr = np.asarray(image, dtype=np.float32)
    arr = arr * brightness
    mean = arr.mean()
    arr = (arr - mean) * contrast + mean
    return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))


def adjust_saturation(image: Image.Image, factor: float) -> Image.Image:
    enhancer = ImageEnhance.Color(image)
    return enhancer.enhance(factor)


def adjust_hue(image: Image.Image, factor: float) -> Image.Image:
    if image.mode != "RGB":
        image = image.convert("RGB")
    arr = np.asarray(image, dtype=np.float32) / 255.0
    cos_h = np.cos(factor * np.pi)
    sin_h = np.sin(factor * np.pi)
    r, g, b = arr[:,:,0], arr[:,:,1], arr[:,:,2]
    new_r = r * (0.299 + 0.701*cos_h + 0.168*sin_h) + g * (0.587 - 0.587*cos_h + 0.330*sin_h) + b * (0.114 - 0.114*cos_h - 0.497*sin_h)
    new_g = r * (0.299 - 0.299*cos_h - 0.328*sin_h) + g * (0.587 + 0.413*cos_h + 0.035*sin_h) + b * (0.114 - 0.114*cos_h + 0.292*sin_h)
    new_b = r * (0.299 - 0.300*cos_h + 1.250*sin_h) + g * (0.587 - 0.588*cos_h - 1.050*sin_h) + b * (0.114 + 0.886*cos_h - 0.203*sin_h)
    result = np.stack([new_r, new_g, new_b], axis=2)
    result = np.clip(result * 255, 0, 255).astype(np.uint8)
    return Image.fromarray(result)


def random_rotation(image: Image.Image, angle: float) -> Image.Image:
    return image.rotate(angle, resample=Image.Resampling.BILINEAR, expand=False, fillcolor=(128, 128, 128))


def random_resized_crop(image: Image.Image, scale: Tuple[float, float], ratio: Tuple[float, float]) -> Image.Image:
    w, h = image.size
    area = w * h
    target_area = random.uniform(scale[0], scale[1]) * area
    aspect_ratio = random.uniform(ratio[0], ratio[1])
    new_w = int(round(np.sqrt(target_area * aspect_ratio)))
    new_h = int(round(np.sqrt(target_area / aspect_ratio)))
    if new_w <= w and new_h <= h:
        x1 = random.randint(0, w - new_w)
        y1 = random.randint(0, h - new_h)
        image = image.crop((x1, y1, x1 + new_w, y1 + new_h))
    return image.resize((w, h), Image.Resampling.LANCZOS)


def to_grayscale(image: Image.Image) -> Image.Image:
    return image.convert("L").convert("RGB")


class AugmentationPipeline:
    """Augmentation pipeline for training."""
    
    def __init__(self, config: AugmentConfig = None):
        self.config = config or AugmentConfig()

    def __call__(self, image: Image.Image) -> Image.Image:
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        c = self.config
        
        # Geometric augmentations
        if random.random() < c.p_flip:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        if random.random() < c.p_rotation:
            angle = random.uniform(*c.rotation_range)
            image = random_rotation(image, angle)
        
        if random.random() < c.p_random_crop:
            image = random_resized_crop(image, c.random_crop_scale, c.random_crop_ratio)
        
        # Compression/degradation
        if random.random() < c.p_jpeg:
            image = jpeg_compress(image, random.choice(c.jpeg_qualities))
        
        if random.random() < c.p_resize:
            image = resize_chain(image, random.choice(c.resize_scales))
        
        if random.random() < c.p_blur:
            image = gaussian_blur(image, random.choice(c.blur_sigmas))
        
        if random.random() < c.p_noise:
            image = add_sensor_noise(image, random.choice(c.noise_sigmas))
        
        # Color augmentations
        if random.random() < c.p_color:
            image = adjust_brightness_contrast(image, random.uniform(*c.brightness_range), random.uniform(*c.contrast_range))
            image = adjust_saturation(image, random.uniform(*c.saturation_range))
            if random.random() < 0.5:
                image = adjust_hue(image, random.uniform(*c.hue_range))
        
        # Random grayscale
        if random.random() < c.p_grayscale:
            image = to_grayscale(image)
        
        return image

print("Augmentation pipeline defined.")

## 6. Dataset and DataLoader

In [None]:
class DeepfakeDataset(Dataset):
    """Dataset for deepfake detection."""
    
    def __init__(self, images: List[ImageMetadata], augment: bool = True, augment_config: AugmentConfig = None):
        self.images = tuple(images)
        self.augment = augment
        self.aug_pipeline = AugmentationPipeline(augment_config) if augment else None
        
        n_real = sum(1 for img in self.images if img.label == 0)
        n_fake = sum(1 for img in self.images if img.label == 1)
        aug_str = "augmented" if augment else "no augmentation"
        print(f"  Dataset: {len(self.images)} images (Real: {n_real}, Fake: {n_fake}) [{aug_str}]")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        meta = self.images[idx]
        
        with Image.open(meta.path) as img:
            img = ImageOps.exif_transpose(img)
            img = img.convert("RGB")
        
        if self.augment and self.aug_pipeline:
            img = self.aug_pipeline(img)
        
        img = img.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
        arr = np.asarray(img, dtype=np.float32) / 255.0
        arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
        tensor = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
        
        return {
            "image": tensor,
            "label": torch.tensor(meta.label, dtype=torch.long),
            "image_id": meta.image_id,
        }


def collate_fn(batch):
    return {
        "image": torch.stack([b["image"] for b in batch], dim=0),
        "label": torch.stack([b["label"] for b in batch], dim=0),
        "image_id": [b["image_id"] for b in batch],
    }

print("Dataset class defined.")

## 7. Model Definition

In [None]:
class AttentionPooling(nn.Module):
    """Attention-based pooling over patch logits."""
    
    def __init__(self, in_channels=2048, hidden_dim=512):
        super().__init__()
        self.attention_fc = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, 1, kernel_size=1),
        )

    def forward(self, features, patch_logits):
        attn_scores = self.attention_fc(features)
        B, _, H, W = attn_scores.shape
        attn_flat = attn_scores.view(B, -1)
        attn_flat = attn_flat - attn_flat.max(dim=1, keepdim=True)[0]
        attn = F.softmax(attn_flat, dim=1).view(B, 1, H, W)
        pooled = (patch_logits * attn).sum(dim=(2, 3))
        return pooled, attn


class LaDeDaResNet50(nn.Module):
    """LaDeDa-style ResNet50 for deepfake detection with dropout regularization."""
    
    def __init__(self, pretrained=True, freeze_layers=None, dropout_rate=0.4):
        super().__init__()
        
        base = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        if pretrained:
            with torch.no_grad():
                self.conv1.weight.copy_(base.conv1.weight[:, :, 2:5, 2:5])
        
        self.bn1 = base.bn1
        self.relu = base.relu
        self.layer1 = base.layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3
        self.layer4 = base.layer4
        
        # Dropout for regularization
        self.dropout = nn.Dropout2d(p=dropout_rate)
        
        self.patch_classifier = nn.Conv2d(2048, 1, kernel_size=1)
        self.attention_pool = AttentionPooling(2048)
        
        self.freeze_layers = freeze_layers or []
        self._freeze()

    def _freeze(self):
        freeze_map = {
            "conv1": [self.conv1, self.bn1],
            "layer1": [self.layer1],
            "layer2": [self.layer2],
            "layer3": [self.layer3],
            "layer4": [self.layer4],
        }
        for name in self.freeze_layers:
            if name in freeze_map:
                for module in freeze_map[name]:
                    for p in module.parameters():
                        p.requires_grad = False

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.dropout(x)
        
        patch_logits = self.patch_classifier(x)
        pooled_logits, attn = self.attention_pool(x, patch_logits)
        pooled_logits = pooled_logits.view(-1)
        
        return pooled_logits, patch_logits, attn

print("Model class defined.")

## 8. Metrics and Training Functions

In [None]:
def compute_metrics(preds, labels, probs=None) -> Dict:
    """Compute classification metrics."""
    preds = np.asarray(preds).astype(int)
    labels = np.asarray(labels).astype(int)
    
    acc = (preds == labels).mean()
    tp = ((preds == 1) & (labels == 1)).sum()
    fp = ((preds == 1) & (labels == 0)).sum()
    fn = ((preds == 0) & (labels == 1)).sum()
    
    prec = tp / (tp + fp + 1e-8)
    rec = tp / (tp + fn + 1e-8)
    f1 = 2 * prec * rec / (prec + rec + 1e-8)
    
    metrics = {"accuracy": float(acc), "precision": float(prec), "recall": float(rec), "f1": float(f1)}
    
    if probs is not None:
        try:
            from sklearn.metrics import roc_auc_score
            metrics["auc"] = float(roc_auc_score(labels, probs))
        except Exception:
            pass
    
    return metrics


def train_epoch(model, dataloader, criterion, optimizer, scaler, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    all_probs, all_labels = [], []
    
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch in pbar:
        images = batch["image"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True).float()
        
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(enabled=(device.type == "cuda")):
            pooled, _, _ = model(images)
            pooled = pooled.view(-1)
            loss = criterion(pooled, labels)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * labels.size(0)
        probs = torch.sigmoid(pooled).detach().view(-1).cpu().numpy()
        all_probs.extend(probs.tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    all_probs = np.asarray(all_probs)
    all_labels = np.asarray(all_labels)
    
    metrics = compute_metrics((all_probs > 0.5).astype(int), all_labels, all_probs)
    metrics["loss"] = total_loss / len(all_labels)
    return metrics


def validate(model, dataloader, criterion, device):
    """Validate the model."""
    model.eval()
    total_loss = 0.0
    all_probs, all_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation", leave=False):
            images = batch["image"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True).float()
            
            pooled, _, _ = model(images)
            pooled = pooled.view(-1)
            loss = criterion(pooled, labels)
            
            total_loss += loss.item() * labels.size(0)
            probs = torch.sigmoid(pooled).view(-1).cpu().numpy()
            all_probs.extend(probs.tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
    
    all_probs = np.asarray(all_probs)
    all_labels = np.asarray(all_labels)
    
    metrics = compute_metrics((all_probs > 0.5).astype(int), all_labels, all_probs)
    metrics["loss"] = total_loss / len(all_labels)
    return metrics

print("Training functions defined.")

## 9. Load Dataset

In [None]:
# Setup directories
os.makedirs(config.output_dir, exist_ok=True)
checkpoint_dir = os.path.join(config.output_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load data
print("\nLoading dataset...")
data_root = Path(config.data_path)

train_images = load_split(data_root, "train")
val_images = load_split(data_root, "val")
test_images = load_split(data_root, "test")

print(f"  Train: {len(train_images)} images")
print(f"  Val: {len(val_images)} images")
print(f"  Test: {len(test_images)} images")

In [None]:
# Create datasets
print("Creating datasets...")
train_dataset = DeepfakeDataset(train_images, augment=True, augment_config=augment_config)
val_dataset = DeepfakeDataset(val_images, augment=False)
test_dataset = DeepfakeDataset(test_images, augment=False)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 
                          num_workers=config.num_workers, pin_memory=True, drop_last=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False,
                        num_workers=config.num_workers, pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
                         num_workers=config.num_workers, pin_memory=True, collate_fn=collate_fn)

print(f"\nDataLoaders: {len(train_loader)} train, {len(val_loader)} val, {len(test_loader)} test batches")

## 10. Initialize Model

In [None]:
print("Initializing model...")
model = LaDeDaResNet50(pretrained=True, freeze_layers=config.freeze_layers, dropout_rate=config.dropout_rate)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Total params: {total_params:,}")
print(f"  Trainable: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")
print(f"  Frozen: {total_params - trainable_params:,} ({100*(total_params-trainable_params)/total_params:.1f}%)")

## 11. Setup Optimizer and Scheduler

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=config.lr,
    weight_decay=config.weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-6)
scaler = GradScaler()

print("Optimizer: AdamW")
print(f"  Learning rate: {config.lr}")
print(f"  Weight decay: {config.weight_decay}")
print("Scheduler: CosineAnnealingWarmRestarts")

## 12. Training Loop with Early Stopping and Checkpoints

This cell runs the main training loop. Progress is saved after each epoch.

In [None]:
# Initialize tracking variables
best_val_loss = float('inf')
best_val_f1 = 0
epochs_without_improvement = 0
history = {"train": [], "val": []}
start_time = time.time()

print("=" * 60)
print("TRAINING")
print("=" * 60)
print(f"Early stopping patience: {config.early_stopping_patience} epochs")
print()

In [None]:
# Main training loop - Run this cell to train
# You can re-run this cell to continue training from last checkpoint

for epoch in range(len(history["train"]), config.epochs):
    print(f"\nEpoch {epoch + 1}/{config.epochs}")
    print("-" * 40)
    
    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
    
    # Validate
    val_metrics = validate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    history["train"].append(train_metrics)
    history["val"].append(val_metrics)
    
    print(f"Train | Loss: {train_metrics['loss']:.4f} | Acc: {train_metrics['accuracy']:.4f} | F1: {train_metrics['f1']:.4f}")
    print(f"Val   | Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['f1']:.4f}")
    
    # Early stopping check
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        epochs_without_improvement = 0
        print(f"âœ“ New best val loss: {best_val_loss:.4f}")
    else:
        epochs_without_improvement += 1
        print(f"âš  No improvement for {epochs_without_improvement}/{config.early_stopping_patience} epochs")
    
    # Save best model
    if val_metrics["f1"] > best_val_f1:
        best_val_f1 = val_metrics["f1"]
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "val_metrics": val_metrics,
            "history": history,
        }, os.path.join(checkpoint_dir, "best_model.pth"))
        print(f"ðŸ’¾ Saved best model (Val F1: {best_val_f1:.4f})")
    
    # Save latest checkpoint (for resuming)
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "best_val_loss": best_val_loss,
        "best_val_f1": best_val_f1,
        "epochs_without_improvement": epochs_without_improvement,
        "history": history,
    }, os.path.join(checkpoint_dir, "latest_checkpoint.pth"))
    
    # Early stopping
    if epochs_without_improvement >= config.early_stopping_patience:
        print(f"\nðŸ›‘ Early stopping triggered! No improvement for {config.early_stopping_patience} epochs.")
        break

total_time = time.time() - start_time
print(f"\n{'='*60}")
print(f"Training complete in {total_time/60:.1f} minutes")
print(f"Best Val F1: {best_val_f1:.4f}")
print(f"Best Val Loss: {best_val_loss:.4f}")

## 13. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot([m['loss'] for m in history['train']], 'b-', label='Train', linewidth=2)
axes[0].plot([m['loss'] for m in history['val']], 'r-', label='Val', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot([m['accuracy'] for m in history['train']], 'b-', label='Train', linewidth=2)
axes[1].plot([m['accuracy'] for m in history['val']], 'r-', label='Val', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# F1
axes[2].plot([m['f1'] for m in history['train']], 'b-', label='Train', linewidth=2)
axes[2].plot([m['f1'] for m in history['val']], 'r-', label='Val', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('F1 Score')
axes[2].set_title('F1 Score')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config.output_dir, "training_history.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Training history saved to: {os.path.join(config.output_dir, 'training_history.png')}")

## 14. Load Best Model and Evaluate on Test Set

In [None]:
print("Loading best model for evaluation...")
checkpoint = torch.load(os.path.join(checkpoint_dir, "best_model.pth"), map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded model from epoch {checkpoint['epoch'] + 1}")

In [None]:
def evaluate(model, dataloader, device):
    """Evaluate the model."""
    model.eval()
    all_probs, all_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            images = batch["image"].to(device)
            labels = batch["label"].cpu().numpy().astype(int)
            
            pooled, _, _ = model(images)
            probs = torch.sigmoid(pooled).view(-1).cpu().numpy()
            
            all_probs.extend(probs.tolist())
            all_labels.extend(labels.tolist())
    
    all_probs = np.asarray(all_probs)
    all_labels = np.asarray(all_labels)
    
    return compute_metrics((all_probs > 0.5).astype(int), all_labels, all_probs)

print("\n" + "=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

test_metrics = evaluate(model, test_loader, device)
print(f"\nTest Results:")
print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  F1:        {test_metrics['f1']:.4f}")
if "auc" in test_metrics:
    print(f"  AUC:       {test_metrics['auc']:.4f}")

## 15. Save Final Results

In [None]:
results = {
    "config": {
        "epochs_trained": len(history["train"]),
        "max_epochs": config.epochs,
        "batch_size": config.batch_size,
        "lr": config.lr,
        "weight_decay": config.weight_decay,
        "dropout_rate": config.dropout_rate,
        "early_stopping_patience": config.early_stopping_patience,
        "frozen_layers": config.freeze_layers,
    },
    "best_val_f1": best_val_f1,
    "best_val_loss": best_val_loss,
    "test_metrics": test_metrics,
    "training_time_minutes": total_time / 60,
}

with open(os.path.join(config.output_dir, "results.json"), "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {os.path.join(config.output_dir, 'results.json')}")
print(f"\nâœ… Training complete!")

## Optional: Resume Training from Checkpoint

Run this cell if you need to resume training after a kernel restart.

In [None]:
# Uncomment and run to resume from latest checkpoint

# checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth")
# if os.path.exists(checkpoint_path):
#     checkpoint = torch.load(checkpoint_path, map_location=device)
#     model.load_state_dict(checkpoint["model_state_dict"])
#     optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
#     scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
#     best_val_loss = checkpoint["best_val_loss"]
#     best_val_f1 = checkpoint["best_val_f1"]
#     epochs_without_improvement = checkpoint["epochs_without_improvement"]
#     history = checkpoint["history"]
#     print(f"Resumed from epoch {checkpoint['epoch'] + 1}")
#     print(f"Best val F1: {best_val_f1:.4f}, Best val loss: {best_val_loss:.4f}")
# else:
#     print("No checkpoint found.")