# A Reference-Guided Stacked Hourglass Network for Facial Landmark Detection
Albert Kim - Undergraduate Honors Thesis

## Setup

In [None]:
import os
DATA_ROOT_DIR = "/content"
IMG_DIR = os.path.join(DATA_ROOT_DIR, "extracted_data", "WFLW_images")
TRAIN_LANDMARK_FILE_PATH = os.path.join(DATA_ROOT_DIR, "extracted_annotations", "WFLW_annotations", "list_98pt_rect_attr_train_test", "list_98pt_rect_attr_train.txt")
TEST_LANDMARK_FILE_PATH = os.path.join(DATA_ROOT_DIR, "extracted_annotations", "WFLW_annotations", "list_98pt_rect_attr_train_test", "list_98pt_rect_attr_test.txt")
IMAGE_RESIZE = (128, 128)

In [None]:
import gdown
import os

# WFLW Images
file_id = '1hzBd48JIdWTJSsATBEB_eFVvPL1bx6UC'
output_path = '/content/WFLW_images.tar.gz'
gdown.download(id=file_id, output=output_path, quiet=False)

extract_dir = '/content/extracted_data'
os.makedirs(extract_dir, exist_ok=True)
!tar -xvf {output_path} -C {extract_dir}

print(f"File downloaded to {output_path}")
print(f"File extracted to {extract_dir}")

# WFLW Annotations
annotations_url = 'https://wywu.github.io/projects/LAB/support/WFLW_annotations.tar.gz'
annotations_output_path = '/content/WFLW_annotations.tar.gz'
annotations_extract_dir = '/content/extracted_annotations'
os.makedirs(annotations_extract_dir, exist_ok=True)

!wget {annotations_url} -O {annotations_output_path}
!tar -xzvf {annotations_output_path} -C {annotations_extract_dir}

print(f"Annotations downloaded to {annotations_output_path}")
print(f"Annotations extracted to {annotations_extract_dir}")

## Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
import random
import cv2
import glob
from typing import Union
from tqdm import tqdm
import json

## Dataset Class

In [None]:
def generate_heatmap(landmark, heatmap_size, sigma=5):

    x, y = landmark
    H, W = heatmap_size
    heatmap = np.zeros((H, W), dtype=np.float32)

    # Ensure bounds
    if x < 0 or y < 0 or x >= W or y >= H:
        return heatmap
    # Generate heatmap
    xv, yv = np.meshgrid(np.arange(W), np.arange(H))
    heatmap = np.exp(-((xv - x) ** 2 + (yv - y) ** 2) / (2 * sigma ** 2))
    return heatmap

def heatmaps_to_coordinates(heatmaps, tau=0.05):
    B, N, H, W = heatmaps.shape
    heatmaps_flat = heatmaps.view(B, N, -1)
    probs = F.softmax(heatmaps_flat / tau, dim=-1)

    y_coords = torch.arange(H, device=heatmaps.device, dtype=torch.float32)
    x_coords = torch.arange(W, device=heatmaps.device, dtype=torch.float32)
    yy, xx = torch.meshgrid(y_coords, x_coords, indexing='ij')
    xx = xx.reshape(-1)
    yy = yy.reshape(-1)

    x = torch.sum(probs * xx, dim=-1)
    y = torch.sum(probs * yy, dim=-1)
    return torch.stack([x, y], dim=-1)

