# SimSiam + FastAP Retrieval Pipeline

End-to-end notebook for SimSiam pretraining, FastAP fine-tuning, embedding extraction, and retrieval evaluation. Designed for Colab execution with streamlined data loading.

In [None]:
# Install dependencies (execute in Colab as needed)
!pip install --quiet gdown pandas matplotlib h5py

In [None]:
# Environment setup, configuration, and reproducibility helpers
import json
import math
import random
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace

import h5py
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.models as tvm

BASE_DIR = Path('.').resolve()
DATA_DIR = BASE_DIR / 'data'
ANNOTATIONS_DIR = BASE_DIR / 'annotations'
SAVE_DIR = BASE_DIR / 'save'
RESULTS_DIR = SAVE_DIR / 'inference_results'
SIMSIAM_DIR = SAVE_DIR / 'simsiam_stage'
FASTAP_DIR = SAVE_DIR / 'fastap_stage'
for path in (DATA_DIR, ANNOTATIONS_DIR, SAVE_DIR, RESULTS_DIR, SIMSIAM_DIR, FASTAP_DIR):
    path.mkdir(parents=True, exist_ok=True)

SHARED_DATASET_URL = 'https://drive.google.com/file/d/1fa0gaEmbtGmqZ92L0EqzhH5LiMUAztix/view?usp=sharing'
SHARED_BREATH_HOLD_URL = ''  # 任意: invivo.jpg の共有リンク

DATASET_PATH = DATA_DIR / 'dataset.mat'
LABEL_CSV_PATH = ANNOTATIONS_DIR / 'dataset_labels.csv'
BREATH_HOLD_IMAGE_PATH = DATA_DIR / 'invivo.jpg'
CHECKPOINT_PATH = SAVE_DIR / 'latest.pth'

PATHS = SimpleNamespace(
    simsiam_dir=SIMSIAM_DIR,
    fastap_dir=FASTAP_DIR,
    results_dir=RESULTS_DIR,
    simsiam_latest=SIMSIAM_DIR / 'simsiam_latest.pth',
    fastap_checkpoint=FASTAP_DIR / 'fastap_stage.pth',
    embeddings=RESULTS_DIR / 'fastap_embeddings.npy',
    metrics=RESULTS_DIR / 'fastap_metrics.json',
    report=RESULTS_DIR / 'fastap_report.md',
    simsiam_log=SIMSIAM_DIR / 'simsiam_train_log.json',
    fastap_log=FASTAP_DIR / 'fastap_train_log.json',
)

CONFIG = SimpleNamespace(
    backbone='resnet18',
    image_key='Acq/Amp',
    image_axes=(0, 2, 1),
    normalize_255=False,
    image_size=224,
    proj_dim=2048,
    emb_dim=256,
    eval_batch_size=128,
    dataloader=SimpleNamespace(
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    ),
    simsiam=SimpleNamespace(
        run=True,
        epochs=100,
        batch_size=128,
        lr=0.05,
        momentum=0.9,
        weight_decay=1e-4,
        cosine_t_max=200,
        amp=True,
        resume=True,
        save_every=10,
    ),
    fastap=SimpleNamespace(
        run=True,
        epochs=10,
        batch_size=64,
        lr=1e-3,
        weight_decay=1e-4,
        optimizer='adamw',
        scheduler='cosine',
        freeze_backbone=True,
        freeze_projector=False,
        use_labels=False,
        grad_clip=1.0,
        amp=False,
        num_bins=100,
        sigma=0.05,
        epsilon=1e-8,
    ),
    retrieval=SimpleNamespace(
        topk=[1, 5, 10],
        distance='cosine',
    ),
    seeds=SimpleNamespace(base=42),
)

def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(CONFIG.seeds.base)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)

def ensure_data_availability() -> None:
    if not DATASET_PATH.exists():
        print(f'Downloading dataset to {DATASET_PATH} ...')
        cmd = ['gdown', '--fuzzy', SHARED_DATASET_URL, '-O', str(DATASET_PATH)]
        result = subprocess.run(cmd, check=False)
        if result.returncode != 0:
            raise RuntimeError('Failed to download dataset from Google Drive.')
    else:
        print('Dataset already present:', DATASET_PATH)

    if not LABEL_CSV_PATH.exists():
        alt_path = DATA_DIR / 'dataset_labels.csv'
        if alt_path.exists():
            alt_path.replace(LABEL_CSV_PATH)
            print('Moved pseudo labels to annotations/:', LABEL_CSV_PATH)
        else:
            raise FileNotFoundError('annotations/dataset_labels.csv が見つかりません。事前に生成または配置してください。')
    else:
        print('Pseudo labels found:', LABEL_CSV_PATH)

    if not BREATH_HOLD_IMAGE_PATH.exists():
        moved = False
        for candidate in (BASE_DIR / 'invivo.jpg',):
            if candidate.exists():
                candidate.replace(BREATH_HOLD_IMAGE_PATH)
                print('Moved breath-hold image to data/:', BREATH_HOLD_IMAGE_PATH)
                moved = True
                break
        if not moved and SHARED_BREATH_HOLD_URL:
            print(f'Downloading breath-hold reference to {BREATH_HOLD_IMAGE_PATH} ...')
            cmd = ['gdown', '--fuzzy', SHARED_BREATH_HOLD_URL, '-O', str(BREATH_HOLD_IMAGE_PATH)]
            result = subprocess.run(cmd, check=False)
            if result.returncode != 0:
                raise RuntimeError('Failed to download breath-hold reference image.')
    else:
        print('Breath-hold reference image located at', BREATH_HOLD_IMAGE_PATH)

