# Facial Keypoint Detection
This notebook implements facial keypoint detection (68 landmarks) using four approaches:
1. **CNN** — Custom convolutional network (regression)
2. **ResNet** — Transfer learning with ResNet18 (regression)
3. **DINO** — Transfer learning with DINO Vision Transformer (regression)
4. **U-Net** — Heatmap-based keypoint detection

---
## 1. Setup

In [None]:
import os
import random
import subprocess
import zipfile

import cv2
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from PIL import Image
from scipy.ndimage import gaussian_filter
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

### 1.1 Download Data

In [None]:
def download_data(data_dir='data'):
    if os.path.exists(os.path.join(data_dir, 'training')) and os.path.exists(os.path.join(data_dir, 'test')):
        print("Data already exists.")
        return
    print("Data not found. Downloading...")
    os.makedirs(data_dir, exist_ok=True)
    zip_path = os.path.join(data_dir, 'train-test-data.zip')
    subprocess.run([
        'wget', '-q', '-O', zip_path,
        'https://s3.amazonaws.com/video.udacity-data.com/topher/2018/May/5aea1b91_train-test-data/train-test-data.zip'
    ], check=True)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(data_dir)
    os.remove(zip_path)
    print("Data downloaded and extracted.")

download_data()

### 1.2 Transforms

In [None]:
class Rescale:
    """Rescale the image in a sample to a given size."""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        new_h, new_w = int(new_h), int(new_w)
        img = cv2.resize(image, (new_w, new_h))
        key_pts = key_pts * [new_w / w, new_h / h]
        return {'image': img, 'keypoints': key_pts}


class RandomCrop:
    """Crop randomly the image in a sample."""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size

    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        image = image[top:top + new_h, left:left + new_w]
        key_pts = key_pts - [left, top]
        return {'image': image, 'keypoints': key_pts}


class NormalizeOriginal:
    """Convert to grayscale and normalize image [0,1], keypoints (pts-100)/50."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        image_copy = cv2.cvtColor(np.copy(image), cv2.COLOR_RGB2GRAY)
        image_copy = (image_copy / 255.0).astype(np.float32)
        key_pts_copy = (np.copy(key_pts) - 100) / 50.0
        return {'image': image_copy, 'keypoints': key_pts_copy}


class ToTensor:
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        if len(image.shape) == 2:
            image = image.reshape(image.shape[0], image.shape[1], 1)
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'keypoints': torch.from_numpy(key_pts)}


class RandomHorizontalFlip:
    """Random horizontal flip with correct keypoint mirroring for 68 landmarks."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        image_copy = np.copy(image)
        key_pts_copy = np.copy(key_pts)
        key_pts_copy_2 = np.copy(key_pts_copy)

        if random.choice([0, 1]) <= 0.5:
            image_copy = np.fliplr(image_copy)
            key_pts_copy[:, 0] = -key_pts_copy[:, 0]
            key_pts_copy[:, 0] = key_pts_copy[:, 0] + image_copy.shape[1]
            key_pts_copy_2 = np.copy(key_pts_copy)

            # Mirror jawline
            for i, j in zip(range(0, 8), range(16, 8, -1)):
                key_pts_copy_2[j] = key_pts_copy[i]
                key_pts_copy_2[i] = key_pts_copy[j]
            # Mirror eyebrows
            for i, j in zip(range(17, 22), range(26, 21, -1)):
                key_pts_copy_2[j] = key_pts_copy[i]
                key_pts_copy_2[i] = key_pts_copy[j]
            # Mirror nose tip
            key_pts_copy_2[35] = key_pts_copy[31]
            key_pts_copy_2[34] = key_pts_copy[32]
            key_pts_copy_2[32] = key_pts_copy[34]
            key_pts_copy_2[31] = key_pts_copy[35]
            # Mirror eyes
            key_pts_copy_2[45] = key_pts_copy[36]; key_pts_copy_2[44] = key_pts_copy[37]
            key_pts_copy_2[43] = key_pts_copy[38]; key_pts_copy_2[42] = key_pts_copy[39]
            key_pts_copy_2[47] = key_pts_copy[40]; key_pts_copy_2[46] = key_pts_copy[41]
            key_pts_copy_2[39] = key_pts_copy[42]; key_pts_copy_2[38] = key_pts_copy[43]
            key_pts_copy_2[37] = key_pts_copy[44]; key_pts_copy_2[36] = key_pts_copy[45]
            key_pts_copy_2[41] = key_pts_copy[46]; key_pts_copy_2[40] = key_pts_copy[47]
            # Mirror lips
            key_pts_copy_2[54] = key_pts_copy[48]; key_pts_copy_2[53] = key_pts_copy[49]
            key_pts_copy_2[52] = key_pts_copy[50]; key_pts_copy_2[50] = key_pts_copy[52]
            key_pts_copy_2[49] = key_pts_copy[53]; key_pts_copy_2[48] = key_pts_copy[54]
            key_pts_copy_2[59] = key_pts_copy[55]; key_pts_copy_2[58] = key_pts_copy[56]
            key_pts_copy_2[56] = key_pts_copy[58]; key_pts_copy_2[55] = key_pts_copy[59]
            key_pts_copy_2[64] = key_pts_copy[60]; key_pts_copy_2[63] = key_pts_copy[61]
            key_pts_copy_2[61] = key_pts_copy[63]; key_pts_copy_2[60] = key_pts_copy[64]
            key_pts_copy_2[67] = key_pts_copy[65]; key_pts_copy_2[65] = key_pts_copy[67]

        return {'image': image_copy, 'keypoints': key_pts_copy_2}