class WFLWLandmarksDataset_ReferenceData(Dataset):

    def __init__(self, img_dir, landmark_file, transform=None, heatmap_size=IMAGE_RESIZE, max_images=None, attr_map=None):

        # Default transform
        if transform is None:
          transform = transforms.Compose([
              transforms.Resize(IMAGE_RESIZE),
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
          ])

        # Set instance vars
        self.img_dir = img_dir
        self.landmark_file = landmark_file
        self.transform = transform
        self.heatmap_size = heatmap_size
        self.max_images = max_images
        self.attr_map = attr_map

        self.data = self._load_data()
        self.data_typed = {}

    def _load_data(self):

        with open(self.landmark_file, 'r') as f:
            lines = f.readlines()
        # Limit number of images (depreciated)
        if self.max_images is not None:
            lines = lines[:self.max_images]

        data_list = []
        for line in lines:
            tokens = line.strip().split()
            if len(tokens) != 207:  # 196 landmarks, 4 bounding box, 6 attrs, 1 img name
                continue

            landmarks = list(map(float, tokens[:196]))
            rect = list(map(float, tokens[196:200]))
            attrs = list(map(float, tokens[200:206]))
            img_id = tokens[206]


            # Resize bounding box
            x1, y1, x2, y2 = rect
            scale = 1.2

            w = x2 - x1
            h = y2 - y1
            cx = x1 + w / 2
            cy = y1 + h / 2

            new_w = w * scale
            new_h = h * scale

            rect = [
                cx - new_w / 2,
                cy - new_h / 2,
                cx + new_w / 2,
                cy + new_h / 2
            ]

            if self.attr_map is None or self.attr_map == attrs:
                data_list.append({
                    'img_id': img_id,
                    'landmarks': landmarks,
                    'rect': rect,
                    'attributes': attrs
                })

        return data_list

    def _crop_image_by_rect(self, image, rect):
      x1, y1, x2, y2 = map(int, rect)
      h, w = image.shape[:2]

      # Calculate padding
      pad_left   = max(0, -x1)
      pad_top    = max(0, -y1)
      pad_right  = max(0, x2 - w)
      pad_bottom = max(0, y2 - h)

      # Apply padding if any
      if any([pad_left, pad_top, pad_right, pad_bottom]):
          image = np.pad(
              image,
              ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
              mode='constant',
              constant_values=0
          )

          # Adjust coordinates for padding
          x1 += pad_left
          y1 += pad_top
          x2 += pad_left
          y2 += pad_top

      cropped = image[y1:y2, x1:x2]
      return cropped

    def _adjust_landmarks_to_crop(self, landmarks, rect):
        x1, y1, x2, y2 = rect
        adjusted_landmarks = landmarks.copy()
        adjusted_landmarks[:, 0] -= x1  # Adjust x
        adjusted_landmarks[:, 1] -= y1  # Adjust y
        return adjusted_landmarks

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        # MAIN IMAGE
        item = self.data[idx]
        img_path = os.path.join(self.img_dir, item['img_id'])
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Crop main image by box
        image_cropped = self._crop_image_by_rect(image, item['rect'])
        H_img_cropped, W_img_cropped = image_cropped.shape[:2]

        if self.transform is not None:
            image_tensor = self.transform(Image.fromarray(image_cropped))
        else:
            image_tensor = torch.from_numpy(image_cropped.transpose(2,0,1) / 255.0).float()
        H_img, W_img = image_tensor.shape[1:3]

        # HEATMAPS
        landmarks = np.array(item['landmarks'], dtype=np.float32).reshape(-1,2)

        # Adjust landmarks to crop
        landmarks_adjusted = self._adjust_landmarks_to_crop(landmarks, item['rect'])

        # Scale lanemarks
        scale_x_landmark = self.heatmap_size[0] / W_img_cropped  # Transform resizes to 256x256
        scale_y_landmark = self.heatmap_size[1] / H_img_cropped

        landmarks_adjusted[:, 0] *= scale_x_landmark
        landmarks_adjusted[:, 1] *= scale_y_landmark

        heatmaps = np.zeros((landmarks_adjusted.shape[0], self.heatmap_size[0], self.heatmap_size[1]), dtype=np.float32)
        for i, (x, y) in enumerate(landmarks_adjusted):

            # Skip landmarks outside crop
            if 0 <= x < W_img and 0 <= y < H_img:
                heatmaps[i] = generate_heatmap((x, y), self.heatmap_size)
                #print(f"ACCEPTED: {img_path} LANDMARK {i}")
            #else:
            #    print(f"SKIPPED {idx}: {img_path} LANDMARK {i}")
        heatmaps_tensor = torch.from_numpy(heatmaps)

        # REFERENCE IMAGE
        target_attrs = [0, 0, 0, 0, 0, 0] #item['attributes']
        target_attrs_str = ','.join(map(str, target_attrs))

        # Find all normal attrs
        if self.data_typed.get(target_attrs_str) is None:
          self.data_typed[target_attrs_str] = [
              i for i, data_item in enumerate(self.data)
              if data_item['attributes'] == [0, 0, 0, 0, 0, 0] and i != idx
          ]
        matching_indices = self.data_typed[target_attrs_str]

        if matching_indices:
            # Randomly select from images
            ref_idx = random.choice(matching_indices)
        else:
            # If no matching images, use any image
            ref_idx = idx
            while ref_idx == idx:
                ref_idx = random.randint(0, len(self.data)-1)

        ref_item = self.data[ref_idx]

        ref_img_path = os.path.join(self.img_dir, ref_item['img_id'])
        ref_image = cv2.imread(ref_img_path)
        if ref_image is None:
            raise FileNotFoundError(f"Reference image not found: {ref_img_path}")
        ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)

        # Crop reference image
        ref_image_cropped = self._crop_image_by_rect(ref_image, ref_item['rect'])

        H_ref_img_cropped, W_ref_img_cropped = ref_image_cropped.shape[:2]

        if self.transform is not None:
            ref_image_tensor = self.transform(Image.fromarray(ref_image_cropped))
        else:
            ref_image_tensor = torch.from_numpy(ref_image_cropped.transpose(2,0,1) / 255.0).float()

        ref_landmarks = np.array(ref_item['landmarks'], dtype=np.float32).reshape(-1,2)

        # REFERENCE HEATMAPS
        ref_landmarks_adjusted = self._adjust_landmarks_to_crop(ref_landmarks, ref_item['rect'])

        scale_x_landmark = self.heatmap_size[0] / W_ref_img_cropped  # Transform resizes to 256x256
        scale_y_landmark = self.heatmap_size[1] / H_ref_img_cropped

        # Scale landmakrs
        ref_landmarks_adjusted[:, 0] *= scale_x_landmark
        ref_landmarks_adjusted[:, 1] *= scale_y_landmark

        ref_heatmaps = np.zeros((ref_landmarks_adjusted.shape[0], self.heatmap_size[0], self.heatmap_size[1]), dtype=np.float32)
        for i, (x, y) in enumerate(ref_landmarks_adjusted):

            # Skip landmarks outside box
            if 0 <= x < W_img and 0 <= y < H_img:
                #x_scaled = x * scale_x
                #y_scaled = y * scale_y
                ref_heatmaps[i] = generate_heatmap((x, y), self.heatmap_size)
        ref_heatmaps_tensor = torch.from_numpy(ref_heatmaps)

        return image_tensor, heatmaps_tensor, ref_image_tensor, ref_heatmaps_tensor, \
              item['img_id'], landmarks_adjusted, ref_item['img_id'], ref_landmarks_adjusted