ensure_data_availability()
print('Artifacts will be stored under', SAVE_DIR)


In [None]:
# Dataset utilities and transforms (preloaded for speed)
from typing import List, Optional, Sequence, Tuple

def _resolve_image_dataset(f: h5py.File, key: str) -> h5py.Dataset:
    if key in f:
        return f[key]
    for candidate in f.keys():
        if candidate.endswith(key):
            return f[candidate]
    raise KeyError(f"image_key '{key}' not found in {list(f.keys())}")


def _interpret_image_shape(
    shape: Tuple[int, ...],
    override: Optional[Tuple[int, ...]] = None,
) -> Tuple[int, int, int, int, Tuple[int, ...]]:
    if override is not None:
        axes = tuple(override)
        if len(axes) not in (3, 4):
            raise ValueError('image_axes override must have length 3 or 4')
        if len(axes) != len(shape):
            raise ValueError('image_axes override must match dataset rank')
        dims = [shape[a] for a in axes]
        if len(axes) == 4:
            n, c, h, w = dims
            return n, c, h, w, axes
        n, h, w = dims
        return n, 1, h, w, axes
    rank = len(shape)
    if rank == 4:
        candidates = [
            (0, 3, 1, 2),
            (0, 1, 2, 3),
            (3, 2, 0, 1),
            (3, 0, 1, 2),
        ]
        for axes in candidates:
            n = shape[axes[0]]
            c = shape[axes[1]]
            h = shape[axes[2]]
            w = shape[axes[3]]
            if all(v > 0 for v in (n, c, h, w)):
                return n, c, h, w, axes
        raise ValueError(f'Unable to infer N,C,H,W from shape {shape}')
    if rank == 3:
        candidates = [
            (0, 1, 2),
            (0, 2, 1),
            (2, 0, 1),
            (2, 1, 0),
            (1, 0, 2),
            (1, 2, 0),
        ]
        for axes in candidates:
            n = shape[axes[0]]
            h = shape[axes[1]]
            w = shape[axes[2]]
            if all(v > 0 for v in (n, h, w)):
                return n, 1, h, w, axes
        raise ValueError(f'Unable to infer N,H,W from shape {shape}')
    raise ValueError(f'Unsupported dataset rank {rank}; expected 3D or 4D')


def _to_nchw(arr: np.ndarray, axes: Tuple[int, ...]) -> np.ndarray:
    permuted = np.transpose(arr, axes)
    if permuted.ndim == 3:
        permuted = np.expand_dims(permuted, axis=1)
    return permuted

def load_frames_from_mat(
    mat_files: Sequence[str],
    image_key: str,
    image_axes: Optional[Tuple[int, ...]] = None,
    dtype: str = 'float32',
    normalize_255: bool = False,
) -> np.ndarray:
    buffers: List[np.ndarray] = []
    for path in mat_files:
        with h5py.File(path, 'r') as f:
            ds = _resolve_image_dataset(f, image_key)
            n, c, h, w, axes = _interpret_image_shape(ds.shape, override=image_axes)
            arr = np.asarray(ds)
            arr = _to_nchw(arr, axes)
            if arr.dtype.kind in ('u', 'i'):
                arr = arr.astype(dtype, copy=False)
                if normalize_255:
                    arr = arr / 255.0
            else:
                arr = arr.astype(dtype, copy=False)
            buffers.append(arr)
    if not buffers:
        raise ValueError('No frames were loaded from the provided .mat files.')
    stacked = np.concatenate(buffers, axis=0)
    return stacked


class PreloadedMatDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        mat_files: Sequence[str],
        image_key: str,
        image_axes: Optional[Tuple[int, ...]] = None,
        normalize_255: bool = False,
        transform=None,
        dtype: str = 'float32',
    ):
        self.data = load_frames_from_mat(
            mat_files, image_key, image_axes=image_axes, dtype=dtype, normalize_255=normalize_255
        )
        self.transform = transform

    def __len__(self) -> int:
        return self.data.shape[0]

    def get_raw(self, idx: int) -> np.ndarray:
        return self.data[idx]

    def __getitem__(self, idx: int):
        chw = self.get_raw(idx)
        if self.transform is not None:
            return self.transform(chw)
        return torch.from_numpy(chw)

    def view(self, transform):
        return PreloadedDatasetView(self, transform)


class PreloadedDatasetView(torch.utils.data.Dataset):
    def __init__(self, base: PreloadedMatDataset, transform):
        self.base = base
        self.transform = transform

    def __len__(self) -> int:
        return len(self.base)

    def __getitem__(self, idx: int):
        chw = self.base.get_raw(idx)
        if self.transform is not None:
            return self.transform(chw)
        return torch.from_numpy(chw)


def bmode_normalize(chw: np.ndarray) -> np.ndarray:
    x = np.abs(chw.astype(np.float32))
    x = x / (x.max() + 1e-12)
    x = 20.0 * np.log10(x + 1e-12)
    x = np.clip((x + 60.0) / 60.0, 0.0, 1.0)
    return x


def to_pil_3ch_from_chw01(chw01: np.ndarray) -> Image.Image:
    if chw01.shape[0] == 1:
        chw01 = np.repeat(chw01, 3, axis=0)
    hwc = np.transpose(chw01, (1, 2, 0))
    hwc255 = (hwc * 255.0).astype(np.uint8)
    return Image.fromarray(hwc255)