class RandomRotate:
    """Random rotation by +/- angle degrees."""
    def __init__(self, rotation=30):
        self.rotation = rotation

    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        image_copy = np.copy(image)
        key_pts_copy = np.copy(key_pts)
        rows, cols = image.shape[:2]
        M = cv2.getRotationMatrix2D((rows / 2, cols / 2), random.choice([-self.rotation, self.rotation]), 1)
        image_copy = cv2.warpAffine(image_copy, M, (cols, rows))
        key_pts_copy = key_pts_copy.reshape((1, 136))
        new_keypoints = np.zeros(136)
        for i in range(68):
            coord_idx = 2 * i
            old_coord = key_pts_copy[0][coord_idx:coord_idx + 2]
            new_coord = np.matmul(M, np.append(old_coord, 1))
            new_keypoints[coord_idx] += new_coord[0]
            new_keypoints[coord_idx + 1] += new_coord[1]
        return {'image': image_copy, 'keypoints': new_keypoints.reshape((68, 2))}


class ColorJitter:
    """Random color jitter (brightness, contrast, saturation)."""
    def __call__(self, sample):
        image, key_pts = sample['image'], sample['keypoints']
        color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
        image_copy = np.copy(image)
        if image_copy.dtype != np.uint8:
            image_copy = (image_copy * 255).astype(np.uint8)
        image_copy = np.array(color_jitter(Image.fromarray(image_copy)))
        return {'image': image_copy, 'keypoints': np.copy(key_pts)}

### 1.3 Datasets

In [None]:
class FacialKeypointsDataset(Dataset):
    """Face Landmarks dataset for regression-based models."""
    def __init__(self, csv_file, root_dir, transform=None):
        self.key_pts_frame = pd.read_csv(csv_file, index_col=0)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.key_pts_frame)

    def __getitem__(self, idx):
        image_name = os.path.join(self.root_dir, self.key_pts_frame.index[idx])
        image = mpimg.imread(image_name)
        if image.shape[2] == 4:
            image = image[:, :, 0:3]
        key_pts = self.key_pts_frame.iloc[idx, :].values.astype('float').reshape(-1, 2)
        sample = {'image': image, 'keypoints': key_pts}
        if self.transform:
            sample = self.transform(sample)
        return sample