## NN Class

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)


class HourglassModule(nn.Module):
    def __init__(self, n_channels, depth):
        super(HourglassModule, self).__init__()
        self.depth = depth
        self.n_channels = n_channels

        self.up1 = ConvBlock(n_channels, n_channels)
        self.pool = nn.MaxPool2d(2, 2)
        self.low1 = ConvBlock(n_channels, n_channels)

        if depth > 1:
            self.low2 = HourglassModule(n_channels, depth - 1)
        else:
            self.low2 = ConvBlock(n_channels, n_channels)

        self.low3 = ConvBlock(n_channels, n_channels)

    def forward(self, x):
        up1 = self.up1(x)
        down = self.pool(x)
        down = self.low1(down)
        down = self.low2(down)
        down = self.low3(down)

        up2 = F.interpolate(down, size=up1.shape[2:], mode='nearest')  # <— FIX HERE
        return up1 + up2

def get_preds_from_heatmaps(heatmaps):
    B, N, H, W = heatmaps.shape
    heatmaps_reshaped = heatmaps.view(B, N, -1)
    idx = torch.argmax(heatmaps_reshaped, dim=2)
    preds_y = (idx // W).float()
    preds_x = (idx % W).float()
    preds = torch.stack((preds_x, preds_y), dim=2)
    return preds

In [None]:
import math

from math import sqrt


class WingLoss(nn.Module):
    def __init__(self, w = 10.0, epsilon = 2.0):
        super().__init__()
        self.w = w
        self.epsilon = epsilon
        self.c = w * (1.0 - math.log(1.0 + w / epsilon))

    def forward(self, pred_coords, true_coords):

        # Ensure float tensors, for mixed-precision
        pred_coords = pred_coords.float()
        true_coords = true_coords.float()

        # Calc distance
        diff = pred_coords - true_coords
        abs_diff = torch.abs(diff)

        # Wing loss per landmark
        loss_small = self.w * torch.log(1.0 + abs_diff / self.epsilon)
        loss_large = abs_diff - self.c

        # Boundary
        loss = torch.where(abs_diff < self.w, loss_small, loss_large)

        # Mean loss
        return loss.mean()

In [None]:
class StackedHourglassNet_ReferenceData(nn.Module):

    def __init__(self, num_landmarks=98, num_channels=128, depth=4, num_hourglass_modules=2):
        super().__init__()

        # Images stem
        self.shared_stem = nn.Sequential(
            ConvBlock(3, 32),
            ConvBlock(32, 64),
            ConvBlock(64, num_channels)
        )

        # Hourglass backbone
        self.shared_hourglass = nn.ModuleList([
            HourglassModule(num_channels, depth) for _ in range(num_hourglass_modules)
        ])
        self.shared_intermediate_convs = nn.ModuleList([
            ConvBlock(num_channels, num_channels) for _ in range(num_hourglass_modules)
        ])

        # Ref heatmap stem
        self.stem_ref_heatmaps = nn.Sequential(
            ConvBlock(num_landmarks, 128),
            ConvBlock(128, num_channels)
        )

        # Flow predictor
        self.flow_predictor = nn.Sequential(
            HourglassModule(num_channels*2, depth),
            ConvBlock(num_channels*2, 2, kernel_size=1, stride=1, padding=0),
        )
        # Final heatmap predictor
        self.heatmap_predictor = nn.Sequential(
            HourglassModule(num_channels*2, depth),
            ConvBlock(num_channels*2, num_landmarks, kernel_size=1, stride=1, padding=0),
        )

    def forward_through_hourglass(self, x):

        features = self.shared_stem(x)
        for i in range(len(self.shared_hourglass)):
            features = self.shared_hourglass[i](features)
            features = self.shared_intermediate_convs[i](features)
        return features

    def warp_heatmaps(self, heatmaps, flow):

        B, C, H, W = heatmaps.shape

        # Normalized meshgrid
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(-1, 1, H, device=heatmaps.device),
            torch.linspace(-1, 1, W, device=heatmaps.device)
        )
        grid = torch.stack((grid_x, grid_y), dim=-1)
        grid = grid.unsqueeze(0).repeat(B, 1, 1, 1)

        # Normalize flow to [-1,1], percent of image dimensions
        flow_norm = torch.zeros_like(flow)
        flow_norm[:, 0, :, :] = flow[:, 0, :, :] / ((W - 1) / 2)
        flow_norm[:, 1, :, :] = flow[:, 1, :, :] / ((H - 1) / 2)
        flow_norm = flow_norm.permute(0, 2, 3, 1)
        warped_grid = grid + flow_norm
        warped_heatmaps = F.grid_sample(heatmaps, warped_grid, align_corners=True)
        return warped_heatmaps

    def forward(self, x, ref_x, ref_heatmaps):

        # Images through stems, backbone
        x_features = self.forward_through_hourglass(x)
        ref_img_features = self.forward_through_hourglass(ref_x)

        # Ref heatmaps through sttem
        ref_heat_features = self.stem_ref_heatmaps(ref_heatmaps)

        # Get flow field
        flow_input = torch.cat((x_features, ref_img_features), dim=1)
        flow_field = self.flow_predictor(flow_input)

        # Warp reference heatmnaps
        warped_ref_heatmaps = self.warp_heatmaps(ref_heat_features, flow_field)

        # Feature fusion
        combined = torch.cat((x_features, warped_ref_heatmaps), dim=1)

        # Final outputs
        out_heatmaps = self.heatmap_predictor(combined)

        return out_heatmaps