def build_simsiam_augmentation(image_size: int = 224) -> T.Compose:
    color_jitter = T.ColorJitter(0.4, 0.4, 0.4, 0.1)
    blur_kernel = max(3, int(image_size * 0.1))
    if blur_kernel % 2 == 0:
        blur_kernel += 1
    return T.Compose([
        T.RandomResizedCrop(image_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC),
        T.RandomHorizontalFlip(),
        T.RandomApply([color_jitter], p=0.8),
        T.RandomGrayscale(p=0.2),
        T.RandomApply([T.GaussianBlur(kernel_size=blur_kernel, sigma=(0.1, 2.0))], p=0.5),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


class TwoViewTransform:
    def __init__(self, transform_q: T.Compose, transform_k: Optional[T.Compose] = None):
        self.transform_q = transform_q
        self.transform_k = transform_k or transform_q

    def __call__(self, chw: np.ndarray):
        chw01 = bmode_normalize(chw)
        pil = to_pil_3ch_from_chw01(chw01)
        view_q = self.transform_q(pil)
        view_k = self.transform_k(pil)
        return view_q, view_k


class TwoViewDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset: PreloadedMatDataset, pair_transform: TwoViewTransform):
        self.base_dataset = base_dataset
        self.pair_transform = pair_transform

    def __len__(self) -> int:
        return len(self.base_dataset)

    def __getitem__(self, idx: int):
        chw = self.base_dataset.get_raw(idx)
        return self.pair_transform(chw)


class FastAPPairDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset: PreloadedMatDataset, pair_transform: TwoViewTransform, labels: Optional[np.ndarray] = None):
        self.base_dataset = base_dataset
        self.pair_transform = pair_transform
        if labels is not None and len(labels) != len(base_dataset):
            raise ValueError('Labels length must match dataset length.')
        self.labels = labels

    def __len__(self) -> int:
        return len(self.base_dataset)

    def __getitem__(self, idx: int):
        chw = self.base_dataset.get_raw(idx)
        view_q, view_k = self.pair_transform(chw)
        label = -1 if self.labels is None else int(self.labels[idx])
        return view_q, view_k, idx, label