class FacialKeypointsHeatmapDataset(Dataset):
    """Face Landmarks dataset with heatmap generation for U-Net."""
    def __init__(self, csv_file, root_dir, transform=None, output_size=64, sigma=1, image_size=224):
        self.key_pts_frame = pd.read_csv(csv_file, index_col=0)
        self.root_dir = root_dir
        self.transform = transform
        self.output_size = output_size
        self.sigma = sigma
        self.image_size = image_size

    def __len__(self):
        return len(self.key_pts_frame)

    def __getitem__(self, idx):
        image_name = os.path.join(self.root_dir, self.key_pts_frame.index[idx])
        image = mpimg.imread(image_name)
        if image.shape[2] == 4:
            image = image[:, :, 0:3]
        key_pts = self.key_pts_frame.iloc[idx, :].values.astype('float').reshape(-1, 2)
        sample = {'image': image, 'keypoints': key_pts}
        if self.transform:
            sample = self.transform(sample)
        sample['heatmaps'] = self.generate_heatmaps(sample['keypoints'])
        return sample

    def generate_heatmaps(self, keypoints):
        if isinstance(keypoints, torch.Tensor):
            keypoints = keypoints.numpy()
        num_keypoints = keypoints.shape[0]
        heatmaps = np.zeros((num_keypoints, self.output_size, self.output_size), dtype=np.float32)
        keypoints_scaled = (keypoints * 50 + 100) * (self.output_size / self.image_size)
        for i in range(num_keypoints):
            x, y = keypoints_scaled[i]
            if np.isnan(x) or np.isnan(y):
                continue
            x_int = max(0, min(self.output_size - 1, int(x)))
            y_int = max(0, min(self.output_size - 1, int(y)))
            heatmap = np.zeros((self.output_size, self.output_size), dtype=np.float32)
            heatmap[y_int, x_int] = 1.0
            heatmap = gaussian_filter(heatmap, sigma=self.sigma)
            if heatmap.max() > 0:
                heatmap = heatmap / heatmap.max()
            heatmaps[i] = heatmap
        return torch.from_numpy(heatmaps)

### 1.4 Data Loaders

In [None]:
def load_regression_data(batch_size=64):
    """Load datasets and dataloaders for regression-based models (CNN, ResNet, DINO)."""
    data_transform = transforms.Compose([
        Rescale(250), RandomCrop(224), NormalizeOriginal(), ToTensor()
    ])
    train_dataset = FacialKeypointsDataset('data/training_frames_keypoints.csv', 'data/training', data_transform)
    test_dataset = FacialKeypointsDataset('data/test_frames_keypoints.csv', 'data/test', data_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_loader, test_loader


def load_heatmap_data(batch_size=256, heatmap_size=64):
    """Load datasets and dataloaders for heatmap-based models (U-Net)."""
    train_transform = transforms.Compose([
        Rescale(250), RandomCrop(224), RandomHorizontalFlip(), RandomRotate(15),
        ColorJitter(), NormalizeOriginal(), ToTensor()
    ])
    test_transform = transforms.Compose([
        Rescale((224, 224)), NormalizeOriginal(), ToTensor()
    ])
    train_dataset = FacialKeypointsHeatmapDataset(
        'data/training_frames_keypoints.csv', 'data/training',
        transform=train_transform, output_size=heatmap_size, sigma=2, image_size=224)
    test_dataset = FacialKeypointsHeatmapDataset(
        'data/test_frames_keypoints.csv', 'data/test',
        transform=test_transform, output_size=heatmap_size, sigma=2, image_size=224)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    return train_loader, test_loader

### 1.5 Training & Evaluation Utilities

In [None]:
def save_checkpoint(model, optimizer, epoch, step, model_name, path='checkpoints/last_checkpoint.pth'):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch, 'step': step, 'model_name': model_name,
    }, path)
    print(f"Checkpoint saved to {path}")