## Training

In [None]:
def train_stacked_hourglass_refdata_wflw(
    img_dir,
    landmark_file,
    model_save_path,
    num_hourglass_modules=2,
    train_split_ratio=0.85,
    batch_size=32,
    learning_rate=1e-3,
    num_epochs=20,
    transform=None,
    criterion=None,
    optimizer_class=torch.optim.Adam,
    optimizer_params=None,
    num_channels=128,
    mixed_precision=True,
    gradient_accumulation_steps=1,
    checkpoint_dir="./checkpoints",
    resume_checkpoint=None
):

    # Create checkpoint dir
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Default transform
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize(IMAGE_RESIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    # Load dataset
    dataset = WFLWLandmarksDataset_ReferenceData(
        img_dir=img_dir,
        landmark_file=landmark_file,
        transform=transform
    )

    if len(dataset) == 0:
        print("Error: Dataset is empty.")
        return None, None, None, None

    # Train/validation split
    train_size = int(train_split_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Data loader
    num_workers = min(8, os.cpu_count() // 2)
    pin_memory = True

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0,
        prefetch_factor=2 if num_workers > 0 else None
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0,
        prefetch_factor=2 if num_workers > 0 else None
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Mixed precision training
    scaler = torch.cuda.amp.GradScaler() if mixed_precision and device.type == 'cuda' else None

    # Initialize model
    model = StackedHourglassNet_ReferenceData(
        num_hourglass_modules=num_hourglass_modules,
        num_landmarks=98,
        num_channels=num_channels
    ).to(device)

    # Loss
    if criterion is None:
        criterion = WingLoss()
        #criterion = nn.MSELoss()

    # Optimizer
    if optimizer_params is None:
        optimizer_params = {"lr": learning_rate}
    optimizer = optimizer_class(model.parameters(), **optimizer_params)

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    # Training state vars
    start_epoch = 0
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    best_model = None

    # Resume from checkpoint if provided
    if resume_checkpoint and os.path.exists(resume_checkpoint):
        print(f"Loading checkpoint: {resume_checkpoint}")
        checkpoint = torch.load(resume_checkpoint, map_location=device)

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        train_losses = checkpoint['train_losses']
        val_losses = checkpoint['val_losses']

        if scaler and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])

        print(f"Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")

    print("\nStarting Training (Reference Data)...")

    # Validation
    @torch.no_grad()
    def validate_epoch(model, val_loader, criterion, device):
        model.eval()
        val_loss = 0.0
        num_batches = 0

        for batch in val_loader:
            images, heatmaps, ref_images, ref_heatmaps, _, _, _, _ = batch
            images, heatmaps = images.to(device, non_blocking=True), heatmaps.to(device, non_blocking=True)
            ref_images, ref_heatmaps = ref_images.to(device, non_blocking=True), ref_heatmaps.to(device, non_blocking=True)

            outputs = model(images, ref_images, ref_heatmaps)
            #outputs = F.interpolate(outputs, size=heatmaps.shape[2:], mode='bilinear', align_corners=False)

            pred_coords = heatmaps_to_coordinates(outputs)
            gt_coords = heatmaps_to_coordinates(heatmaps)

            loss = criterion(pred_coords, gt_coords)
            val_loss += loss.item()
            num_batches += 1

        return val_loss / num_batches

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0
        num_batches = 0

        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Progress bar
        pbar = tqdm(
            train_loader,
            desc=f'Epoch {epoch+1}/{num_epochs}',
            leave=True,
            ncols=100
        )

        for i, batch in enumerate(pbar):
            images, heatmaps, ref_images, ref_heatmaps, _, _, _, _ = batch

            # Mixed precision, gradient accumulation
            if i % gradient_accumulation_steps == 0:
                optimizer.zero_grad()

            # Faster, probably
            images = images.to(device, non_blocking=True)
            heatmaps = heatmaps.to(device, non_blocking=True)
            ref_images = ref_images.to(device, non_blocking=True)
            ref_heatmaps = ref_heatmaps.to(device, non_blocking=True)

            # Mixed precision forward
            if scaler:
                with torch.cuda.amp.autocast():
                    outputs = model(images, ref_images, ref_heatmaps)
                    #outputs = F.interpolate(outputs, size=heatmaps.shape[2:], mode='bilinear', align_corners=False)
                    pred_coords = heatmaps_to_coordinates(outputs)
                    gt_coords = heatmaps_to_coordinates(heatmaps)

                    loss = criterion(pred_coords, gt_coords) / gradient_accumulation_steps
                scaler.scale(loss).backward()
            else:
                outputs = model(images, ref_images, ref_heatmaps)
                #outputs = F.interpolate(outputs, size=heatmaps.shape[2:], mode='bilinear', align_corners=False)

                pred_coords = heatmaps_to_coordinates(outputs)
                gt_coords = heatmaps_to_coordinates(heatmaps)

                loss = criterion(pred_coords, gt_coords) / gradient_accumulation_steps
                loss.backward()

            current_loss = loss.item() * gradient_accumulation_steps
            train_loss += current_loss
            num_batches += 1

            # Update progress bar
            pbar.set_postfix({
                'batch_loss': f'{current_loss:.4f}',
                'avg_loss': f'{(train_loss / num_batches):.4f}'
            })

            # Gradient accumulation and optimizer
            if (i + 1) % gradient_accumulation_steps == 0:
                if scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

        avg_train_loss = train_loss / num_batches
        train_losses.append(avg_train_loss)

        # Validation
        avg_val_loss = validate_epoch(model, val_loader, criterion, device)
        val_losses.append(avg_val_loss)

        # Learning rate scheduling
        scheduler.step(avg_val_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'config': {
                'num_hourglass_modules': num_hourglass_modules,
                'num_channels': num_channels,
                'learning_rate': learning_rate,
                'batch_size': batch_size
            }
        }

        if scaler:
            checkpoint['scaler_state_dict'] = scaler.state_dict()

        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(checkpoint, checkpoint_path)

        # Save best model to drive
        if avg_val_loss < best_val_loss:
            print(f"Validation loss improved from {best_val_loss:.4f} to {avg_val_loss:.4f}. Saving model...")
            best_val_loss = avg_val_loss
            best_model = model.state_dict().copy()
            torch.save(model.state_dict(), model_save_path)

            # Also save best checkpoint
            best_checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
            torch.save(checkpoint, best_checkpoint_path)

    print("\nTraining Finished!")

    # Load best model for return
    if best_model is not None:
        model.load_state_dict(best_model)

    # Save final training stats
    training_stats = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss,
        'final_epoch': num_epochs
    }

    stats_path = os.path.join(checkpoint_dir, 'training_stats.json')
    with open(stats_path, 'w') as f:
        json.dump(training_stats, f, indent=2)

    return model, best_val_loss, train_losses, val_losses