def build_eval_transform(image_size: int = 224) -> T.Compose:
    return T.Compose([
        T.Lambda(lambda x: to_pil_3ch_from_chw01(bmode_normalize(x))),
        T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(image_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

eval_transform = build_eval_transform(CONFIG.image_size)


In [None]:
# Model components, losses, and training utilities
from typing import Any, Dict, Optional, Sequence

class Projector(nn.Module):
    def __init__(self, in_dim: int, hid_dim: int, out_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim, bias=False),
            nn.BatchNorm1d(hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class Predictor(nn.Module):
    def __init__(self, in_dim: int, hid_dim: int, out_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim, bias=False),
            nn.BatchNorm1d(hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class EmbeddingHead(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor:
        z = self.net(x)
        if normalize:
            z = F.normalize(z, dim=1)
        return z


class FastAPModel(nn.Module):
    def __init__(self, backbone: nn.Module, projector: Projector, embedding_head: EmbeddingHead, predictor: Predictor):
        super().__init__()
        self.backbone = backbone
        self.projector = projector
        self.predictor = predictor
        self.embedding_head = embedding_head

    def forward_backbone(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.backbone(x)
        return torch.flatten(feats, 1)

    def forward_projector(self, x: torch.Tensor, normalize: bool = False) -> torch.Tensor:
        h = self.forward_backbone(x)
        z = self.projector(h)
        if normalize:
            z = F.normalize(z, dim=1)
        return z

    def forward_embedding(self, x: torch.Tensor) -> torch.Tensor:
        z = self.forward_projector(x, normalize=False)
        emb = self.embedding_head(z, normalize=True)
        return emb

    def encode(self, x: torch.Tensor, stage: str = 'embedding') -> torch.Tensor:
        if stage == 'backbone':
            return self.forward_backbone(x)
        if stage == 'projector':
            return self.forward_projector(x, normalize=True)
        if stage == 'embedding':
            return self.forward_embedding(x)
        raise ValueError(f'Unknown encode stage: {stage}')


class SimSiam(nn.Module):
    def __init__(self, backbone_name: str = 'resnet18', proj_dim: int = 2048):
        super().__init__()
        if backbone_name == 'resnet18':
            backbone = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1)
            feat_dim = 512
        elif backbone_name == 'resnet50':
            backbone = tvm.resnet50(weights=tvm.ResNet50_Weights.IMAGENET1K_V2)
            feat_dim = 2048
        else:
            raise ValueError(f'Unsupported backbone: {backbone_name}')
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        self.feat_dim = feat_dim
        self.projector = Projector(in_dim=feat_dim, hid_dim=proj_dim, out_dim=proj_dim)
        self.predictor = Predictor(in_dim=proj_dim, hid_dim=512, out_dim=proj_dim)

    def forward_backbone(self, x: torch.Tensor) -> torch.Tensor:
        h = self.backbone(x)
        h = torch.flatten(h, 1)
        return h

    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        h1 = self.forward_backbone(x1)
        z1 = self.projector(h1)
        p1 = self.predictor(z1)
        h2 = self.forward_backbone(x2)
        z2 = self.projector(h2)
        p2 = self.predictor(z2)
        return p1, z1, p2, z2


def build_fastap_model(backbone_name: str, proj_dim: int, emb_dim: int) -> FastAPModel:
    simsiam = SimSiam(backbone_name=backbone_name, proj_dim=proj_dim)
    embedding_head = EmbeddingHead(in_dim=proj_dim, hidden_dim=proj_dim, out_dim=emb_dim)
    return FastAPModel(simsiam.backbone, simsiam.projector, embedding_head, simsiam.predictor)


def build_simsiam(backbone_name: str, proj_dim: int) -> SimSiam:
    return SimSiam(backbone_name=backbone_name, proj_dim=proj_dim)


def negative_cosine(p: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
    z = z.detach()
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()


def make_positive_mask(ids_or_labels: torch.Tensor) -> torch.BoolTensor:
    ids = ids_or_labels.view(-1, 1)
    mask = ids == ids.T
    diag = torch.eye(ids.size(0), dtype=torch.bool, device=ids.device)
    mask = mask & ~diag
    return mask


def fastap_loss(
    embeddings: torch.Tensor,
    pos_mask: torch.BoolTensor,
    *,
    num_bins: int = 50,
    sigma: float = 0.05,
    epsilon: float = 1e-8,
) -> torch.Tensor:
    if embeddings.dim() != 2:
        raise ValueError('Expected embeddings of shape [N, D].')
    embeddings = F.normalize(embeddings, dim=1)
    device = embeddings.device
    n = embeddings.size(0)

    sim = embeddings @ embeddings.T
    diag = torch.eye(n, device=device, dtype=torch.bool)
    sim = sim.masked_fill(diag, -1.0)

    pos_mask = pos_mask.to(device) & ~diag

    bin_centers = torch.linspace(-1.0, 1.0, steps=num_bins, device=device)
    bin_width = 2.0 / num_bins
    half_width = bin_width / 2.0

    s = sim.unsqueeze(-1)
    c = bin_centers.view(1, 1, -1)
    assign = torch.sigmoid((s - (c - half_width)) / sigma) - torch.sigmoid((s - (c + half_width)) / sigma)
    assign = assign.clamp(min=0.0)

    valid_mask = (~diag).unsqueeze(-1)
    assign = assign * valid_mask

    pos_weights = assign * pos_mask.unsqueeze(-1).float()
    all_weights = assign

    H = all_weights.sum(dim=1)
    P = pos_weights.sum(dim=1)

    cum_H = torch.cumsum(H, dim=1)
    cum_P = torch.cumsum(P, dim=1)

    pos_counts = pos_mask.float().sum(dim=1)
    valid_queries = pos_counts > 0

    precision = (cum_P + epsilon) / (cum_H + epsilon)
    recall = (cum_P + epsilon) / (pos_counts.unsqueeze(-1) + epsilon)
    delta_recall = recall[:, :1]
    if recall.size(1) > 1:
        delta_recall = torch.cat([delta_recall, recall[:, 1:] - recall[:, :-1]], dim=1)

    ap = (precision * delta_recall).sum(dim=1)
    ap = ap[valid_queries]

    if ap.numel() == 0:
        return embeddings.new_tensor(0.0, requires_grad=True)
    return -ap.mean()


@dataclass
class TrainLog:
    epochs: List[int]
    losses: List[float]
    lrs: List[float]
    elapsed: List[float]
    config: Dict[str, Any]

    def to_dict(self) -> Dict[str, Any]:
        return {
            'epochs': self.epochs,
            'losses': self.losses,
            'lrs': self.lrs,
            'elapsed': self.elapsed,
            'config': self.config,
        }


def train_simsiam_stage(
    dataloader: DataLoader,
    model: SimSiam,
    optimizer: torch.optim.Optimizer,
    scheduler,
    epochs: int,
    device: torch.device,
    *,
    amp_enabled: bool = True,
    log_every: int = 50,
) -> TrainLog:
    scaler = torch.cuda.amp.GradScaler(enabled=(amp_enabled and device.type == 'cuda'))
    history = TrainLog(epochs=[], losses=[], lrs=[], elapsed=[], config={'stage': 'simsiam'})
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        start = time.time()
        for step, (view_q, view_k) in enumerate(dataloader, start=1):
            view_q = view_q.to(device, non_blocking=True)
            view_k = view_k.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(amp_enabled and device.type == 'cuda')):
                p1, z1, p2, z2 = model(view_q, view_k)
                loss = 0.5 * negative_cosine(p1, z2) + 0.5 * negative_cosine(p2, z1)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += float(loss.item())
            if log_every and step % log_every == 0:
                avg = running_loss / step
                print(f'Epoch {epoch:02d} step {step:04d}/{len(dataloader):04d} | loss={avg:.4f}')
        if scheduler is not None:
            scheduler.step()
        elapsed = time.time() - start
        avg_loss = running_loss / max(1, len(dataloader))
        lr = optimizer.param_groups[0]['lr']
        history.epochs.append(epoch)
        history.losses.append(avg_loss)
        history.lrs.append(lr)
        history.elapsed.append(elapsed)
        print(f'Epoch {epoch:02d}/{epochs:02d} completed | loss={avg_loss:.4f} | lr={lr:.3e} | time={elapsed:.1f}s')
    return history


def train_fastap_stage(
    dataloader: DataLoader,
    model: FastAPModel,
    optimizer: torch.optim.Optimizer,
    scheduler,
    epochs: int,
    cfg,
    *,
    log_every: int = 50,
) -> TrainLog:
    device = next(model.parameters()).device
    amp_enabled = getattr(cfg, 'amp', False) and device.type == 'cuda'
    scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled)
    history = TrainLog(epochs=[], losses=[], lrs=[], elapsed=[], config={
        'stage': 'fastap',
        'num_bins': cfg.num_bins,
        'sigma': cfg.sigma,
        'epsilon': getattr(cfg, 'epsilon', 1e-8),
        'use_labels': cfg.use_labels,
    })
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        start = time.time()
        for step, (view_q, view_k, idx, labels) in enumerate(dataloader, start=1):
            view_q = view_q.to(device, non_blocking=True)
            view_k = view_k.to(device, non_blocking=True)
            ids = idx.to(device, non_blocking=True)
            labels_tensor = labels.to(device, non_blocking=True)
            use_labels = cfg.use_labels and (labels_tensor >= 0).any()
            if use_labels:
                labels_tensor = torch.where(labels_tensor >= 0, labels_tensor, ids)
                pair_ids = torch.cat([labels_tensor, labels_tensor], dim=0)
            else:
                pair_ids = torch.cat([ids, ids], dim=0)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=amp_enabled):
                emb_q = model.forward_embedding(view_q)
                emb_k = model.forward_embedding(view_k)
                embeddings = torch.cat([emb_q, emb_k], dim=0)
                pos_mask = make_positive_mask(pair_ids)
                loss = fastap_loss(
                    embeddings,
                    pos_mask,
                    num_bins=cfg.num_bins,
                    sigma=cfg.sigma,
                    epsilon=getattr(cfg, 'epsilon', 1e-8),
                )
            if amp_enabled:
                scaler.scale(loss).backward()
                if cfg.grad_clip is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if cfg.grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                optimizer.step()
            if scheduler is not None and getattr(cfg, 'scheduler_step', 'epoch') == 'step':
                scheduler.step()
            running_loss += float(loss.item())
            if log_every and step % log_every == 0:
                avg = running_loss / step
                print(f'Epoch {epoch:02d} step {step:04d}/{len(dataloader):04d} | loss={avg:.4f}')
        if scheduler is not None and getattr(cfg, 'scheduler_step', 'epoch') == 'epoch':
            scheduler.step()
        elapsed = time.time() - start
        avg_loss = running_loss / max(1, len(dataloader))
        lr = optimizer.param_groups[0]['lr']
        history.epochs.append(epoch)
        history.losses.append(avg_loss)
        history.lrs.append(lr)
        history.elapsed.append(elapsed)
        print(f'Epoch {epoch:02d}/{epochs:02d} completed | loss={avg_loss:.4f} | lr={lr:.3e} | time={elapsed:.1f}s')
    return history


def evaluate_map(embeddings: np.ndarray, labels: np.ndarray, topk: Optional[Sequence[int]] = None) -> Dict[str, Any]:
    if embeddings.ndim != 2:
        raise ValueError('Expected embeddings with shape [N, D].')
    if labels.ndim != 1 or labels.shape[0] != embeddings.shape[0]:
        raise ValueError('Labels must have shape [N].')
    embeddings = embeddings.astype(np.float32)
    labels = labels.astype(np.int64)
    sim = embeddings @ embeddings.T
    np.fill_diagonal(sim, -np.inf)
    order = np.argsort(-sim, axis=1)
    ap_list = []
    precision_at_k: Dict[int, float] = {}
    recall_at_k: Dict[int, float] = {}
    if topk is None:
        topk = [1, 5, 10]
    topk = sorted(set(int(k) for k in topk if k > 0))
    for i in range(sim.shape[0]):
        rel = (labels == labels[i]).astype(np.int32)
        rel[i] = 0
        ranked = order[i]
        rel_ranked = rel[ranked]
        n_rel = rel_ranked.sum()
        if n_rel == 0:
            ap_list.append(0.0)
        else:
            cumsum = np.cumsum(rel_ranked)
            precision = cumsum / (np.arange(len(rel_ranked)) + 1)
            ap = (precision * rel_ranked).sum() / n_rel
            ap_list.append(float(ap))
        for k in topk:
            hits = rel_ranked[:k]
            h_sum = hits.sum()
            precision_at_k.setdefault(k, 0.0)
            recall_at_k.setdefault(k, 0.0)
            precision_at_k[k] += float(h_sum) / k
            if n_rel > 0:
                recall_at_k[k] += float(h_sum) / n_rel
    num_queries = sim.shape[0]
    metrics = {
        'mAP': float(np.mean(ap_list)),
        'per_query_ap': ap_list,
        'precision_at_k': {k: precision_at_k[k] / num_queries for k in topk},
        'recall_at_k': {k: recall_at_k[k] / num_queries for k in topk},
    }
    return metrics


def extract_embeddings(model: FastAPModel, loader: DataLoader, device: torch.device, stage: str = 'embedding') -> np.ndarray:
    model.eval()
    outputs: List[np.ndarray] = []
    with torch.no_grad():
        for xb in loader:
            xb = xb.to(device, non_blocking=True)
            feats = model.encode(xb, stage=stage)
            outputs.append(feats.cpu().numpy())
    return np.concatenate(outputs, axis=0)


In [None]:
# Build datasets, run SimSiam pretraining and FastAP fine-tuning
mat_files = [str(DATASET_PATH)]
raw_dataset = PreloadedMatDataset(
    mat_files,
    image_key=CONFIG.image_key,
    image_axes=CONFIG.image_axes,
    normalize_255=CONFIG.normalize_255,
    transform=None,
)
print('Loaded frames:', len(raw_dataset))

labels_df = pd.read_csv(LABEL_CSV_PATH).set_index('frame').sort_index()
frame_labels = labels_df.loc[range(len(raw_dataset)), 'label'].to_numpy().astype(np.int64)

simsiam_log = None
fastap_log = None

# Stage 1: SimSiam pretraining
if CONFIG.simsiam.run:
    pair_transform = TwoViewTransform(build_simsiam_augmentation(CONFIG.image_size))
    simsiam_dataset = TwoViewDataset(raw_dataset, pair_transform)
    simsiam_loader = DataLoader(
        simsiam_dataset,
        batch_size=CONFIG.simsiam.batch_size,
        shuffle=True,
        drop_last=CONFIG.dataloader.drop_last,
        num_workers=CONFIG.dataloader.num_workers,
        pin_memory=CONFIG.dataloader.pin_memory,
    )
    if len(simsiam_loader) == 0:
        raise RuntimeError('SimSiam DataLoader returned 0 batches. Reduce batch size or disable drop_last.')
    simsiam_model = build_simsiam(CONFIG.backbone, CONFIG.proj_dim).to(DEVICE)
    resume_path = PATHS.simsiam_latest
    if CONFIG.simsiam.resume and resume_path.exists():
        ckpt = torch.load(resume_path, map_location=DEVICE)
        state_dict = ckpt.get('state_dict', ckpt)
        load_result = simsiam_model.load_state_dict(state_dict, strict=False)
        print('Resumed SimSiam checkpoint:', resume_path)
        if getattr(load_result, 'missing_keys', None):
            print('  Missing keys:', load_result.missing_keys)
        if getattr(load_result, 'unexpected_keys', None):
            print('  Unexpected keys:', load_result.unexpected_keys)
    optimizer = torch.optim.SGD(
        simsiam_model.parameters(),
        lr=CONFIG.simsiam.lr,
        momentum=CONFIG.simsiam.momentum,
        weight_decay=CONFIG.simsiam.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG.simsiam.cosine_t_max)
    log_every = max(1, len(simsiam_loader) // 5)
    simsiam_log = train_simsiam_stage(
        dataloader=simsiam_loader,
        model=simsiam_model,
        optimizer=optimizer,
        scheduler=scheduler,
        epochs=CONFIG.simsiam.epochs,
        device=DEVICE,
        amp_enabled=CONFIG.simsiam.amp,
        log_every=log_every,
    )
    torch.save(simsiam_log.to_dict(), PATHS.simsiam_log)
    torch.save({'epoch': simsiam_log.epochs[-1] if simsiam_log.epochs else 0, 'state_dict': simsiam_model.state_dict()}, PATHS.simsiam_latest)
    torch.save({'state_dict': simsiam_model.state_dict()}, CHECKPOINT_PATH)
    print('SimSiam checkpoint saved to', PATHS.simsiam_latest)
else:
    print('Skipping SimSiam pretraining stage.')
    simsiam_model = None

# Load SimSiam weights for FastAP initialisation
simsim_state_dict = None
if CONFIG.simsiam.run and simsiam_model is not None:
    simsim_state_dict = simsiam_model.state_dict()
else:
    fallback_paths = [PATHS.simsiam_latest, CHECKPOINT_PATH]
    for candidate in fallback_paths:
        if candidate.exists():
            ckpt = torch.load(candidate, map_location=DEVICE)
            simsim_state_dict = ckpt.get('state_dict', ckpt)
            print('Using SimSiam state from', candidate)
            break

# Stage 2: FastAP fine-tuning
model = build_fastap_model(CONFIG.backbone, CONFIG.proj_dim, CONFIG.emb_dim).to(DEVICE)
if simsim_state_dict is not None:
    load_result = model.load_state_dict(simsim_state_dict, strict=False)
    print('Loaded SimSiam weights into FastAP model.')
    if getattr(load_result, 'missing_keys', None):
        print('  Missing keys:', load_result.missing_keys)
    if getattr(load_result, 'unexpected_keys', None):
        print('  Unexpected keys:', load_result.unexpected_keys)
else:
    print('Warning: no SimSiam weights available; FastAP model initialised randomly.')

if CONFIG.fastap.freeze_backbone:
    for param in model.backbone.parameters():
        param.requires_grad = False
if CONFIG.fastap.freeze_projector:
    for param in model.projector.parameters():
        param.requires_grad = False

trainable_params = [p for p in model.parameters() if p.requires_grad]
print('Trainable parameter count:', sum(p.numel() for p in trainable_params))

if CONFIG.fastap.run:
    pair_transform = TwoViewTransform(build_simsiam_augmentation(CONFIG.image_size))
    labels_for_training = frame_labels if CONFIG.fastap.use_labels else None
    fastap_dataset = FastAPPairDataset(raw_dataset, pair_transform, labels=labels_for_training)
    train_loader = DataLoader(
        fastap_dataset,
        batch_size=CONFIG.fastap.batch_size,
        shuffle=True,
        drop_last=CONFIG.dataloader.drop_last,
        num_workers=CONFIG.dataloader.num_workers,
        pin_memory=CONFIG.dataloader.pin_memory,
    )
    if len(train_loader) == 0:
        raise RuntimeError('FastAP DataLoader returned 0 batches. Reduce batch size or disable drop_last.')
    if not trainable_params:
        raise RuntimeError('No parameters left to optimise after freezing selections.')
    if CONFIG.fastap.optimizer.lower() == 'adamw':
        optimizer = torch.optim.AdamW(trainable_params, lr=CONFIG.fastap.lr, weight_decay=CONFIG.fastap.weight_decay)
    else:
        optimizer = torch.optim.SGD(trainable_params, lr=CONFIG.fastap.lr, momentum=0.9, weight_decay=CONFIG.fastap.weight_decay)
    scheduler = None
    scheduler_step = 'epoch'
    if CONFIG.fastap.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG.fastap.epochs)
        scheduler_step = 'epoch'
    fastap_cfg = SimpleNamespace(
        num_bins=CONFIG.fastap.num_bins,
        sigma=CONFIG.fastap.sigma,
        epsilon=CONFIG.fastap.epsilon,
        use_labels=CONFIG.fastap.use_labels,
        grad_clip=CONFIG.fastap.grad_clip,
        amp=CONFIG.fastap.amp,
        scheduler_step=scheduler_step,
    )
    log_every = max(1, len(train_loader) // 5)
    fastap_log = train_fastap_stage(
        dataloader=train_loader,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        epochs=CONFIG.fastap.epochs,
        cfg=fastap_cfg,
        log_every=log_every,
    )
    torch.save(fastap_log.to_dict(), PATHS.fastap_log)
    torch.save(
        {
            'model_state': model.state_dict(),
            'config': {
                'proj_dim': CONFIG.proj_dim,
                'emb_dim': CONFIG.emb_dim,
                'fastap': vars(CONFIG.fastap),
                'backbone': CONFIG.backbone,
            },
        },
        PATHS.fastap_checkpoint,
    )
    torch.save({'state_dict': model.state_dict()}, CHECKPOINT_PATH)
    print('FastAP checkpoint saved to', PATHS.fastap_checkpoint)
else:
    print('Skipping FastAP fine-tuning stage.')

model.eval()


In [None]:
# Extract embeddings and evaluate retrieval metrics
eval_dataset = raw_dataset.view(eval_transform)
eval_loader = DataLoader(
    eval_dataset,
    batch_size=CONFIG.eval_batch_size,
    shuffle=False,
    num_workers=CONFIG.dataloader.num_workers,
    pin_memory=CONFIG.dataloader.pin_memory,
)

embeddings = extract_embeddings(model, eval_loader, DEVICE, stage='embedding')
np.save(PATHS.embeddings, embeddings)
print('Embeddings shape:', embeddings.shape)
print('Saved embeddings to', PATHS.embeddings)

labels_for_eval = frame_labels[: len(eval_dataset)].astype(np.int64)
metrics = evaluate_map(embeddings, labels_for_eval, topk=CONFIG.retrieval.topk)
with open(PATHS.metrics, 'w', encoding='utf-8') as f:
    json.dump(metrics, f, indent=2)
print('mAP:', metrics['mAP'])
print('Metrics saved to', PATHS.metrics)


In [None]:
# Similar-frame retrieval helpers and example queries
from typing import Tuple

TOPK = max(CONFIG.retrieval.topk)

retrieval_transform = T.Compose([
    T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(CONFIG.image_size),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def preprocess_external_image(path: str) -> torch.Tensor:
    pil = Image.open(path).convert('RGB')
    return retrieval_transform(pil)

def embed_tensor(model: FastAPModel, img: torch.Tensor) -> np.ndarray:
    model.eval()
    with torch.no_grad():
        out = model.forward_embedding(img.unsqueeze(0).to(DEVICE))
    return out.cpu().numpy()[0]

def search_similar(emb_matrix: np.ndarray, query_emb: np.ndarray, topk: int) -> Tuple[np.ndarray, np.ndarray]:
    emb_norm = emb_matrix / (np.linalg.norm(emb_matrix, axis=1, keepdims=True) + 1e-12)
    query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-12)
    scores = emb_norm @ query_norm
    top_indices = np.argsort(-scores)[:topk]
    return top_indices, scores[top_indices]

retrieval_summary = {
    'query_index': 0,
    'internal': [],
    'external_query': None,
}

query_index = retrieval_summary['query_index']
indices, scores = search_similar(embeddings, embeddings[query_index], topk=min(TOPK, len(embeddings)))
print(f'Query frame index: {query_index}')
for rank, (idx, score) in enumerate(zip(indices, scores), start=1):
    label = int(frame_labels[idx]) if idx < len(frame_labels) else -1
    print(f'Rank {rank:02d}: frame={idx} | score={score:.4f} | label={label}')
    retrieval_summary['internal'].append({
        'rank': int(rank),
        'frame': int(idx),
        'score': float(score),
        'label': int(label),
    })

if BREATH_HOLD_IMAGE_PATH.exists():
    external_tensor = preprocess_external_image(str(BREATH_HOLD_IMAGE_PATH))
    external_embedding = embed_tensor(model, external_tensor)
    ext_indices, ext_scores = search_similar(embeddings, external_embedding, topk=min(TOPK, len(embeddings)))
    ext_results = []
    print('\nExternal query (invivo.jpg) top matches:')
    for rank, (idx, score) in enumerate(zip(ext_indices, ext_scores), start=1):
        label = int(frame_labels[idx]) if idx < len(frame_labels) else -1
        print(f'Rank {rank:02d}: frame={idx} | score={score:.4f} | label={label}')
        ext_results.append({
            'rank': int(rank),
            'frame': int(idx),
            'score': float(score),
            'label': int(label),
        })
    retrieval_summary['external_query'] = {
        'path': str(BREATH_HOLD_IMAGE_PATH),
        'results': ext_results,
    }
else:
    warning_msg = 'Breath-hold reference image is missing; external query skipped.'
    print(warning_msg)
    retrieval_summary['external_query'] = {
        'path': str(BREATH_HOLD_IMAGE_PATH),
        'warning': warning_msg,
    }


In [None]:
# Markdown report generation
from datetime import datetime

if 'metrics' not in globals():
    raise RuntimeError('Please run the evaluation cell before generating the report.')

report_lines = []
report_lines.append('# FastAP Stage Report')
report_lines.append('')
report_lines.append(f'- generated_at: {datetime.utcnow().isoformat()}Z')
report_lines.append(f'- device: {DEVICE}')
report_lines.append(f'- checkpoint_path: {CHECKPOINT_PATH}')
report_lines.append(f'- embeddings_path: {PATHS.embeddings}')
report_lines.append(f'- metrics_path: {PATHS.metrics}')
report_lines.append('')

report_lines.append('## Config')
report_lines.append('')
report_lines.append(f'- backbone: {CONFIG.backbone}')
report_lines.append(f'- projector_dim: {CONFIG.proj_dim}')
report_lines.append(f'- embedding_dim: {CONFIG.emb_dim}')
report_lines.append(f'- simsiam_epochs: {CONFIG.simsiam.epochs}')
report_lines.append(f'- fastap_epochs: {CONFIG.fastap.epochs}')
report_lines.append(f'- fastap_bins: {CONFIG.fastap.num_bins}')
report_lines.append(f'- fastap_sigma: {CONFIG.fastap.sigma}')
report_lines.append(f'- freeze_backbone: {CONFIG.fastap.freeze_backbone}')
report_lines.append(f'- freeze_projector: {CONFIG.fastap.freeze_projector}')
report_lines.append(f'- use_labels: {CONFIG.fastap.use_labels}')
report_lines.append('')

report_lines.append('## SimSiam Training History')
if simsiam_log is None:
    report_lines.append('SimSiam stage was skipped or not executed in this run.')
else:
    report_lines.append('')
    report_lines.append(f'- epochs_trained: {len(simsiam_log.epochs)}')
    final_loss = simsiam_log.losses[-1] if simsiam_log.losses else float('nan')
    best_loss = min(simsiam_log.losses) if simsiam_log.losses else float('nan')
    total_time = sum(simsiam_log.elapsed) if simsiam_log.elapsed else 0.0
    report_lines.append(f'- final_loss: {final_loss:.4f}')
    report_lines.append(f'- best_loss: {best_loss:.4f}')
    report_lines.append(f'- total_time_sec: {total_time:.1f}')
    report_lines.append('')
    report_lines.append('| Epoch | Loss | LR | Time (s) |')
    report_lines.append('| ----- | ---- | -- | -------- |')
    for epoch, loss, lr, elapsed in zip(simsiam_log.epochs, simsiam_log.losses, simsiam_log.lrs, simsiam_log.elapsed):
        report_lines.append(f'| {epoch} | {loss:.4f} | {lr:.3e} | {elapsed:.1f} |')
report_lines.append('')

report_lines.append('## FastAP Training History')
if fastap_log is None:
    report_lines.append('FastAP stage was skipped or not executed in this run.')
else:
    report_lines.append('')
    report_lines.append(f'- epochs_trained: {len(fastap_log.epochs)}')
    final_loss = fastap_log.losses[-1] if fastap_log.losses else float('nan')
    best_loss = min(fastap_log.losses) if fastap_log.losses else float('nan')
    total_time = sum(fastap_log.elapsed) if fastap_log.elapsed else 0.0
    report_lines.append(f'- final_loss: {final_loss:.4f}')
    report_lines.append(f'- best_loss: {best_loss:.4f}')
    report_lines.append(f'- total_time_sec: {total_time:.1f}')
    report_lines.append('')
    report_lines.append('| Epoch | Loss | LR | Time (s) |')
    report_lines.append('| ----- | ---- | -- | -------- |')
    for epoch, loss, lr, elapsed in zip(fastap_log.epochs, fastap_log.losses, fastap_log.lrs, fastap_log.elapsed):
        report_lines.append(f'| {epoch} | {loss:.4f} | {lr:.3e} | {elapsed:.1f} |')
report_lines.append('')

report_lines.append('## Evaluation Metrics')
report_lines.append('')
report_lines.append(f'- mAP: {metrics["mAP"]:.4f}')
if metrics.get('precision_at_k'):
    report_lines.append('')
    report_lines.append('| K | Precision | Recall |')
    report_lines.append('| - | --------- | ------ |')
    for k in sorted(metrics['precision_at_k'].keys()):
        prec = metrics['precision_at_k'][k]
        rec = metrics['recall_at_k'].get(k, float('nan'))
        report_lines.append(f'| {k} | {prec:.4f} | {rec:.4f} |')
report_lines.append('')

if retrieval_summary:
    report_lines.append('## Retrieval Examples')
    report_lines.append('')
    report_lines.append(f"### Internal query (index={retrieval_summary.get('query_index')})")
    internal = retrieval_summary.get('internal', [])
    if internal:
        report_lines.append('| Rank | Frame | Score | Label |')
        report_lines.append('| ---- | ----- | ----- | ----- |')
        for row in internal:
            report_lines.append(f"| {row['rank']} | {row['frame']} | {row['score']:.4f} | {row['label']} |")
    else:
        report_lines.append('No retrieval results were recorded.')
    report_lines.append('')
    external = retrieval_summary.get('external_query')
    if external:
        report_lines.append(f"### External query ({external.get('path')})")
        if external.get('warning'):
            report_lines.append(f"> {external['warning']}")
        else:
            report_lines.append('| Rank | Frame | Score | Label |')
            report_lines.append('| ---- | ----- | ----- | ----- |')
            for row in external.get('results', []):
                report_lines.append(f"| {row['rank']} | {row['frame']} | {row['score']:.4f} | {row['label']} |")
        report_lines.append('')

report_text = '\n'.join(report_lines) + '\n'
with open(PATHS.report, 'w', encoding='utf-8') as f:
    f.write(report_text)
print('Markdown report saved to', PATHS.report)