def evaluate_regression(model, test_loader, criterion, device):
    """Evaluate a regression model (CNN/ResNet/DINO) on the test set."""
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            keypoints = batch['keypoints'].view(images.size(0), -1).to(device)
            outputs = model(images)
            loss = model.compute_loss(outputs, keypoints, criterion)
            total_loss += loss.item()
    return total_loss / len(test_loader)


def evaluate_heatmap(model, test_loader, device):
    """Evaluate U-Net on the test set using BCE loss."""
    model.eval()
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(device))
    total_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            target = batch['heatmaps'].to(device)
            logits = model(images)
            total_loss += criterion(logits, target).item()
    return total_loss / len(test_loader)


def heatmaps_to_keypoints(heatmaps, heatmap_size=64, image_size=224):
    """Extract keypoint coordinates from heatmaps using argmax."""
    batch_size, num_kpts, h, w = heatmaps.shape
    heatmaps_flat = heatmaps.view(batch_size, num_kpts, -1)
    max_indices = heatmaps_flat.argmax(dim=2)
    y_coords = (max_indices // w).float()
    x_coords = (max_indices % w).float()
    x_coords = x_coords * (image_size / heatmap_size)
    y_coords = y_coords * (image_size / heatmap_size)
    x_norm = (x_coords - 100) / 50.0
    y_norm = (y_coords - 100) / 50.0
    return torch.stack([x_norm, y_norm], dim=2)


def train_regression(model, train_loader, test_loader, optimizer, criterion, device,
                     model_name='model', num_epochs=10, eval_interval=10,
                     log_interval=5, save_interval=30):
    """Training loop for regression-based models (CNN, ResNet, DINO)."""
    step = 0
    running_loss = 0

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            step += 1
            image = batch['image'].to(device)
            keypoints = batch['keypoints'].reshape(image.size(0), 136).float().to(device)

            preds = model(image)
            loss = model.compute_loss(preds, keypoints, criterion)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if step % log_interval == 0:
                avg_loss = running_loss / log_interval
                wandb.log({'train_loss': avg_loss, 'step': step})
                print(f"Step {step}: Train Loss = {avg_loss:.4f}")
                running_loss = 0

            if step % eval_interval == 0:
                val_loss = evaluate_regression(model, test_loader, criterion, device)
                wandb.log({'val_loss': val_loss, 'step': step})
                model.train()

            if step % save_interval == 0:
                save_checkpoint(model, optimizer, epoch, step, model_name,
                                path=f'checkpoints-{model_name}/step_{step}.pth')

    print(f"Training complete for {model_name}!")
    return epoch, step


def train_heatmap(model, train_loader, test_loader, optimizer, device,
                  model_name='unet', num_epochs=10, eval_interval=10,
                  log_interval=5, save_interval=30):
    """Training loop for heatmap-based models (U-Net)."""
    step = 0
    running_loss = 0
    best_val_loss = float('inf')
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(device))

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            step += 1
            images = batch['image'].to(device)
            heatmaps_gt = batch['heatmaps'].to(device)

            logits = model(images)
            loss = criterion(logits, heatmaps_gt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if step % log_interval == 0:
                avg_loss = running_loss / log_interval
                wandb.log({'train_loss': avg_loss, 'step': step, 'epoch': epoch,
                           'lr': optimizer.param_groups[0]['lr']})
                print(f"Epoch {epoch}, Step {step}: Train Loss = {avg_loss:.6f} (lr={optimizer.param_groups[0]['lr']:.6f})")
                running_loss = 0

            if step % eval_interval == 0:
                val_loss = evaluate_heatmap(model, test_loader, device)
                wandb.log({'val_loss': val_loss, 'step': step, 'epoch': epoch})
                print(f"Epoch {epoch}, Step {step}: Val Loss = {val_loss:.6f}")
                model.train()
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    save_checkpoint(model, optimizer, epoch, step, model_name,
                                    path=f'checkpoints-{model_name}/best.pth')
                    print(f"  -> New best val loss: {val_loss:.6f}")

            if step % save_interval == 0:
                save_checkpoint(model, optimizer, epoch, step, model_name,
                                path=f'checkpoints-{model_name}/step_{step}.pth')

        epoch_val_loss = evaluate_heatmap(model, test_loader, device)
        scheduler.step(epoch_val_loss)
        print(f"--- Epoch {epoch} done. Val Loss = {epoch_val_loss:.6f}, LR = {optimizer.param_groups[0]['lr']:.6f} ---")
        model.train()

    print("Heatmap training complete!")
    return epoch, step


def visualize_keypoints(test_loader, model, device='cuda'):
    """Visualize predicted vs ground-truth keypoints on test samples."""
    model.eval()
    for i, data in enumerate(test_loader):
        image = data['image'][0]
        images = data['image']
        with torch.no_grad():
            images = images.to(device)
            predictions = model(images)
        if predictions.dim() == 4:
            predictions = heatmaps_to_keypoints(predictions)
        else:
            predictions = predictions.reshape(images.size(0), 68, 2)
        predictions = predictions.cpu().numpy()
        pred_kpts = predictions[0]
        gt_kpts = data['keypoints'][0].numpy()
        pred_kpts_denorm = (pred_kpts * 50) + 100
        gt_kpts_denorm = (gt_kpts * 50) + 100
        plt.figure(figsize=(10, 6))
        plt.imshow(image.numpy().transpose(1, 2, 0), cmap='gray')
        plt.scatter(pred_kpts_denorm[:, 0], pred_kpts_denorm[:, 1], c='r', s=20, label='Predicted')
        plt.scatter(gt_kpts_denorm[:, 0], gt_kpts_denorm[:, 1], c='g', s=20, label='Ground Truth')
        plt.legend()
        plt.title(f'Test Sample {i}')
        plt.show()
        if i >= 4:
            break

---
## 2. CNN

### 2.1 Model Definition

In [None]:
class CNN(nn.Module):
    def __init__(self, num_keypoints=68, dropout=0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, num_keypoints * 2)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), 2)
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))), 2)
        x = F.max_pool2d(F.relu(self.bn4(self.conv4(x))), 2)
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

    def compute_loss(self, preds, labels, criterion):
        if criterion == 'mse':
            loss_fn = nn.MSELoss()
        elif criterion == 'Smoothl1Loss':
            loss_fn = nn.SmoothL1Loss()
        else:
            raise ValueError(f'Unknown criterion: {criterion}')
        return loss_fn(preds, labels)