In [None]:
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

In [None]:


model = None
torch.cuda.empty_cache()

model, best_loss, refdata_train_losses, refdata_val_losses = train_stacked_hourglass_refdata_wflw(
    img_dir=IMG_DIR,
    landmark_file=TRAIN_LANDMARK_FILE_PATH,
    model_save_path="/content/drive/MyDrive/refdata.pth",
    train_split_ratio=0.8,
    batch_size=32,
    learning_rate=1e-4,
    num_epochs=50,
    num_hourglass_modules=2,
    num_channels=128,
    mixed_precision = True
)

## Evaluation

In [None]:
def calculate_nme(pred_landmarks, gt_landmarks, return_all=False):
    batch_nme = []

    for pred_lms, gt_lms in zip(pred_landmarks, gt_landmarks):

        # Inter-ocular normalization
        gt_left_eye = gt_lms[60]
        gt_right_eye = gt_lms[72]

        d_norm = sqrt(
            (gt_left_eye[0] - gt_right_eye[0]) ** 2 +
            (gt_left_eye[1] - gt_right_eye[1]) ** 2
        )

        if d_norm == 0:
            continue  # skip invalid ones

        total_error = 0.0
        for pred_lm, gt_lm in zip(pred_lms, gt_lms):
            total_error += sqrt(
                (pred_lm[0] - gt_lm[0]) ** 2 +
                (pred_lm[1] - gt_lm[1]) ** 2
            )

        nme_image = (total_error / len(gt_lms)) / d_norm
        batch_nme.append(nme_image)

    if not batch_nme:
        return [] if return_all else 0.0

    return batch_nme if return_all else np.mean(batch_nme)

