# Facial Keypoint Detection
This notebook implements facial keypoint detection (68 landmarks) using four approaches:
1. **CNN** — Custom convolutional network (regression)
2. **ResNet** — Transfer learning with ResNet18 (regression)
3. **DINO** — Transfer learning with DINO Vision Transformer (regression)
4. **U-Net** — Heatmap-based keypoint detection

---
## 1. Setup

In [1]:
import os
import random
import subprocess
import zipfile

import cv2
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from PIL import Image
from scipy.ndimage import gaussian_filter
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

objc[69209]: Class AVFFrameReceiver is implemented in both /Users/mersadabbasi/anaconda3/lib/python3.13/site-packages/cv2/.dylibs/libavdevice.61.3.100.dylib (0x104f5c3a8) and /Users/mersadabbasi/anaconda3/lib/python3.13/site-packages/av/.dylibs/libavdevice.61.3.100.dylib (0x121ce43a8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[69209]: Class AVFAudioReceiver is implemented in both /Users/mersadabbasi/anaconda3/lib/python3.13/site-packages/cv2/.dylibs/libavdevice.61.3.100.dylib (0x104f5c3f8) and /Users/mersadabbasi/anaconda3/lib/python3.13/site-packages/av/.dylibs/libavdevice.61.3.100.dylib (0x121ce43f8). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.


Using device: cpu


### 1.1 Download Data

In [2]:
def download_data(data_dir='data'):
    if os.path.exists(os.path.join(data_dir, 'training')) and os.path.exists(os.path.join(data_dir, 'test')):
        print("Data already exists.")
        return
    print("Data not found. Downloading...")
    os.makedirs(data_dir, exist_ok=True)
    zip_path = os.path.join(data_dir, 'train-test-data.zip')
    subprocess.run([
        'wget', '-q', '-O', zip_path,
        'https://s3.amazonaws.com/video.udacity-data.com/topher/2018/May/5aea1b91_train-test-data/train-test-data.zip'
    ], check=True)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(data_dir)
    os.remove(zip_path)
    print("Data downloaded and extracted.")

download_data()

Data already exists.


### 1.2 Transforms

In [3]:
class Rescale:
    """Rescale the image in a sample to a given size."""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        new_h, new_w = int(new_h), int(new_w)
        img = cv2.resize(image, (new_w, new_h))
        key_pts = key_pts * [new_w / w, new_h / h]
        return {'image': img, 'keypoints': key_pts}


class RandomCrop:
    """Crop randomly the image in a sample."""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size

    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        image = image[top:top + new_h, left:left + new_w]
        key_pts = key_pts - [left, top]
        return {'image': image, 'keypoints': key_pts}


class NormalizeOriginal:
    """Convert to grayscale and normalize image [0,1], keypoints (pts-100)/50."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        image_copy = cv2.cvtColor(np.copy(image), cv2.COLOR_RGB2GRAY)
        image_copy = (image_copy / 255.0).astype(np.float32)
        key_pts_copy = (np.copy(key_pts) - 100) / 50.0
        return {'image': image_copy, 'keypoints': key_pts_copy}


class ToTensor:
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        if len(image.shape) == 2:
            image = image.reshape(image.shape[0], image.shape[1], 1)
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'keypoints': torch.from_numpy(key_pts)}


class RandomHorizontalFlip:
    """Random horizontal flip with correct keypoint mirroring for 68 landmarks."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        image_copy = np.copy(image)
        key_pts_copy = np.copy(key_pts)
        key_pts_copy_2 = np.copy(key_pts_copy)

        if random.choice([0, 1]) <= 0.5:
            image_copy = np.fliplr(image_copy)
            key_pts_copy[:, 0] = -key_pts_copy[:, 0]
            key_pts_copy[:, 0] = key_pts_copy[:, 0] + image_copy.shape[1]
            key_pts_copy_2 = np.copy(key_pts_copy)

            # Mirror jawline
            for i, j in zip(range(0, 8), range(16, 8, -1)):
                key_pts_copy_2[j] = key_pts_copy[i]
                key_pts_copy_2[i] = key_pts_copy[j]
            # Mirror eyebrows
            for i, j in zip(range(17, 22), range(26, 21, -1)):
                key_pts_copy_2[j] = key_pts_copy[i]
                key_pts_copy_2[i] = key_pts_copy[j]
            # Mirror nose tip
            key_pts_copy_2[35] = key_pts_copy[31]
            key_pts_copy_2[34] = key_pts_copy[32]
            key_pts_copy_2[32] = key_pts_copy[34]
            key_pts_copy_2[31] = key_pts_copy[35]
            # Mirror eyes
            key_pts_copy_2[45] = key_pts_copy[36]; key_pts_copy_2[44] = key_pts_copy[37]
            key_pts_copy_2[43] = key_pts_copy[38]; key_pts_copy_2[42] = key_pts_copy[39]
            key_pts_copy_2[47] = key_pts_copy[40]; key_pts_copy_2[46] = key_pts_copy[41]
            key_pts_copy_2[39] = key_pts_copy[42]; key_pts_copy_2[38] = key_pts_copy[43]
            key_pts_copy_2[37] = key_pts_copy[44]; key_pts_copy_2[36] = key_pts_copy[45]
            key_pts_copy_2[41] = key_pts_copy[46]; key_pts_copy_2[40] = key_pts_copy[47]
            # Mirror lips
            key_pts_copy_2[54] = key_pts_copy[48]; key_pts_copy_2[53] = key_pts_copy[49]
            key_pts_copy_2[52] = key_pts_copy[50]; key_pts_copy_2[50] = key_pts_copy[52]
            key_pts_copy_2[49] = key_pts_copy[53]; key_pts_copy_2[48] = key_pts_copy[54]
            key_pts_copy_2[59] = key_pts_copy[55]; key_pts_copy_2[58] = key_pts_copy[56]
            key_pts_copy_2[56] = key_pts_copy[58]; key_pts_copy_2[55] = key_pts_copy[59]
            key_pts_copy_2[64] = key_pts_copy[60]; key_pts_copy_2[63] = key_pts_copy[61]
            key_pts_copy_2[61] = key_pts_copy[63]; key_pts_copy_2[60] = key_pts_copy[64]
            key_pts_copy_2[67] = key_pts_copy[65]; key_pts_copy_2[65] = key_pts_copy[67]

        return {'image': image_copy, 'keypoints': key_pts_copy_2}


class RandomRotate:
    """Random rotation by +/- angle degrees."""
    def __init__(self, rotation=30):
        self.rotation = rotation

    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        image_copy = np.copy(image)
        key_pts_copy = np.copy(key_pts)
        rows, cols = image.shape[:2]
        M = cv2.getRotationMatrix2D((rows / 2, cols / 2), random.choice([-self.rotation, self.rotation]), 1)
        image_copy = cv2.warpAffine(image_copy, M, (cols, rows))
        key_pts_copy = key_pts_copy.reshape((1, 136))
        new_keypoints = np.zeros(136)
        for i in range(68):
            coord_idx = 2 * i
            old_coord = key_pts_copy[0][coord_idx:coord_idx + 2]
            new_coord = np.matmul(M, np.append(old_coord, 1))
            new_keypoints[coord_idx] += new_coord[0]
            new_keypoints[coord_idx + 1] += new_coord[1]
        return {'image': image_copy, 'keypoints': new_keypoints.reshape((68, 2))}


class ColorJitter:
    """Random color jitter (brightness, contrast, saturation)."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
        image_copy = np.copy(image)
        if image_copy.dtype != np.uint8:
            image_copy = (image_copy * 255).astype(np.uint8)
        image_copy = np.array(color_jitter(Image.fromarray(image_copy)))
        return {'image': image_copy, 'keypoints': np.copy(key_pts)}

### 1.3 Datasets

In [4]:
class FacialKeypointsDataset(Dataset):
    """Face Landmarks dataset for regression-based models."""
    def __init__(self, csv_file, root_dir, transform=None):
        self.key_pts_frame = pd.read_csv(csv_file, index_col=0)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.key_pts_frame)

    def __getitem__(self, idx):
        image_name = os.path.join(self.root_dir, self.key_pts_frame.index[idx])
        image = mpimg.imread(image_name)
        if image.shape[2] == 4:
            image = image[:, :, 0:3]
        key_pts = self.key_pts_frame.iloc[idx, :].values.astype('float').reshape(-1, 2)
        sample = {'image': image, 'keypoints': key_pts}
        if self.transform:
            sample = self.transform(sample)
        return sample


class FacialKeypointsHeatmapDataset(Dataset):
    """Face Landmarks dataset with heatmap generation for U-Net."""
    def __init__(self, csv_file, root_dir, transform=None, output_size=64, sigma=1, image_size=224):
        self.key_pts_frame = pd.read_csv(csv_file, index_col=0)
        self.root_dir = root_dir
        self.transform = transform
        self.output_size = output_size
        self.sigma = sigma
        self.image_size = image_size

    def __len__(self):
        return len(self.key_pts_frame)

    def __getitem__(self, idx):
        image_name = os.path.join(self.root_dir, self.key_pts_frame.index[idx])
        image = mpimg.imread(image_name)
        if image.shape[2] == 4:
            image = image[:, :, 0:3]
        key_pts = self.key_pts_frame.iloc[idx, :].values.astype('float').reshape(-1, 2)
        sample = {'image': image, 'keypoints': key_pts}
        if self.transform:
            sample = self.transform(sample)
        sample['heatmaps'] = self.generate_heatmaps(sample['keypoints'])
        return sample

    def generate_heatmaps(self, keypoints):
        if isinstance(keypoints, torch.Tensor):
            keypoints = keypoints.numpy()
        num_keypoints = keypoints.shape[0]
        heatmaps = np.zeros((num_keypoints, self.output_size, self.output_size), dtype=np.float32)
        keypoints_scaled = (keypoints * 50 + 100) * (self.output_size / self.image_size)
        for i in range(num_keypoints):
            x, y = keypoints_scaled[i]
            if np.isnan(x) or np.isnan(y):
                continue
            x_int = max(0, min(self.output_size - 1, int(x)))
            y_int = max(0, min(self.output_size - 1, int(y)))
            heatmap = np.zeros((self.output_size, self.output_size), dtype=np.float32)
            heatmap[y_int, x_int] = 1.0
            heatmap = gaussian_filter(heatmap, sigma=self.sigma)
            if heatmap.max() > 0:
                heatmap = heatmap / heatmap.max()
            heatmaps[i] = heatmap
        return torch.from_numpy(heatmaps)

### 1.4 Data Loaders

In [5]:
def load_regression_data(batch_size=64):
    """Load datasets and dataloaders for regression-based models (CNN, ResNet, DINO)."""
    data_transform = transforms.Compose([
        Rescale(250), RandomCrop(224), NormalizeOriginal(), ToTensor()
    ])
    train_dataset = FacialKeypointsDataset('data/training_frames_keypoints.csv', 'data/training', data_transform)
    test_dataset = FacialKeypointsDataset('data/test_frames_keypoints.csv', 'data/test', data_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_loader, test_loader


def load_heatmap_data(batch_size=256, heatmap_size=64):
    """Load datasets and dataloaders for heatmap-based models (U-Net)."""
    train_transform = transforms.Compose([
        Rescale(250), RandomCrop(224), RandomHorizontalFlip(), RandomRotate(15),
        ColorJitter(), NormalizeOriginal(), ToTensor()
    ])
    test_transform = transforms.Compose([
        Rescale((224, 224)), NormalizeOriginal(), ToTensor()
    ])
    train_dataset = FacialKeypointsHeatmapDataset(
        'data/training_frames_keypoints.csv', 'data/training',
        transform=train_transform, output_size=heatmap_size, sigma=2, image_size=224)
    test_dataset = FacialKeypointsHeatmapDataset(
        'data/test_frames_keypoints.csv', 'data/test',
        transform=test_transform, output_size=heatmap_size, sigma=2, image_size=224)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    return train_loader, test_loader

### 1.5 Training & Evaluation Utilities

In [None]:
def save_checkpoint(model, optimizer, epoch, step, model_name, path='checkpoints/last_checkpoint.pth'):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch, 'step': step, 'model_name': model_name,
    }, path)
    print(f"Checkpoint saved to {path}")


def evaluate_regression(model, test_loader, criterion, device):
    """Evaluate a regression model (CNN/ResNet/DINO) on the test set."""
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            keypoints = batch['keypoints'].view(images.size(0), -1).to(device)
            outputs = model(images)
            loss = model.compute_loss(outputs, keypoints, criterion)
            total_loss += loss.item()
    return total_loss / len(test_loader)


def evaluate_heatmap(model, test_loader, device, loss_type='mse'):
    """Evaluate U-Net on the test set."""
    model.eval()
    if loss_type == 'mse':
        criterion = nn.MSELoss()
    else:
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(device))
    total_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            target = batch['heatmaps'].to(device)
            logits = model(images)
            if loss_type == 'mse':
                total_loss += criterion(torch.sigmoid(logits), target).item()
            else:
                total_loss += criterion(logits, target).item()
    return total_loss / len(test_loader)


def heatmaps_to_keypoints(heatmaps, heatmap_size=64, image_size=224):
    """Extract keypoint coordinates from heatmaps using argmax."""
    batch_size, num_kpts, h, w = heatmaps.shape
    heatmaps_flat = heatmaps.view(batch_size, num_kpts, -1)
    max_indices = heatmaps_flat.argmax(dim=2)
    y_coords = (max_indices // w).float()
    x_coords = (max_indices % w).float()
    x_coords = x_coords * (image_size / heatmap_size)
    y_coords = y_coords * (image_size / heatmap_size)
    x_norm = (x_coords - 100) / 50.0
    y_norm = (y_coords - 100) / 50.0
    return torch.stack([x_norm, y_norm], dim=2)


def train_regression(model, train_loader, test_loader, optimizer, criterion, device,
                     model_name='model', num_epochs=10, eval_interval=10,
                     log_interval=5, save_interval=30):
    """Training loop for regression-based models (CNN, ResNet, DINO)."""
    step = 0
    running_loss = 0

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            step += 1
            image = batch['image'].to(device)
            keypoints = batch['keypoints'].reshape(image.size(0), 136).float().to(device)

            preds = model(image)
            loss = model.compute_loss(preds, keypoints, criterion)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if step % log_interval == 0:
                avg_loss = running_loss / log_interval
                wandb.log({'train_loss': avg_loss, 'step': step})
                print(f"Step {step}: Train Loss = {avg_loss:.4f}")
                running_loss = 0

            if step % eval_interval == 0:
                val_loss = evaluate_regression(model, test_loader, criterion, device)
                wandb.log({'val_loss': val_loss, 'step': step})
                model.train()

            if step % save_interval == 0:
                save_checkpoint(model, optimizer, epoch, step, model_name,
                                path=f'checkpoints-{model_name}/step_{step}.pth')

    print(f"Training complete for {model_name}!")
    return epoch, step


def train_heatmap(model, train_loader, test_loader, optimizer, device,
                  model_name='unet', num_epochs=10, eval_interval=10,
                  log_interval=5, save_interval=30, loss_type='mse',
                  scheduler_type='cosine'):
    """Training loop for heatmap-based models (U-Net).
    
    Args:
        loss_type: 'mse' (recommended) or 'bce'
        scheduler_type: 'cosine' (recommended) or 'plateau'
    """
    step = 0
    running_loss = 0
    best_val_loss = float('inf')

    if loss_type == 'mse':
        criterion = nn.MSELoss()
    else:
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(device))

    total_steps = num_epochs * len(train_loader)
    if scheduler_type == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
    else:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            step += 1
            images = batch['image'].to(device)
            heatmaps_gt = batch['heatmaps'].to(device)

            logits = model(images)

            if loss_type == 'mse':
                loss = criterion(torch.sigmoid(logits), heatmaps_gt)
            else:
                loss = criterion(logits, heatmaps_gt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if scheduler_type == 'cosine':
                scheduler.step()

            running_loss += loss.item()

            if step % log_interval == 0:
                avg_loss = running_loss / log_interval
                wandb.log({'train_loss': avg_loss, 'step': step, 'epoch': epoch,
                           'lr': optimizer.param_groups[0]['lr']})
                print(f"Epoch {epoch}, Step {step}: Train Loss = {avg_loss:.6f} (lr={optimizer.param_groups[0]['lr']:.6f})")
                running_loss = 0

            if step % eval_interval == 0:
                val_loss = evaluate_heatmap(model, test_loader, device, loss_type=loss_type)
                wandb.log({'val_loss': val_loss, 'step': step, 'epoch': epoch})
                print(f"Epoch {epoch}, Step {step}: Val Loss = {val_loss:.6f}")
                model.train()
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    save_checkpoint(model, optimizer, epoch, step, model_name,
                                    path=f'checkpoints-{model_name}/best.pth')
                    print(f"  -> New best val loss: {val_loss:.6f}")

            if step % save_interval == 0:
                save_checkpoint(model, optimizer, epoch, step, model_name,
                                path=f'checkpoints-{model_name}/step_{step}.pth')

        # End of epoch
        if scheduler_type == 'plateau':
            epoch_val_loss = evaluate_heatmap(model, test_loader, device, loss_type=loss_type)
            scheduler.step(epoch_val_loss)
            print(f"--- Epoch {epoch} done. Val Loss = {epoch_val_loss:.6f}, LR = {optimizer.param_groups[0]['lr']:.6f} ---")
        else:
            print(f"--- Epoch {epoch} done. LR = {optimizer.param_groups[0]['lr']:.6f} ---")
        model.train()

    print("Heatmap training complete!")
    return epoch, step


def visualize_keypoints(test_loader, model, device='cuda'):
    """Visualize predicted vs ground-truth keypoints on test samples."""
    model.eval()
    for i, data in enumerate(test_loader):
        image = data['image'][0]
        images = data['image']
        with torch.no_grad():
            images = images.to(device)
            predictions = model(images)
        if predictions.dim() == 4:
            predictions = heatmaps_to_keypoints(predictions)
        else:
            predictions = predictions.reshape(images.size(0), 68, 2)
        predictions = predictions.cpu().numpy()
        pred_kpts = predictions[0]
        gt_kpts = data['keypoints'][0].numpy()
        pred_kpts_denorm = (pred_kpts * 50) + 100
        gt_kpts_denorm = (gt_kpts * 50) + 100
        plt.figure(figsize=(10, 6))
        plt.imshow(image.numpy().transpose(1, 2, 0), cmap='gray')
        plt.scatter(pred_kpts_denorm[:, 0], pred_kpts_denorm[:, 1], c='r', s=20, label='Predicted')
        plt.scatter(gt_kpts_denorm[:, 0], gt_kpts_denorm[:, 1], c='g', s=20, label='Ground Truth')
        plt.legend()
        plt.title(f'Test Sample {i}')
        plt.show()
        if i >= 4:
            break

---
## 2. CNN

### 2.1 Model Definition

In [None]:
class CNN(nn.Module):
    def __init__(self, num_keypoints=68, dropout=0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, num_keypoints * 2)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), 2)
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))), 2)
        x = F.max_pool2d(F.relu(self.bn4(self.conv4(x))), 2)
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

    def compute_loss(self, preds, labels, criterion):
        if criterion == 'mse':
            loss_fn = nn.MSELoss()
        elif criterion == 'Smoothl1Loss':
            loss_fn = nn.SmoothL1Loss()
        else:
            raise ValueError(f'Unknown criterion: {criterion}')
        return loss_fn(preds, labels)

### 2.2 Train CNN

In [None]:
# --- Configuration ---
CNN_EPOCHS = 50
CNN_CRITERION = 'Smoothl1Loss'  # or 'mse'
CNN_LR = 1e-3
CNN_BATCH_SIZE = 64

# --- Initialize ---
cnn_model = CNN().to(device)
cnn_train_loader, cnn_test_loader = load_regression_data(batch_size=CNN_BATCH_SIZE)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=CNN_LR)

print(f"CNN parameters: {sum(p.numel() for p in cnn_model.parameters()):,}")

# --- Train ---
wandb.init(project='facial-keypoints', name='cnn-train', reinit=True)
wandb.config.update({'model': 'cnn', 'criterion': CNN_CRITERION, 'lr': CNN_LR})

cnn_epoch, cnn_step = train_regression(
    cnn_model, cnn_train_loader, cnn_test_loader, cnn_optimizer,
    CNN_CRITERION, device, model_name='cnn', num_epochs=CNN_EPOCHS
)

save_checkpoint(cnn_model, cnn_optimizer, cnn_epoch, cnn_step, 'cnn',
                path='checkpoints-cnn/last_checkpoint.pth')
wandb.finish()

### 2.3 Visualize CNN

In [None]:
visualize_keypoints(cnn_test_loader, cnn_model, device)

---
## 3. Transfer Learning

### 3.1 ResNet Model

In [None]:
class ResNetKeypointDetector(nn.Module):
    def __init__(self, num_keypoints=68, backbone='resnet18', pretrained=True, freeze_backbone=True):
        super().__init__()
        if backbone == 'resnet18':
            self.backbone = models.resnet18(pretrained=pretrained)
        elif backbone == 'resnet34':
            self.backbone = models.resnet34(pretrained=pretrained)

        # Adapt first conv for grayscale
        original_conv = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if pretrained:
            with torch.no_grad():
                self.backbone.conv1.weight[:, 0, :, :] = original_conv.weight[:, 0, :, :]

        self.backbone_out_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.regression_head = nn.Sequential(
            nn.Linear(self.backbone_out_features, 1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, num_keypoints * 2),
        )

        self._backbone_frozen = freeze_backbone
        if freeze_backbone:
            self._freeze_backbone()

    def _freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        self._backbone_frozen = True

    def _unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
        self._backbone_frozen = False

    def train(self, mode=True):
        """Override train() to keep frozen backbone in eval mode.
        This prevents BatchNorm from updating running stats on frozen features."""
        super().train(mode)
        if mode and self._backbone_frozen:
            self.backbone.eval()
        return self

    def forward(self, x):
        features = self.backbone(x)
        return self.regression_head(features)

    def compute_loss(self, preds, labels, criterion):
        if criterion == 'mse':
            loss_fn = nn.MSELoss()
        elif criterion == 'Smoothl1Loss':
            loss_fn = nn.SmoothL1Loss()
        else:
            raise ValueError(f'Unknown criterion: {criterion}')
        return loss_fn(preds, labels)

### 3.2 Train ResNet (Stage 1: Frozen Backbone)

In [None]:
# --- Configuration ---
RESNET_EPOCHS_FROZEN = 50
RESNET_CRITERION = 'Smoothl1Loss'
RESNET_LR_FROZEN = 5e-4
RESNET_BATCH_SIZE = 64

# --- Initialize ---
resnet_model = ResNetKeypointDetector(backbone='resnet18', pretrained=True, freeze_backbone=True).to(device)
resnet_train_loader, resnet_test_loader = load_regression_data(batch_size=RESNET_BATCH_SIZE)
resnet_optimizer = torch.optim.Adam(
    [p for p in resnet_model.parameters() if p.requires_grad], lr=RESNET_LR_FROZEN
)

print(f"ResNet parameters (total): {sum(p.numel() for p in resnet_model.parameters()):,}")
print(f"ResNet parameters (trainable): {sum(p.numel() for p in resnet_model.parameters() if p.requires_grad):,}")

# --- Train Stage 1: Frozen backbone ---
wandb.init(project='facial-keypoints', name='resnet-frozen', reinit=True)
wandb.config.update({'model': 'resnet', 'criterion': RESNET_CRITERION, 'freeze': True})

resnet_epoch, resnet_step = train_regression(
    resnet_model, resnet_train_loader, resnet_test_loader, resnet_optimizer,
    RESNET_CRITERION, device, model_name='resnet', num_epochs=RESNET_EPOCHS_FROZEN
)

save_checkpoint(resnet_model, resnet_optimizer, resnet_epoch, resnet_step, 'resnet',
                path='checkpoints-resnet/frozen_checkpoint.pth')
wandb.finish()
print("Stage 1 (frozen backbone) complete!")

### 3.3 Train ResNet (Stage 2: Fine-tune Full Model)

In [None]:
# --- Configuration ---
RESNET_EPOCHS_FINETUNE = 50
RESNET_LR_FINETUNE = 1e-4

# --- Unfreeze and set up new optimizer ---
resnet_model._unfreeze_backbone()
resnet_optimizer = torch.optim.Adam(resnet_model.parameters(), lr=RESNET_LR_FINETUNE)

print(f"ResNet parameters (trainable after unfreeze): {sum(p.numel() for p in resnet_model.parameters() if p.requires_grad):,}")

# --- Train Stage 2: Fine-tune ---
wandb.init(project='facial-keypoints', name='resnet-finetune', reinit=True)
wandb.config.update({'model': 'resnet', 'criterion': RESNET_CRITERION, 'freeze': False})

resnet_epoch, resnet_step = train_regression(
    resnet_model, resnet_train_loader, resnet_test_loader, resnet_optimizer,
    RESNET_CRITERION, device, model_name='resnet', num_epochs=RESNET_EPOCHS_FINETUNE
)

save_checkpoint(resnet_model, resnet_optimizer, resnet_epoch, resnet_step, 'resnet',
                path='checkpoints-resnet/last_checkpoint.pth')
wandb.finish()
print("Stage 2 (fine-tune) complete!")

### 3.4 Visualize ResNet

In [None]:
visualize_keypoints(resnet_test_loader, resnet_model, device)

### 3.5 DINO Model

In [9]:
class DINOKeypointDetector(nn.Module):
    def __init__(self, num_keypoints=68, model_name='vit_base_patch16_224.dino',
                 pretrained=True, freeze_backbone=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained)
        self.backbone_out_features = self.backbone.embed_dim
        self.backbone.head = nn.Identity()

        # ImageNet normalization (DINO was trained with this)
        self.register_buffer('img_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('img_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        # BatchNorm to amplify inter-sample feature differences
        self.feature_norm = nn.BatchNorm1d(self.backbone_out_features)

        self.regression_head = nn.Sequential(
            nn.Linear(self.backbone_out_features, 1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, num_keypoints * 2),
        )

        self._backbone_frozen = freeze_backbone
        if freeze_backbone:
            self._freeze_backbone()

    def _freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        self._backbone_frozen = True

    def _unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
        self._backbone_frozen = False

    def train(self, mode=True):
        """Keep frozen backbone in eval mode to disable DropPath."""
        super().train(mode)
        if mode and self._backbone_frozen:
            self.backbone.eval()
        return self

    def forward(self, x):
        # Grayscale → RGB
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)

        # Apply ImageNet normalization
        x = (x - self.img_mean) / self.img_std

        # Extract CLS token features (more discriminative than patch avg)
        features = self.backbone.forward_features(x)
        if len(features.shape) > 2:
            features = features[:, 0, :]  # CLS token

        # Normalize features — amplifies inter-sample differences
        features = self.feature_norm(features)

        return self.regression_head(features)

    def compute_loss(self, preds, labels, criterion):
        if criterion == 'mse':
            loss_fn = nn.MSELoss()
        elif criterion == 'Smoothl1Loss':
            loss_fn = nn.SmoothL1Loss()
        else:
            raise ValueError(f'Unknown criterion: {criterion}')
        return loss_fn(preds, labels)

### 3.6 Train DINO (Stage 1: Frozen Backbone)

In [None]:
# --- Configuration ---
DINO_EPOCHS_FROZEN = 50
DINO_CRITERION = 'Smoothl1Loss'
DINO_LR_FROZEN = 5e-4
DINO_BATCH_SIZE = 64

# --- Initialize ---
dino_model = DINOKeypointDetector(pretrained=True, freeze_backbone=True).to(device)
dino_train_loader, dino_test_loader = load_regression_data(batch_size=DINO_BATCH_SIZE)
dino_optimizer = torch.optim.Adam(
    [p for p in dino_model.parameters() if p.requires_grad], lr=DINO_LR_FROZEN
)

print(f"DINO parameters (total): {sum(p.numel() for p in dino_model.parameters()):,}")
print(f"DINO parameters (trainable): {sum(p.numel() for p in dino_model.parameters() if p.requires_grad):,}")

# --- Train Stage 1: Frozen backbone ---
wandb.init(project='facial-keypoints', name='dino-frozen', reinit=True)
wandb.config.update({'model': 'dino', 'criterion': DINO_CRITERION, 'freeze': True})

dino_epoch, dino_step = train_regression(
    dino_model, dino_train_loader, dino_test_loader, dino_optimizer,
    DINO_CRITERION, device, model_name='dino', num_epochs=DINO_EPOCHS_FROZEN
)

save_checkpoint(dino_model, dino_optimizer, dino_epoch, dino_step, 'dino',
                path='checkpoints-dino/frozen_checkpoint.pth')
wandb.finish()
print("DINO Stage 1 (frozen backbone) complete!")

### 3.7 Train DINO (Stage 2: Fine-tune Full Model)

In [None]:
# --- Configuration ---
DINO_EPOCHS_FINETUNE = 50
DINO_LR_FINETUNE = 1e-4

# --- Unfreeze and set up new optimizer ---
dino_model._unfreeze_backbone()
dino_optimizer = torch.optim.Adam(dino_model.parameters(), lr=DINO_LR_FINETUNE)

print(f"DINO parameters (trainable after unfreeze): {sum(p.numel() for p in dino_model.parameters() if p.requires_grad):,}")

# --- Train Stage 2: Fine-tune ---
wandb.init(project='facial-keypoints', name='dino-finetune', reinit=True)
wandb.config.update({'model': 'dino', 'criterion': DINO_CRITERION, 'freeze': False})

dino_epoch, dino_step = train_regression(
    dino_model, dino_train_loader, dino_test_loader, dino_optimizer,
    DINO_CRITERION, device, model_name='dino', num_epochs=DINO_EPOCHS_FINETUNE
)

save_checkpoint(dino_model, dino_optimizer, dino_epoch, dino_step, 'dino',
                path='checkpoints-dino/last_checkpoint.pth')
wandb.finish()
print("DINO Stage 2 (fine-tune) complete!")

### 3.8 Visualize DINO

In [None]:
visualize_keypoints(dino_test_loader, dino_model, device)

### 3.9 DEBUG: DINO 32-Sample Overfit Test
If the model can't overfit 32 samples, there's a bug in labels / transforms / optimizer / forward.

In [10]:
# --- VERIFY FIX: Run AFTER re-running DINO class cell (3.5) ---
overfit_transform = transforms.Compose([
    Rescale((224, 224)), NormalizeOriginal(), ToTensor()
])
overfit_dataset = FacialKeypointsDataset(
    'data/training_frames_keypoints.csv', 'data/training', overfit_transform
)
overfit_subset = torch.utils.data.Subset(overfit_dataset, range(32))
batch = next(iter(DataLoader(overfit_subset, batch_size=32, shuffle=False)))
images = batch['image'].to(device)
keypoints = batch['keypoints'].reshape(32, 136).float().to(device)

mean_mse = ((keypoints - keypoints.mean(0, keepdim=True)) ** 2).mean().item()
print(f"Mean-baseline MSE: {mean_mse:.6f}\n")

# Full model overfit test (ImageNet norm + CLS token + BatchNorm + MLP)
dino_overfit = DINOKeypointDetector(pretrained=True, freeze_backbone=True).to(device)

# Disable dropout for overfit test
for m in dino_overfit.regression_head.modules():
    if isinstance(m, nn.Dropout):
        m.p = 0.0

dino_overfit.train()  # backbone stays eval (via override), BN/MLP in train mode
opt = torch.optim.Adam([p for p in dino_overfit.parameters() if p.requires_grad], lr=1e-3)
loss_fn = nn.MSELoss()

losses = []
for step in range(1, 501):
    preds = dino_overfit(images)
    loss = loss_fn(preds, keypoints)
    opt.zero_grad(); loss.backward(); opt.step()
    losses.append(loss.item())
    if step in [1, 50, 100, 200, 300, 500]:
        print(f"Step {step:3d}: MSE = {loss.item():.6f}")

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step'); plt.ylabel('MSE'); plt.yscale('log')
plt.title('DINO overfit: ImageNet norm + CLS + BatchNorm1d + MLP')
plt.grid(True); plt.show()

print(f"\nFinal MSE: {losses[-1]:.6f}  (mean-baseline: {mean_mse:.6f})")
if losses[-1] < 1e-3:
    print("PASS! Pipeline works. Re-run DINO training cells.")
else:
    print(f"Improvement: {mean_mse/losses[-1]:.1f}x better than mean-baseline")

del dino_overfit

Mean-baseline MSE: 0.048850

Step   1: MSE = 0.364251
Step  50: MSE = 0.008731


KeyboardInterrupt: 

---
## 4. U-Net (Heatmap-based)

### 4.1 Model Definition

In [None]:
class DoubleConv(nn.Module):
    """Two consecutive conv-bn-relu blocks."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, num_keypoints=68, heatmap_size=64):
        super().__init__()
        self.heatmap_size = heatmap_size

        # Encoder
        self.enc1 = DoubleConv(in_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)
        self.pool = nn.MaxPool2d(2)
        self.dropout_enc = nn.Dropout2d(p=0.3)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512)
        self.dropout_bottleneck = nn.Dropout2d(p=0.3)
        self.dropout_dec = nn.Dropout2d(p=0.2)

        # Decoder
        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(128, 64)
        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(64, 32)

        self.out_conv = nn.Conv2d(32, num_keypoints, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.dropout_enc(self.enc4(self.pool(e3)))

        b = self.dropout_bottleneck(self.bottleneck(self.pool(e4)))

        d4 = self.dropout_dec(self.dec4(torch.cat([self.up4(b), e4], dim=1)))
        d3 = self.dropout_dec(self.dec3(torch.cat([self.up3(d4), e3], dim=1)))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        out = self.out_conv(d1)
        out = F.interpolate(out, size=self.heatmap_size, mode='bilinear', align_corners=False)
        return out

### 4.2 Train U-Net

In [None]:
# --- Configuration ---
# KEY FIXES vs original:
#   1. batch_size 256->32 (gives ~90 steps/epoch instead of ~11)
#   2. MSE loss instead of BCE (BCE+pos_weight causes vertical stripe collapse)
#   3. Cosine annealing scheduler (smooth LR decay)
#   4. weight_decay 1e-4 (stronger regularization)
UNET_EPOCHS = 50
UNET_LR = 3e-4
UNET_BATCH_SIZE = 32
UNET_LOSS = 'mse'           # 'mse' or 'bce'
UNET_SCHEDULER = 'cosine'   # 'cosine' or 'plateau'

# --- Initialize ---
unet_model = UNet(in_channels=1, num_keypoints=68, heatmap_size=64).to(device)
unet_train_loader, unet_test_loader = load_heatmap_data(batch_size=UNET_BATCH_SIZE)
unet_optimizer = torch.optim.Adam(unet_model.parameters(), lr=UNET_LR, weight_decay=1e-4)

print(f"U-Net parameters: {sum(p.numel() for p in unet_model.parameters()):,}")
print(f"Batches per epoch: {len(unet_train_loader)}")
print(f"Total steps: {UNET_EPOCHS * len(unet_train_loader)}")
print(f"Loss: {UNET_LOSS}, Scheduler: {UNET_SCHEDULER}")

# --- Train ---
wandb.init(project='facial-keypoints', name=f'unet-{UNET_LOSS}', reinit=True)
wandb.config.update({'model': 'unet', 'criterion': UNET_LOSS, 'lr': UNET_LR,
                      'batch_size': UNET_BATCH_SIZE, 'scheduler': UNET_SCHEDULER})

unet_epoch, unet_step = train_heatmap(
    unet_model, unet_train_loader, unet_test_loader, unet_optimizer,
    device, model_name='unet', num_epochs=UNET_EPOCHS,
    loss_type=UNET_LOSS, scheduler_type=UNET_SCHEDULER
)

save_checkpoint(unet_model, unet_optimizer, unet_epoch, unet_step, 'unet',
                path='checkpoints-unet/last_checkpoint.pth')
wandb.finish()

### 4.3 Visualize U-Net

In [None]:
visualize_keypoints(unet_test_loader, unet_model, device)

### 4.4 DEBUG: U-Net Post-Training Diagnostics
Run this after training to diagnose **overfitting**, **keypoint collapse**, and **heatmap quality**.
Symptoms: val loss plateaus at ~0.2, train keeps improving, predictions cluster to a few face points.

In [None]:
# =====================================================================
# U-NET POST-TRAINING DIAGNOSTICS
# =====================================================================
# Requires: unet_model (trained), unet_test_loader already defined

unet_model.eval()

# Use deterministic test loader for diagnostics
diag_transform = transforms.Compose([Rescale((224, 224)), NormalizeOriginal(), ToTensor()])
diag_dataset = FacialKeypointsHeatmapDataset(
    'data/test_frames_keypoints.csv', 'data/test',
    transform=diag_transform, output_size=64, sigma=2, image_size=224)
diag_loader = DataLoader(diag_dataset, batch_size=32, shuffle=False)

# ---- 1. Collect predictions on full test set ----
print("=" * 60)
print("1. HEATMAP & KEYPOINT STATISTICS")
print("=" * 60)

all_pred_peaks = []
all_pixel_errors = []
all_pred_kpts_px = []
all_gt_kpts_px = []

with torch.no_grad():
    for batch in diag_loader:
        images = batch['image'].to(device)
        logits = unet_model(images)
        probs = torch.sigmoid(logits)

        # Peak confidence per keypoint
        peaks = probs.view(probs.size(0), 68, -1).max(dim=-1)[0]
        all_pred_peaks.append(peaks.cpu())

        # Keypoint coords and errors
        pred_kpts = heatmaps_to_keypoints(logits).cpu()
        gt_kpts = batch['keypoints']
        pred_px = pred_kpts * 50 + 100
        gt_px = gt_kpts * 50 + 100
        errors = torch.sqrt(((pred_px - gt_px) ** 2).sum(dim=-1))

        all_pixel_errors.append(errors)
        all_pred_kpts_px.append(pred_px)
        all_gt_kpts_px.append(gt_px)

pred_peaks = torch.cat(all_pred_peaks, dim=0)      # (N, 68)
pixel_errors = torch.cat(all_pixel_errors, dim=0)   # (N, 68)
pred_kpts_px = torch.cat(all_pred_kpts_px, dim=0)   # (N, 68, 2)
gt_kpts_px = torch.cat(all_gt_kpts_px, dim=0)       # (N, 68, 2)

print(f"Test samples analyzed: {pred_peaks.size(0)}")
print(f"\nPredicted heatmap peak confidence:")
print(f"  Mean: {pred_peaks.mean():.4f}, Std: {pred_peaks.std():.4f}")
print(f"  Min across keypoints: {pred_peaks.mean(0).min():.4f} (kpt {pred_peaks.mean(0).argmin().item()})")
print(f"  Max across keypoints: {pred_peaks.mean(0).max():.4f} (kpt {pred_peaks.mean(0).argmax().item()})")
print(f"\nPixel error (test set):")
print(f"  Mean:   {pixel_errors.mean():.2f} px")
print(f"  Median: {pixel_errors.median():.2f} px")
print(f"  90th %: {pixel_errors.quantile(0.9):.2f} px")

# NME (Normalized Mean Error) — normalize by inter-ocular distance
# Keypoints 36-41 = left eye, 42-47 = right eye
left_eye_center = gt_kpts_px[:, 36:42, :].mean(dim=1)   # (N, 2)
right_eye_center = gt_kpts_px[:, 42:48, :].mean(dim=1)
iod = torch.sqrt(((left_eye_center - right_eye_center) ** 2).sum(dim=-1))  # (N,)
nme = (pixel_errors.mean(dim=1) / iod).mean().item()
print(f"  NME (inter-ocular): {nme:.4f} ({nme*100:.2f}%)")

# ---- 2. Keypoint collapse detection ----
print("\n" + "=" * 60)
print("2. KEYPOINT COLLAPSE DETECTION")
print("=" * 60)

# For each sample, measure the spread of predicted keypoints
pred_spread_x = pred_kpts_px[:, :, 0].std(dim=1)  # (N,)
pred_spread_y = pred_kpts_px[:, :, 1].std(dim=1)
gt_spread_x = gt_kpts_px[:, :, 0].std(dim=1)
gt_spread_y = gt_kpts_px[:, :, 1].std(dim=1)

print(f"Predicted keypoint spread (std across 68 kpts):")
print(f"  X: {pred_spread_x.mean():.1f} px (GT: {gt_spread_x.mean():.1f} px)")
print(f"  Y: {pred_spread_y.mean():.1f} px (GT: {gt_spread_y.mean():.1f} px)")
print(f"  Ratio (pred/GT): X={pred_spread_x.mean()/gt_spread_x.mean():.2f}, Y={pred_spread_y.mean()/gt_spread_y.mean():.2f}")

if pred_spread_x.mean() < gt_spread_x.mean() * 0.5 or pred_spread_y.mean() < gt_spread_y.mean() * 0.5:
    print("  >>> KEYPOINT COLLAPSE DETECTED: predictions cluster too tightly!")
    print("      Model is predicting ~same location for many keypoints.")
else:
    print("  Spread looks reasonable (no collapse).")

# Check per-sample: how many unique argmax positions?
with torch.no_grad():
    sample_batch = next(iter(diag_loader))
    sample_logits = unet_model(sample_batch['image'][:4].to(device))
    sample_flat = sample_logits.view(4, 68, -1)
    sample_argmax = sample_flat.argmax(dim=-1)  # (4, 68)

for s_idx in range(4):
    unique_positions = sample_argmax[s_idx].unique().numel()
    print(f"  Sample {s_idx}: {unique_positions}/68 unique argmax positions", end="")
    if unique_positions < 40:
        print(" << MANY KEYPOINTS SHARE SAME PEAK!")
    else:
        print(" (OK)")

# ---- 3. Visual: summed heatmaps (collapse = single blob) ----
print("\n" + "=" * 60)
print("3. HEATMAP VISUALIZATIONS")
print("=" * 60)

with torch.no_grad():
    vis_batch = next(iter(diag_loader))
    vis_images = vis_batch['image'][:4].to(device)
    vis_logits = unet_model(vis_images)
    vis_probs = torch.sigmoid(vis_logits).cpu()
    vis_gt = vis_batch['heatmaps'][:4]

fig, axes = plt.subplots(3, 4, figsize=(20, 15))
for col in range(4):
    img = vis_images[col, 0].cpu().numpy()
    summed_pred = vis_probs[col].sum(dim=0).numpy()
    summed_gt = vis_gt[col].sum(dim=0).numpy()

    axes[0, col].imshow(img, cmap='gray')
    axes[0, col].set_title(f'Input {col}')
    axes[0, col].axis('off')

    axes[1, col].imshow(summed_pred, cmap='hot')
    axes[1, col].set_title(f'Sum pred (max={summed_pred.max():.1f})')
    axes[1, col].axis('off')

    axes[2, col].imshow(summed_gt, cmap='hot')
    axes[2, col].set_title(f'Sum GT (max={summed_gt.max():.1f})')
    axes[2, col].axis('off')

axes[0, 0].set_ylabel('Image', fontsize=12)
axes[1, 0].set_ylabel('Pred (sum 68)', fontsize=12)
axes[2, 0].set_ylabel('GT (sum 68)', fontsize=12)
plt.suptitle('Collapse check: summed pred should spread like summed GT\n'
             '(If pred is a single blob = keypoint collapse)', fontsize=14)
plt.tight_layout()
plt.show()

# ---- 4. Per-keypoint heatmaps for a single sample ----
# Show 10 representative keypoints: pred vs GT side by side
fig, axes = plt.subplots(2, 10, figsize=(25, 5))
kpt_indices = [0, 8, 16, 27, 30, 33, 36, 42, 48, 57]  # jawline, brow, nose, eyes, mouth
for col_idx, kpt_idx in enumerate(kpt_indices):
    axes[0, col_idx].imshow(vis_gt[0, kpt_idx].numpy(), cmap='hot', vmin=0, vmax=1)
    axes[0, col_idx].set_title(f'GT kpt {kpt_idx}', fontsize=8)
    axes[0, col_idx].axis('off')
    axes[1, col_idx].imshow(vis_probs[0, kpt_idx].numpy(), cmap='hot', vmin=0, vmax=1)
    axes[1, col_idx].set_title(f'Pred kpt {kpt_idx}', fontsize=8)
    axes[1, col_idx].axis('off')
axes[0, 0].set_ylabel('GT', fontsize=11)
axes[1, 0].set_ylabel('Pred', fontsize=11)
plt.suptitle('Individual keypoint heatmaps (sample 0) — pred should match GT location', fontsize=13)
plt.tight_layout()
plt.show()

# ---- 5. Per-keypoint error bar chart ----
print("\n" + "=" * 60)
print("4. PER-KEYPOINT ERROR BREAKDOWN")
print("=" * 60)

mean_per_kpt = pixel_errors.mean(dim=0)  # (68,)
worst_5 = mean_per_kpt.argsort(descending=True)[:5]
best_5 = mean_per_kpt.argsort()[:5]

kpt_names = {0: 'jaw-R', 8: 'chin', 16: 'jaw-L', 17: 'brow-R-out', 21: 'brow-R-in',
             22: 'brow-L-in', 26: 'brow-L-out', 27: 'nose-top', 30: 'nose-tip',
             36: 'eye-R-out', 39: 'eye-R-in', 42: 'eye-L-in', 45: 'eye-L-out',
             48: 'mouth-R', 54: 'mouth-L', 51: 'lip-top', 57: 'lip-bottom', 62: 'lip-inner-top'}

print("Best 5 keypoints:")
for idx in best_5:
    name = kpt_names.get(idx.item(), '')
    print(f"  Kpt {idx.item():2d} {name:15s}: {mean_per_kpt[idx]:.2f} px")
print("Worst 5 keypoints:")
for idx in worst_5:
    name = kpt_names.get(idx.item(), '')
    print(f"  Kpt {idx.item():2d} {name:15s}: {mean_per_kpt[idx]:.2f} px")

plt.figure(figsize=(16, 4))
colors = ['#e74c3c' if e > mean_per_kpt.mean() + mean_per_kpt.std() else '#3498db'
          for e in mean_per_kpt]
plt.bar(range(68), mean_per_kpt.numpy(), color=colors)
plt.axhline(y=mean_per_kpt.mean(), color='k', linestyle='--', alpha=0.5, label=f'Mean={mean_per_kpt.mean():.1f}px')
plt.xlabel('Keypoint index')
plt.ylabel('Mean pixel error')
plt.title('Per-keypoint error on test set (red = >1 std above mean)')
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

# ---- 6. Overlay predictions on face ----
print("\n" + "=" * 60)
print("5. PREDICTION OVERLAY")
print("=" * 60)

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
vis_pred_kpts = heatmaps_to_keypoints(vis_logits.cpu()).numpy()
vis_gt_kpts = vis_batch['keypoints'][:4].numpy()

for i in range(4):
    img = vis_images[i, 0].cpu().numpy()
    pred = vis_pred_kpts[i] * 50 + 100
    gt = vis_gt_kpts[i] * 50 + 100

    axes[0, i].imshow(img, cmap='gray')
    axes[0, i].scatter(pred[:, 0], pred[:, 1], c='r', s=15, alpha=0.8, zorder=5)
    axes[0, i].scatter(gt[:, 0], gt[:, 1], c='g', s=15, alpha=0.5, marker='x', zorder=4)
    axes[0, i].set_title(f'Sample {i}: Pred(red) vs GT(green)')
    axes[0, i].axis('off')

    # Per-keypoint error on this sample
    errs = np.sqrt(((pred - gt) ** 2).sum(axis=-1))
    axes[1, i].imshow(img, cmap='gray')
    sc = axes[1, i].scatter(pred[:, 0], pred[:, 1], c=errs, cmap='RdYlGn_r',
                            s=25, vmin=0, vmax=30, zorder=5)
    # Draw lines from pred to GT for worst errors
    worst_10 = errs.argsort()[-10:]
    for w in worst_10:
        axes[1, i].plot([pred[w, 0], gt[w, 0]], [pred[w, 1], gt[w, 1]], 'w-', alpha=0.6, linewidth=0.8)
    axes[1, i].set_title(f'Error map (mean={errs.mean():.1f}px)')
    axes[1, i].axis('off')

plt.colorbar(sc, ax=axes[1, :], label='Pixel error', shrink=0.8)
plt.suptitle('Top: predictions overlay | Bottom: error heatmap (white lines = worst 10)', fontsize=14)
plt.tight_layout()
plt.show()

# ---- 7. Data augmentation sanity check ----
print("\n" + "=" * 60)
print("6. AUGMENTATION SANITY CHECK")
print("=" * 60)
print("Checking that augmented training samples have aligned heatmaps...")

train_transform_debug = transforms.Compose([
    Rescale(250), RandomCrop(224), RandomHorizontalFlip(), RandomRotate(15),
    ColorJitter(), NormalizeOriginal(), ToTensor()
])
debug_hm_dataset = FacialKeypointsHeatmapDataset(
    'data/training_frames_keypoints.csv', 'data/training',
    transform=train_transform_debug, output_size=64, sigma=2, image_size=224)

fig, axes = plt.subplots(2, 4, figsize=(20, 5))
for col in range(4):
    sample = debug_hm_dataset[col]
    img = sample['image'][0].numpy()  # (224, 224)
    hm_sum = sample['heatmaps'].sum(dim=0).numpy()  # (64, 64)
    kpts = sample['keypoints'].numpy() * 50 + 100  # (68, 2) in pixel space

    axes[0, col].imshow(img, cmap='gray')
    axes[0, col].scatter(kpts[:, 0], kpts[:, 1], c='lime', s=8, alpha=0.7)
    axes[0, col].set_title(f'Aug sample {col}')
    axes[0, col].axis('off')

    axes[1, col].imshow(hm_sum, cmap='hot')
    # Overlay expected keypoint positions on heatmap
    kpts_hm = kpts * (64 / 224)
    axes[1, col].scatter(kpts_hm[:, 0], kpts_hm[:, 1], c='lime', s=8, alpha=0.5)
    axes[1, col].set_title(f'Summed heatmap + kpts')
    axes[1, col].axis('off')

plt.suptitle('Augmentation check: green dots should align with heatmap peaks\n'
             '(Misalignment = augmentation bug)', fontsize=13)
plt.tight_layout()
plt.show()

# ---- DIAGNOSIS SUMMARY ----
print("\n" + "=" * 60)
print("DIAGNOSIS SUMMARY")
print("=" * 60)

mean_err = pixel_errors.mean().item()
spread_ratio_x = pred_spread_x.mean() / gt_spread_x.mean()
spread_ratio_y = pred_spread_y.mean() / gt_spread_y.mean()
peak_conf = pred_peaks.mean().item()

issues = []
if spread_ratio_x < 0.5 or spread_ratio_y < 0.5:
    issues.append(f"KEYPOINT COLLAPSE: spread ratio = ({spread_ratio_x:.2f}x, {spread_ratio_y:.2f}y)")
if peak_conf < 0.3:
    issues.append(f"LOW CONFIDENCE: peak = {peak_conf:.3f} (heatmaps are too diffuse)")
if peak_conf > 0.95:
    issues.append(f"OVERCONFIDENT: peak = {peak_conf:.3f} (may be memorizing noise)")
if mean_err > 15:
    issues.append(f"HIGH ERROR: {mean_err:.1f} px mean (poor generalization)")
if nme > 0.10:
    issues.append(f"HIGH NME: {nme*100:.2f}% (>10% is poor)")

if issues:
    print("Issues found:")
    for issue in issues:
        print(f"  - {issue}")
    print("\nRecommended fixes (try in order):")
    print("  1. SMALLER BATCH SIZE: 256 -> 32 or 64 (more gradient updates/epoch)")
    print("  2. STRONGER REGULARIZATION: increase dropout 0.3->0.5, weight_decay 1e-5->1e-4")
    print("  3. MSE LOSS: replace BCE with MSE (less prone to false-positive activation)")
    print("  4. LOWER LR: start at 3e-4 with cosine annealing to 1e-6")
    print("  5. FEWER CHANNELS: reduce encoder to 16-32-64-128 (less overfitting)")
    print("  6. LARGER SIGMA: sigma=2->3 in heatmap generation (easier targets)")
    print("  7. EARLY STOPPING: save best val_loss model, stop if no improvement for 15 epochs")
else:
    print(f"Model looks healthy! Mean error: {mean_err:.1f} px, NME: {nme*100:.2f}%, Peak: {peak_conf:.3f}")

### 4.4 DEBUG: U-Net 8-Sample Overfit Test
If the U-Net can't overfit a tiny batch of heatmaps, there's a bug in the heatmap generation, model architecture, or loss.

In [None]:

# =====================================================================
# U-NET POST-TRAINING DIAGNOSTICS
# =====================================================================
# Requires: unet_model (trained), unet_test_loader already defined

unet_model.eval()

# Use deterministic test loader for diagnostics
diag_transform = transforms.Compose([Rescale((224, 224)), NormalizeOriginal(), ToTensor()])
diag_dataset = FacialKeypointsHeatmapDataset(
    'data/test_frames_keypoints.csv', 'data/test',
    transform=diag_transform, output_size=64, sigma=2, image_size=224)
diag_loader = DataLoader(diag_dataset, batch_size=32, shuffle=False)

# ---- 1. Collect predictions on full test set ----
print("=" * 60)
print("1. HEATMAP & KEYPOINT STATISTICS")
print("=" * 60)

all_pred_peaks = []
all_pixel_errors = []
all_pred_kpts_px = []
all_gt_kpts_px = []

with torch.no_grad():
    for batch in diag_loader:
        images = batch['image'].to(device)
        logits = unet_model(images)
        probs = torch.sigmoid(logits)

        # Peak confidence per keypoint
        peaks = probs.view(probs.size(0), 68, -1).max(dim=-1)[0]
        all_pred_peaks.append(peaks.cpu())

        # Keypoint coords and errors
        pred_kpts = heatmaps_to_keypoints(logits).cpu()
        gt_kpts = batch['keypoints']
        pred_px = pred_kpts * 50 + 100
        gt_px = gt_kpts * 50 + 100
        errors = torch.sqrt(((pred_px - gt_px) ** 2).sum(dim=-1))

        all_pixel_errors.append(errors)
        all_pred_kpts_px.append(pred_px)
        all_gt_kpts_px.append(gt_px)

pred_peaks = torch.cat(all_pred_peaks, dim=0)      # (N, 68)
pixel_errors = torch.cat(all_pixel_errors, dim=0)   # (N, 68)
pred_kpts_px = torch.cat(all_pred_kpts_px, dim=0)   # (N, 68, 2)
gt_kpts_px = torch.cat(all_gt_kpts_px, dim=0)       # (N, 68, 2)

print(f"Test samples analyzed: {pred_peaks.size(0)}")
print(f"\nPredicted heatmap peak confidence:")
print(f"  Mean: {pred_peaks.mean():.4f}, Std: {pred_peaks.std():.4f}")
print(f"  Min across keypoints: {pred_peaks.mean(0).min():.4f} (kpt {pred_peaks.mean(0).argmin().item()})")
print(f"  Max across keypoints: {pred_peaks.mean(0).max():.4f} (kpt {pred_peaks.mean(0).argmax().item()})")
print(f"\nPixel error (test set):")
print(f"  Mean:   {pixel_errors.mean():.2f} px")
print(f"  Median: {pixel_errors.median():.2f} px")
print(f"  90th %: {pixel_errors.quantile(0.9):.2f} px")

# NME (Normalized Mean Error) — normalize by inter-ocular distance
# Keypoints 36-41 = left eye, 42-47 = right eye
left_eye_center = gt_kpts_px[:, 36:42, :].mean(dim=1)   # (N, 2)
right_eye_center = gt_kpts_px[:, 42:48, :].mean(dim=1)
iod = torch.sqrt(((left_eye_center - right_eye_center) ** 2).sum(dim=-1))  # (N,)
nme = (pixel_errors.mean(dim=1) / iod).mean().item()
print(f"  NME (inter-ocular): {nme:.4f} ({nme*100:.2f}%)")

# ---- 2. Keypoint collapse detection ----
print("\n" + "=" * 60)
print("2. KEYPOINT COLLAPSE DETECTION")
print("=" * 60)

# For each sample, measure the spread of predicted keypoints
pred_spread_x = pred_kpts_px[:, :, 0].std(dim=1)  # (N,)
pred_spread_y = pred_kpts_px[:, :, 1].std(dim=1)
gt_spread_x = gt_kpts_px[:, :, 0].std(dim=1)
gt_spread_y = gt_kpts_px[:, :, 1].std(dim=1)

print(f"Predicted keypoint spread (std across 68 kpts):")
print(f"  X: {pred_spread_x.mean():.1f} px (GT: {gt_spread_x.mean():.1f} px)")
print(f"  Y: {pred_spread_y.mean():.1f} px (GT: {gt_spread_y.mean():.1f} px)")
print(f"  Ratio (pred/GT): X={pred_spread_x.mean()/gt_spread_x.mean():.2f}, Y={pred_spread_y.mean()/gt_spread_y.mean():.2f}")

if pred_spread_x.mean() < gt_spread_x.mean() * 0.5 or pred_spread_y.mean() < gt_spread_y.mean() * 0.5:
    print("  >>> KEYPOINT COLLAPSE DETECTED: predictions cluster too tightly!")
    print("      Model is predicting ~same location for many keypoints.")
else:
    print("  Spread looks reasonable (no collapse).")

# Check per-sample: how many unique argmax positions?
with torch.no_grad():
    sample_batch = next(iter(diag_loader))
    sample_logits = unet_model(sample_batch['image'][:4].to(device))
    sample_flat = sample_logits.view(4, 68, -1)
    sample_argmax = sample_flat.argmax(dim=-1)  # (4, 68)

for s_idx in range(4):
    unique_positions = sample_argmax[s_idx].unique().numel()
    print(f"  Sample {s_idx}: {unique_positions}/68 unique argmax positions", end="")
    if unique_positions < 40:
        print(" << MANY KEYPOINTS SHARE SAME PEAK!")
    else:
        print(" (OK)")

# ---- 3. Visual: summed heatmaps (collapse = single blob) ----
print("\n" + "=" * 60)
print("3. HEATMAP VISUALIZATIONS")
print("=" * 60)

with torch.no_grad():
    vis_batch = next(iter(diag_loader))
    vis_images = vis_batch['image'][:4].to(device)
    vis_logits = unet_model(vis_images)
    vis_probs = torch.sigmoid(vis_logits).cpu()
    vis_gt = vis_batch['heatmaps'][:4]

fig, axes = plt.subplots(3, 4, figsize=(20, 15))
for col in range(4):
    img = vis_images[col, 0].cpu().numpy()
    summed_pred = vis_probs[col].sum(dim=0).numpy()
    summed_gt = vis_gt[col].sum(dim=0).numpy()

    axes[0, col].imshow(img, cmap='gray')
    axes[0, col].set_title(f'Input {col}')
    axes[0, col].axis('off')

    axes[1, col].imshow(summed_pred, cmap='hot')
    axes[1, col].set_title(f'Sum pred (max={summed_pred.max():.1f})')
    axes[1, col].axis('off')

    axes[2, col].imshow(summed_gt, cmap='hot')
    axes[2, col].set_title(f'Sum GT (max={summed_gt.max():.1f})')
    axes[2, col].axis('off')

axes[0, 0].set_ylabel('Image', fontsize=12)
axes[1, 0].set_ylabel('Pred (sum 68)', fontsize=12)
axes[2, 0].set_ylabel('GT (sum 68)', fontsize=12)
plt.suptitle('Collapse check: summed pred should spread like summed GT\n'
             '(If pred is a single blob = keypoint collapse)', fontsize=14)
plt.tight_layout()
plt.show()

# ---- 4. Per-keypoint heatmaps for a single sample ----
# Show 10 representative keypoints: pred vs GT side by side
fig, axes = plt.subplots(2, 10, figsize=(25, 5))
kpt_indices = [0, 8, 16, 27, 30, 33, 36, 42, 48, 57]  # jawline, brow, nose, eyes, mouth
for col_idx, kpt_idx in enumerate(kpt_indices):
    axes[0, col_idx].imshow(vis_gt[0, kpt_idx].numpy(), cmap='hot', vmin=0, vmax=1)
    axes[0, col_idx].set_title(f'GT kpt {kpt_idx}', fontsize=8)
    axes[0, col_idx].axis('off')
    axes[1, col_idx].imshow(vis_probs[0, kpt_idx].numpy(), cmap='hot', vmin=0, vmax=1)
    axes[1, col_idx].set_title(f'Pred kpt {kpt_idx}', fontsize=8)
    axes[1, col_idx].axis('off')
axes[0, 0].set_ylabel('GT', fontsize=11)
axes[1, 0].set_ylabel('Pred', fontsize=11)
plt.suptitle('Individual keypoint heatmaps (sample 0) — pred should match GT location', fontsize=13)
plt.tight_layout()
plt.show()

# ---- 5. Per-keypoint error bar chart ----
print("\n" + "=" * 60)
print("4. PER-KEYPOINT ERROR BREAKDOWN")
print("=" * 60)

mean_per_kpt = pixel_errors.mean(dim=0)  # (68,)
worst_5 = mean_per_kpt.argsort(descending=True)[:5]
best_5 = mean_per_kpt.argsort()[:5]

kpt_names = {0: 'jaw-R', 8: 'chin', 16: 'jaw-L', 17: 'brow-R-out', 21: 'brow-R-in',
             22: 'brow-L-in', 26: 'brow-L-out', 27: 'nose-top', 30: 'nose-tip',
             36: 'eye-R-out', 39: 'eye-R-in', 42: 'eye-L-in', 45: 'eye-L-out',
             48: 'mouth-R', 54: 'mouth-L', 51: 'lip-top', 57: 'lip-bottom', 62: 'lip-inner-top'}

print("Best 5 keypoints:")
for idx in best_5:
    name = kpt_names.get(idx.item(), '')
    print(f"  Kpt {idx.item():2d} {name:15s}: {mean_per_kpt[idx]:.2f} px")
print("Worst 5 keypoints:")
for idx in worst_5:
    name = kpt_names.get(idx.item(), '')
    print(f"  Kpt {idx.item():2d} {name:15s}: {mean_per_kpt[idx]:.2f} px")

plt.figure(figsize=(16, 4))
colors = ['#e74c3c' if e > mean_per_kpt.mean() + mean_per_kpt.std() else '#3498db'
          for e in mean_per_kpt]
plt.bar(range(68), mean_per_kpt.numpy(), color=colors)
plt.axhline(y=mean_per_kpt.mean(), color='k', linestyle='--', alpha=0.5, label=f'Mean={mean_per_kpt.mean():.1f}px')
plt.xlabel('Keypoint index')
plt.ylabel('Mean pixel error')
plt.title('Per-keypoint error on test set (red = >1 std above mean)')
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

# ---- 6. Overlay predictions on face ----
print("\n" + "=" * 60)
print("5. PREDICTION OVERLAY")
print("=" * 60)

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
vis_pred_kpts = heatmaps_to_keypoints(vis_logits.cpu()).numpy()
vis_gt_kpts = vis_batch['keypoints'][:4].numpy()

for i in range(4):
    img = vis_images[i, 0].cpu().numpy()
    pred = vis_pred_kpts[i] * 50 + 100
    gt = vis_gt_kpts[i] * 50 + 100

    axes[0, i].imshow(img, cmap='gray')
    axes[0, i].scatter(pred[:, 0], pred[:, 1], c='r', s=15, alpha=0.8, zorder=5)
    axes[0, i].scatter(gt[:, 0], gt[:, 1], c='g', s=15, alpha=0.5, marker='x', zorder=4)
    axes[0, i].set_title(f'Sample {i}: Pred(red) vs GT(green)')
    axes[0, i].axis('off')

    # Per-keypoint error on this sample
    errs = np.sqrt(((pred - gt) ** 2).sum(axis=-1))
    axes[1, i].imshow(img, cmap='gray')
    sc = axes[1, i].scatter(pred[:, 0], pred[:, 1], c=errs, cmap='RdYlGn_r',
                            s=25, vmin=0, vmax=30, zorder=5)
    # Draw lines from pred to GT for worst errors
    worst_10 = errs.argsort()[-10:]
    for w in worst_10:
        axes[1, i].plot([pred[w, 0], gt[w, 0]], [pred[w, 1], gt[w, 1]], 'w-', alpha=0.6, linewidth=0.8)
    axes[1, i].set_title(f'Error map (mean={errs.mean():.1f}px)')
    axes[1, i].axis('off')

plt.colorbar(sc, ax=axes[1, :], label='Pixel error', shrink=0.8)
plt.suptitle('Top: predictions overlay | Bottom: error heatmap (white lines = worst 10)', fontsize=14)
plt.tight_layout()
plt.show()

# ---- 7. Data augmentation sanity check ----
print("\n" + "=" * 60)
print("6. AUGMENTATION SANITY CHECK")
print("=" * 60)
print("Checking that augmented training samples have aligned heatmaps...")

train_transform_debug = transforms.Compose([
    Rescale(250), RandomCrop(224), RandomHorizontalFlip(), RandomRotate(15),
    ColorJitter(), NormalizeOriginal(), ToTensor()
])
debug_hm_dataset = FacialKeypointsHeatmapDataset(
    'data/training_frames_keypoints.csv', 'data/training',
    transform=train_transform_debug, output_size=64, sigma=2, image_size=224)

fig, axes = plt.subplots(2, 4, figsize=(20, 5))
for col in range(4):
    sample = debug_hm_dataset[col]
    img = sample['image'][0].numpy()  # (224, 224)
    hm_sum = sample['heatmaps'].sum(dim=0).numpy()  # (64, 64)
    kpts = sample['keypoints'].numpy() * 50 + 100  # (68, 2) in pixel space

    axes[0, col].imshow(img, cmap='gray')
    axes[0, col].scatter(kpts[:, 0], kpts[:, 1], c='lime', s=8, alpha=0.7)
    axes[0, col].set_title(f'Aug sample {col}')
    axes[0, col].axis('off')

    axes[1, col].imshow(hm_sum, cmap='hot')
    # Overlay expected keypoint positions on heatmap
    kpts_hm = kpts * (64 / 224)
    axes[1, col].scatter(kpts_hm[:, 0], kpts_hm[:, 1], c='lime', s=8, alpha=0.5)
    axes[1, col].set_title(f'Summed heatmap + kpts')
    axes[1, col].axis('off')

plt.suptitle('Augmentation check: green dots should align with heatmap peaks\n'
             '(Misalignment = augmentation bug)', fontsize=13)
plt.tight_layout()
plt.show()

# ---- DIAGNOSIS SUMMARY ----
print("\n" + "=" * 60)
print("DIAGNOSIS SUMMARY")
print("=" * 60)

mean_err = pixel_errors.mean().item()
spread_ratio_x = pred_spread_x.mean() / gt_spread_x.mean()
spread_ratio_y = pred_spread_y.mean() / gt_spread_y.mean()
peak_conf = pred_peaks.mean().item()

issues = []
if spread_ratio_x < 0.5 or spread_ratio_y < 0.5:
    issues.append(f"KEYPOINT COLLAPSE: spread ratio = ({spread_ratio_x:.2f}x, {spread_ratio_y:.2f}y)")
if peak_conf < 0.3:
    issues.append(f"LOW CONFIDENCE: peak = {peak_conf:.3f} (heatmaps are too diffuse)")
if peak_conf > 0.95:
    issues.append(f"OVERCONFIDENT: peak = {peak_conf:.3f} (may be memorizing noise)")
if mean_err > 15:
    issues.append(f"HIGH ERROR: {mean_err:.1f} px mean (poor generalization)")
if nme > 0.10:
    issues.append(f"HIGH NME: {nme*100:.2f}% (>10% is poor)")

if issues:
    print("Issues found:")
    for issue in issues:
        print(f"  - {issue}")
    print("\nRecommended fixes (try in order):")
    print("  1. SMALLER BATCH SIZE: 256 -> 32 or 64 (more gradient updates/epoch)")
    print("  2. STRONGER REGULARIZATION: increase dropout 0.3->0.5, weight_decay 1e-5->1e-4")
    print("  3. MSE LOSS: replace BCE with MSE (less prone to false-positive activation)")
    print("  4. LOWER LR: start at 3e-4 with cosine annealing to 1e-6")
    print("  5. FEWER CHANNELS: reduce encoder to 16-32-64-128 (less overfitting)")
    print("  6. LARGER SIGMA: sigma=2->3 in heatmap generation (easier targets)")
    print("  7. EARLY STOPPING: save best val_loss model, stop if no improvement for 15 epochs")
else:
    print(f"Model looks healthy! Mean error: {mean_err:.1f} px, NME: {nme*100