### 2.2 Train CNN

In [None]:
# --- Configuration ---
CNN_EPOCHS = 50
CNN_CRITERION = 'Smoothl1Loss'  # or 'mse'
CNN_LR = 1e-3
CNN_BATCH_SIZE = 64

# --- Initialize ---
cnn_model = CNN().to(device)
cnn_train_loader, cnn_test_loader = load_regression_data(batch_size=CNN_BATCH_SIZE)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=CNN_LR)

print(f"CNN parameters: {sum(p.numel() for p in cnn_model.parameters()):,}")

# --- Train ---
wandb.init(project='facial-keypoints', name='cnn-train', reinit=True)
wandb.config.update({'model': 'cnn', 'criterion': CNN_CRITERION, 'lr': CNN_LR})

cnn_epoch, cnn_step = train_regression(
    cnn_model, cnn_train_loader, cnn_test_loader, cnn_optimizer,
    CNN_CRITERION, device, model_name='cnn', num_epochs=CNN_EPOCHS
)

save_checkpoint(cnn_model, cnn_optimizer, cnn_epoch, cnn_step, 'cnn',
                path='checkpoints-cnn/last_checkpoint.pth')
wandb.finish()

### 2.3 Visualize CNN

In [None]:
visualize_keypoints(cnn_test_loader, cnn_model, device)