def calculate_auc(errors, failure_threshold=0.1, step=0.0001):
    errors = np.array(errors)
    errors = np.clip(errors, 0, 1.0)  # ensure valid range

    x = np.arange(0, failure_threshold + step, step)
    y = np.array([np.mean(errors <= xx) for xx in x])

    auc = np.trapezoid(y, x) / failure_threshold  # normalize
    failure_rate = np.mean(errors > failure_threshold)

    return auc, failure_rate



def evaluate_nme_auc(model, data_loader, device):
    """
    Evaluation loop — computes NME and AUC (inter-ocular).
    """
    model.eval()
    all_nmes = []
    total_nme = 0.0
    num_samples = 0

    with torch.no_grad():
        for batch_idx, (images, gt_heatmaps, ref_images, ref_gt_heatmaps,
                        img_ids, gt_landmarks, ref_img_ids, ref_gt_landmarks) in enumerate(data_loader):

            images, gt_heatmaps = images.to(device), gt_heatmaps.to(device)
            ref_images, ref_gt_heatmaps = ref_images.to(device), ref_gt_heatmaps.to(device)

            pred_heatmaps = model(images, ref_images, ref_gt_heatmaps)
            if isinstance(pred_heatmaps, tuple):
                pred_heatmaps = pred_heatmaps[0]

            pred_landmarks = heatmaps_to_coordinates(pred_heatmaps.cpu())

            gt_landmarks_list = [lm.tolist() for lm in gt_landmarks.numpy()] \
                if isinstance(gt_landmarks, torch.Tensor) else gt_landmarks

            batch_nmes = calculate_nme(pred_landmarks, gt_landmarks_list, return_all=True)

            batch_avg_nme = np.mean(batch_nmes)
            batch_size = len(batch_nmes)

            all_nmes.extend(batch_nmes)
            total_nme += batch_avg_nme * batch_size
            num_samples += batch_size

            print(f"Batch {batch_idx+1}/{len(data_loader)} | Batch NME: {batch_avg_nme:.6f}")

    final_nme = total_nme / num_samples if num_samples > 0 else 0.0
    auc, failure_rate = calculate_auc(all_nmes, failure_threshold=0.1)

    print(f"\n--- Evaluation Complete ---")
    print(f"Samples: {num_samples}")
    print(f"Mean NME: {final_nme:.6f}")
    print(f"AUC (10%): {auc:.6f}")
    print(f"Failure Rate (>{0.1*100:.1f}%): {failure_rate*100:.2f}%")

    return final_nme, auc, failure_rate