---
## 3. Transfer Learning

### 3.1 ResNet Model

In [None]:
class ResNetKeypointDetector(nn.Module):
    def __init__(self, num_keypoints=68, backbone='resnet18', pretrained=True, freeze_backbone=True):
        super().__init__()
        if backbone == 'resnet18':
            self.backbone = models.resnet18(pretrained=pretrained)
        elif backbone == 'resnet34':
            self.backbone = models.resnet34(pretrained=pretrained)

        # Adapt first conv for grayscale
        original_conv = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if pretrained:
            with torch.no_grad():
                self.backbone.conv1.weight[:, 0, :, :] = original_conv.weight[:, 0, :, :]

        self.backbone_out_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.regression_head = nn.Sequential(
            nn.Linear(self.backbone_out_features, 1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, num_keypoints * 2),
        )

        self.freeze_backbone = freeze_backbone
        if freeze_backbone:
            self._freeze_backbone()

    def _freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

    def _unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True

    def forward(self, x):
        features = self.backbone(x)
        return self.regression_head(features)

    def compute_loss(self, preds, labels, criterion):
        if criterion == 'mse':
            loss_fn = nn.MSELoss()
        elif criterion == 'Smoothl1Loss':
            loss_fn = nn.SmoothL1Loss()
        else:
            raise ValueError(f'Unknown criterion: {criterion}')
        return loss_fn(preds, labels)

### 3.2 Train ResNet (Stage 1: Frozen Backbone)

In [None]:
# --- Configuration ---
RESNET_EPOCHS_FROZEN = 50
RESNET_CRITERION = 'Smoothl1Loss'
RESNET_LR_FROZEN = 5e-4
RESNET_BATCH_SIZE = 64

# --- Initialize ---
resnet_model = ResNetKeypointDetector(backbone='resnet18', pretrained=True, freeze_backbone=True).to(device)
resnet_train_loader, resnet_test_loader = load_regression_data(batch_size=RESNET_BATCH_SIZE)
resnet_optimizer = torch.optim.Adam(
    [p for p in resnet_model.parameters() if p.requires_grad], lr=RESNET_LR_FROZEN
)

print(f"ResNet parameters (total): {sum(p.numel() for p in resnet_model.parameters()):,}")
print(f"ResNet parameters (trainable): {sum(p.numel() for p in resnet_model.parameters() if p.requires_grad):,}")

# --- Train Stage 1: Frozen backbone ---
wandb.init(project='facial-keypoints', name='resnet-frozen', reinit=True)
wandb.config.update({'model': 'resnet', 'criterion': RESNET_CRITERION, 'freeze': True})

resnet_epoch, resnet_step = train_regression(
    resnet_model, resnet_train_loader, resnet_test_loader, resnet_optimizer,
    RESNET_CRITERION, device, model_name='resnet', num_epochs=RESNET_EPOCHS_FROZEN
)

save_checkpoint(resnet_model, resnet_optimizer, resnet_epoch, resnet_step, 'resnet',
                path='checkpoints-resnet/frozen_checkpoint.pth')
wandb.finish()
print("Stage 1 (frozen backbone) complete!")

### 3.3 Train ResNet (Stage 2: Fine-tune Full Model)

In [None]:
# --- Configuration ---
RESNET_EPOCHS_FINETUNE = 50
RESNET_LR_FINETUNE = 1e-4

# --- Unfreeze and set up new optimizer ---
resnet_model._unfreeze_backbone()
resnet_optimizer = torch.optim.Adam(resnet_model.parameters(), lr=RESNET_LR_FINETUNE)