In [None]:
model = None
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = StackedHourglassNet_ReferenceData(
    num_hourglass_modules=2,
    num_landmarks=98,
    num_channels=128
)
model.load_state_dict(torch.load("/content/drive/MyDrive/refdata_2.pth"))
model.eval()
model.to(device)

dataset = WFLWLandmarksDataset_ReferenceData(
    img_dir=IMG_DIR,
    landmark_file=TEST_LANDMARK_FILE_PATH,
    transform=transforms.Compose([
        transforms.Resize(IMAGE_RESIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]),
    #attr_map = [0, 0, 0, 0, 0, 1]
)

test_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

a = evaluate_nme_auc(model, test_loader, device)

## Visualize

In [None]:
def visualize_prediction(model, data_loader, device, sample_idx=0):

    model.eval()

    # Get batch
    batch = next(iter(data_loader))
    (images, gt_heatmaps, ref_images, ref_gt_heatmaps,
     img_ids, gt_landmarks, ref_img_ids, ref_gt_landmarks) = batch

    # Select sample
    if sample_idx >= len(images):
        sample_idx = 0
        print(f"Sample index out of range, using first sample instead")

    # Move to cuda
    images = images.to(device)
    ref_images = ref_images.to(device)
    ref_gt_heatmaps = ref_gt_heatmaps.to(device)

    with torch.no_grad():
        # Forward pass
        outputs = model(images, ref_images, ref_gt_heatmaps)
        outputs = F.interpolate(outputs, size=gt_heatmaps.shape[2:], mode='bilinear', align_corners=False)

        # Convert to pixel coords
        lm_pred = heatmaps_to_coordinates(outputs.cpu())

    # Select sample
    image = images[sample_idx].cpu().permute(1, 2, 0).numpy()
    pred_landmarks = lm_pred[sample_idx]
    gt_landmarks_sample = gt_landmarks[sample_idx]

    ref_image = ref_images[sample_idx].cpu().permute(1, 2, 0).numpy()
    ref_landmarks_sample = ref_gt_landmarks[sample_idx]

    # Denormalize images if in [-1, 1]
    def denormalize(img):
        if img.min() < 0:
            img = (img + 1) / 2
        return np.clip(img, 0, 1)

    image = denormalize(image)
    ref_image = denormalize(ref_image)

    # Calculate NME
    sample_nme = calculate_nme([pred_landmarks], [gt_landmarks_sample])

    # Create plot
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # Main image: GT + preds
    axes[0].imshow(image)
    axes[0].scatter([p[0] for p in gt_landmarks_sample], [p[1] for p in gt_landmarks_sample],
                    c='green', s=10, label='GT', alpha=0.7)
    axes[0].scatter([p[0] for p in pred_landmarks], [p[1] for p in pred_landmarks],
                    c='red', s=10, label='Preds', alpha=0.7)
    axes[0].set_title(f'Main Image (NME: {sample_nme:.4f})')
    axes[0].axis('off')
    axes[0].legend()

    # Reference image: ref landmarks
    axes[1].imshow(ref_image)
    axes[1].scatter([p[0] for p in ref_landmarks_sample], [p[1] for p in ref_landmarks_sample],
                    c='blue', s=10, label='Reference GT', alpha=0.7)
    axes[1].set_title('Reference Image')
    axes[1].axis('off')
    axes[1].legend()

    plt.tight_layout()
    plt.show()

    # Print info
    print(f"Sample ID: {img_ids[sample_idx]}")
    print(f"Reference Image ID: {ref_img_ids[sample_idx]}")
    print(f"Sample NME: {sample_nme:.6f}")
    print(f"Image shape: {image.shape}")
    print(f"Number of landmarks: {len(pred_landmarks)}")

    return {
        'image': image,
        'pred_landmarks': pred_landmarks,
        'gt_landmarks': gt_landmarks_sample,
        'nme': sample_nme,
        'image_id': img_ids[sample_idx],
        'ref_image': ref_image,
        'ref_landmarks': ref_landmarks_sample,
        'ref_image_id': ref_img_ids[sample_idx]
    }

In [None]:
# Initialize model
model = StackedHourglassNet_ReferenceData(
    num_hourglass_modules=2,
    num_landmarks=98,
    num_channels=128
)
model.load_state_dict(torch.load("/content/drive/MyDrive/refdata_2.pth"))
#model.load_state_dict(torch.load("/content/checkpoints/checkpoint_epoch_10.pth")["model_state_dict"])
model.eval()
model.to(device)

dataset = WFLWLandmarksDataset_ReferenceData(
    img_dir=IMG_DIR,
    landmark_file=TEST_LANDMARK_FILE_PATH,
    transform=transforms.Compose([
        transforms.Resize(IMAGE_RESIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),

    ]),
    attr_map = [0, 0, 1, 0, 0, 0]
)

test_loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

x = visualize_prediction(model, test_loader, device, sample_idx = 5)