print(f"ResNet parameters (trainable after unfreeze): {sum(p.numel() for p in resnet_model.parameters() if p.requires_grad):,}")

# --- Train Stage 2: Fine-tune ---
wandb.init(project='facial-keypoints', name='resnet-finetune', reinit=True)
wandb.config.update({'model': 'resnet', 'criterion': RESNET_CRITERION, 'freeze': False})

resnet_epoch, resnet_step = train_regression(
    resnet_model, resnet_train_loader, resnet_test_loader, resnet_optimizer,
    RESNET_CRITERION, device, model_name='resnet', num_epochs=RESNET_EPOCHS_FINETUNE
)

save_checkpoint(resnet_model, resnet_optimizer, resnet_epoch, resnet_step, 'resnet',
                path='checkpoints-resnet/last_checkpoint.pth')
wandb.finish()
print("Stage 2 (fine-tune) complete!")

### 3.4 Visualize ResNet

In [None]:
visualize_keypoints(resnet_test_loader, resnet_model, device)

### 3.5 DINO Model

In [None]:
class DINOKeypointDetector(nn.Module):
    def __init__(self, num_keypoints=68, model_name='vit_base_patch16_224.dino',
                 pretrained=True, freeze_backbone=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained)
        self.backbone_out_features = self.backbone.embed_dim
        self.backbone.head = nn.Identity()

        self.regression_head = nn.Sequential(
            nn.Linear(self.backbone_out_features, 1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, num_keypoints * 2),
        )

        self.freeze_backbone = freeze_backbone
        if freeze_backbone:
            self._freeze_backbone()

    def _freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

    def _unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True

    def forward(self, x):
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone.forward_features(x)
        if len(features.shape) > 2:
            features = features[:, 0, :]
        return self.regression_head(features)

    def compute_loss(self, preds, labels, criterion):
        if criterion == 'mse':
            loss_fn = nn.MSELoss()
        elif criterion == 'Smoothl1Loss':
            loss_fn = nn.SmoothL1Loss()
        else:
            raise ValueError(f'Unknown criterion: {criterion}')
        return loss_fn(preds, labels)

### 3.6 Train DINO (Stage 1: Frozen Backbone)

In [None]:
# --- Configuration ---
DINO_EPOCHS_FROZEN = 50
DINO_CRITERION = 'Smoothl1Loss'
DINO_LR_FROZEN = 5e-4
DINO_BATCH_SIZE = 64

# --- Initialize ---
dino_model = DINOKeypointDetector(pretrained=True, freeze_backbone=True).to(device)
dino_train_loader, dino_test_loader = load_regression_data(batch_size=DINO_BATCH_SIZE)
dino_optimizer = torch.optim.Adam(
    [p for p in dino_model.parameters() if p.requires_grad], lr=DINO_LR_FROZEN
)

print(f"DINO parameters (total): {sum(p.numel() for p in dino_model.parameters()):,}")
print(f"DINO parameters (trainable): {sum(p.numel() for p in dino_model.parameters() if p.requires_grad):,}")

# --- Train Stage 1: Frozen backbone ---
wandb.init(project='facial-keypoints', name='dino-frozen', reinit=True)
wandb.config.update({'model': 'dino', 'criterion': DINO_CRITERION, 'freeze': True})

dino_epoch, dino_step = train_regression(
    dino_model, dino_train_loader, dino_test_loader, dino_optimizer,
    DINO_CRITERION, device, model_name='dino', num_epochs=DINO_EPOCHS_FROZEN
)

save_checkpoint(dino_model, dino_optimizer, dino_epoch, dino_step, 'dino',
                path='checkpoints-dino/frozen_checkpoint.pth')
wandb.finish()
print("DINO Stage 1 (frozen backbone) complete!")

### 3.7 Train DINO (Stage 2: Fine-tune Full Model)

In [None]:
# --- Configuration ---
DINO_EPOCHS_FINETUNE = 50
DINO_LR_FINETUNE = 1e-4

# --- Unfreeze and set up new optimizer ---
dino_model._unfreeze_backbone()
dino_optimizer = torch.optim.Adam(dino_model.parameters(), lr=DINO_LR_FINETUNE)

print(f"DINO parameters (trainable after unfreeze): {sum(p.numel() for p in dino_model.parameters() if p.requires_grad):,}")

# --- Train Stage 2: Fine-tune ---
wandb.init(project='facial-keypoints', name='dino-finetune', reinit=True)
wandb.config.update({'model': 'dino', 'criterion': DINO_CRITERION, 'freeze': False})

dino_epoch, dino_step = train_regression(
    dino_model, dino_train_loader, dino_test_loader, dino_optimizer,
    DINO_CRITERION, device, model_name='dino', num_epochs=DINO_EPOCHS_FINETUNE
)

save_checkpoint(dino_model, dino_optimizer, dino_epoch, dino_step, 'dino',
                path='checkpoints-dino/last_checkpoint.pth')
wandb.finish()
print("DINO Stage 2 (fine-tune) complete!")

### 3.8 Visualize DINO

In [None]:
visualize_keypoints(dino_test_loader, dino_model, device)

---
## 4. U-Net (Heatmap-based)

### 4.1 Model Definition

In [None]:
class DoubleConv(nn.Module):
    """Two consecutive conv-bn-relu blocks."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, num_keypoints=68, heatmap_size=64):
        super().__init__()
        self.heatmap_size = heatmap_size

        # Encoder
        self.enc1 = DoubleConv(in_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)
        self.pool = nn.MaxPool2d(2)
        self.dropout_enc = nn.Dropout2d(p=0.3)

        # Bottleneck
        self.bottleneck = DoubleConv(256, 512)
        self.dropout_bottleneck = nn.Dropout2d(p=0.3)
        self.dropout_dec = nn.Dropout2d(p=0.2)

        # Decoder
        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(128, 64)
        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(64, 32)

        self.out_conv = nn.Conv2d(32, num_keypoints, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.dropout_enc(self.enc4(self.pool(e3)))

        b = self.dropout_bottleneck(self.bottleneck(self.pool(e4)))

        d4 = self.dropout_dec(self.dec4(torch.cat([self.up4(b), e4], dim=1)))
        d3 = self.dropout_dec(self.dec3(torch.cat([self.up3(d4), e3], dim=1)))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        out = self.out_conv(d1)
        out = F.interpolate(out, size=self.heatmap_size, mode='bilinear', align_corners=False)
        return out

### 4.2 Train U-Net

In [None]:
# --- Configuration ---
UNET_EPOCHS = 50
UNET_LR = 1e-3
UNET_BATCH_SIZE = 256

# --- Initialize ---
unet_model = UNet(in_channels=1, num_keypoints=68, heatmap_size=64).to(device)
unet_train_loader, unet_test_loader = load_heatmap_data(batch_size=UNET_BATCH_SIZE)
unet_optimizer = torch.optim.Adam(unet_model.parameters(), lr=UNET_LR, weight_decay=1e-5)

print(f"U-Net parameters: {sum(p.numel() for p in unet_model.parameters()):,}")

# --- Train ---
wandb.init(project='facial-keypoints', name='unet-train', reinit=True)
wandb.config.update({'model': 'unet', 'criterion': 'bce', 'lr': UNET_LR})

unet_epoch, unet_step = train_heatmap(
    unet_model, unet_train_loader, unet_test_loader, unet_optimizer,
    device, model_name='unet', num_epochs=UNET_EPOCHS
)

save_checkpoint(unet_model, unet_optimizer, unet_epoch, unet_step, 'unet',
                path='checkpoints-unet/last_checkpoint.pth')
wandb.finish()

### 4.3 Visualize U-Net

In [None]:
visualize_keypoints(unet_test_loader, unet_model, device)