# Ultrasound: SimSiam 学習 + FastAP 評価 (Colab Ready)

このノートブックは Google Colab 上で超音波データに対する SimSiam 学習と FastAP 評価を実行するテンプレートです。
- Google Drive から `dataset.mat` と `dataset_labels.csv` を取得し、ローカル `data/` と `annotations/` に配置します。
- 息止め参照画像は Colab 上の `data/` (もしくはルート) に `invivo.jpg` としてアップロードしておくか、共有リンクを `SHARED_BREATH_HOLD_URL` に設定して自動ダウンロードさせます。
- 疑似ラベルは `notebook/pseudo_labeling.ipynb` と同じ設定で生成された `annotations/dataset_labels.csv` を想定します。
- セルを上から順に実行すれば、このノートブック単体で Colab 上でも完結するよう構成しています。
- 実行ごとに `outputs/<timestamp>/` 以下へチェックポイント、ログ、評価指標、Top-K レポートを保存します（Colab の `files` からダウンロード可能）。


In [None]:
# Install required packages (Colab などで未インストールの場合のみ実行)
!pip install --quiet gdown pandas matplotlib h5py


In [None]:
# Configure directories, download Google Drive assets, and set global config
import os
import random
import subprocess
import time
from pathlib import Path
from types import SimpleNamespace
from datetime import datetime

import h5py
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.models as tvm

BASE_DIR = Path('.').resolve()
DATA_DIR = BASE_DIR / 'data'
ANNOTATIONS_DIR = BASE_DIR / 'annotations'
OUTPUTS_ROOT = BASE_DIR / 'outputs'
for path_obj in (DATA_DIR, ANNOTATIONS_DIR, OUTPUTS_ROOT):
    path_obj.mkdir(parents=True, exist_ok=True)

RUN_ID = datetime.now().strftime('%Y%m%d-%H%M%S')
RUN_DIR = OUTPUTS_ROOT / RUN_ID
CHECKPOINTS_DIR = RUN_DIR / 'checkpoints'
RESULTS_DIR = RUN_DIR / 'results'
for path_obj in (RUN_DIR, CHECKPOINTS_DIR, RESULTS_DIR):
    path_obj.mkdir(parents=True, exist_ok=True)

LOG_PATH = RUN_DIR / 'training_log.txt'
HISTORY_JSON_PATH = RUN_DIR / 'training_history.json'
METRICS_JSON_PATH = RESULTS_DIR / 'metrics.json'
TOPK_MD_PATH = RESULTS_DIR / 'topk_result.md'
SUMMARY_PATH = RUN_DIR / 'run_summary.txt'
print(f'Artifacts for this run will be saved under: {RUN_DIR}')

SHARED_DATASET_URL = 'https://drive.google.com/file/d/1fa0gaEmbtGmqZ92L0EqzhH5LiMUAztix/view?usp=sharing'
SHARED_BREATH_HOLD_URL = ''  # 任意: 息止め参照画像 (invivo.jpg) の共有リンク

DATASET_PATH = DATA_DIR / 'dataset.mat'
PSEUDO_LABELS_PATH = ANNOTATIONS_DIR / 'dataset_labels.csv'
BREATH_HOLD_IMAGE_PATH = DATA_DIR / 'invivo.jpg'

if not DATASET_PATH.exists():
    print(f'Downloading dataset to {DATASET_PATH} ...')
    cmd = ['gdown', '--fuzzy', SHARED_DATASET_URL, '-O', str(DATASET_PATH)]
    result = subprocess.run(cmd, check=False)
    if result.returncode != 0:
        raise RuntimeError('Failed to download dataset from Google Drive. Please verify the shared URL.')
else:
    print('Dataset already present:', DATASET_PATH)

if not PSEUDO_LABELS_PATH.exists():
    alt_path = DATA_DIR / 'dataset_labels.csv'
    if alt_path.exists():
        alt_path.replace(PSEUDO_LABELS_PATH)
        print('Moved pseudo labels from data/ to annotations/:', PSEUDO_LABELS_PATH)
    else:
        raise FileNotFoundError('annotations/dataset_labels.csv が見つかりません。pseudo_labeling ノートブックで生成するか、Drive から配置してください。')
else:
    print('Pseudo labels found:', PSEUDO_LABELS_PATH)

if not BREATH_HOLD_IMAGE_PATH.exists():
    moved = False
    for candidate in (BASE_DIR / 'invivo.jpg',):
        if candidate.exists():
            candidate.replace(BREATH_HOLD_IMAGE_PATH)
            print('Moved breath-hold image to data/:', BREATH_HOLD_IMAGE_PATH)
            moved = True
            break
    if not moved:
        if SHARED_BREATH_HOLD_URL:
            print(f'Downloading breath-hold reference to {BREATH_HOLD_IMAGE_PATH} ...')
            cmd = ['gdown', '--fuzzy', SHARED_BREATH_HOLD_URL, '-O', str(BREATH_HOLD_IMAGE_PATH)]
            result = subprocess.run(cmd, check=False)
            if result.returncode != 0:
                raise RuntimeError('Failed to download breath-hold image. Please verify SHARED_BREATH_HOLD_URL.')
        else:
            print('Note: data/invivo.jpg が存在しません。Colab のファイルアップロードや Drive から配置してください。')
else:
    print('Breath-hold reference image located at', BREATH_HOLD_IMAGE_PATH)

PATHS = SimpleNamespace(
    data_dir=DATA_DIR,
    annotations_dir=ANNOTATIONS_DIR,
    checkpoints_dir=CHECKPOINTS_DIR,
    dataset_mat=DATASET_PATH,
    pseudo_labels=PSEUDO_LABELS_PATH,
    breath_hold_image=BREATH_HOLD_IMAGE_PATH,
    output_dir=RUN_DIR,
    results_dir=RESULTS_DIR,
    log_path=LOG_PATH,
    history_json=HISTORY_JSON_PATH,
    metrics_json=METRICS_JSON_PATH,
    topk_markdown=TOPK_MD_PATH,
    summary_path=SUMMARY_PATH,
    run_id=RUN_ID,
)

SIMSIAM = SimpleNamespace(
    seed=42,
    image_key='Acq/Amp',
    image_axes=(0, 2, 1),
    train_batch_size=64,
    eval_batch_size=128,
    epochs=10,
    learning_rate=0.05,
    momentum=0.9,
    weight_decay=1e-4,
    cosine_t_max=100,
    topk=10,
    num_workers=0,  # Colab ではマルチプロセス DataLoader がハングする場合がある
)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', DEVI

random.seed(SIMSIAM.seed)
np.random.seed(SIMSIAM.seed)
torch.manual_seed(SIMSIAM.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SIMSIAM.seed)
    torch.backends.cudnn.benchmark = True

DATA_MAT = str(PATHS.dataset_mat)
LABEL_CSV = str(PATHS.pseudo_labels)
IMAGE_KEY = SIMSIAM.image_key
IMAGE_AXES = SIMSIAM.image_axes
CHECKPOINT_DIR = str(PATHS.checkpoints_dir)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

with open(PATHS.summary_path, 'w', encoding='utf-8') as f:
    print(f'run_id={PATHS.run_id}', file=f)
    print(f'output_dir={PATHS.output_dir}', file=f)
    print(f'checkpoints_dir={PATHS.checkpoints_dir}', file=f)
    print(f'results_dir={PATHS.results_dir}', file=f)
    print(f'dataset_mat={PATHS.dataset_mat}', file=f)
    print(f'pseudo_labels_csv={PATHS.pseudo_labels}', file=f)


In [None]:
# Minimal LazyMatImageDataset implementation (self-contained)
from typing import List, Optional, Sequence, Tuple

class LazyMatImageDataset(torch.utils.data.Dataset):
    """Lazily reads ultrasound frames from MATLAB v7.3 (.mat) files."""

    def __init__(
        self,
        mat_files: Sequence[str],
        image_key: str = 'Acq/Amp',
        transform=None,
        dtype: str = 'float32',
        normalize_255: bool = False,
        image_axes: Optional[Tuple[int, ...]] = None,
    ):
        self.mat_files = list(mat_files)
        if not self.mat_files:
            raise ValueError('No .mat files provided for dataset.')

        self.image_key = image_key
        self.transform = transform
        self.dtype = dtype
        self.normalize_255 = normalize_255
        self._axes_override = image_axes

        self._files: List[Optional[h5py.File]] = [None] * len(self.mat_files)
        self._index: List[Tuple[int, int]] = []
        self._image_axes: List[Tuple[int, ...]] = []

        for fi, path in enumerate(self.mat_files):
            f = h5py.File(path, 'r')
            try:
                ds = self._resolve_image_dataset(f, self.image_key)
                n, c, h, w, axes = _interpret_image_shape(ds.shape, override=self._axes_override)
                self._image_axes.append(axes)
                self._index.extend((fi, li) for li in range(n))
                self._files[fi] = f
            except Exception:
                f.close()
                raise

    def _resolve_image_dataset(self, f: h5py.File, key: str) -> h5py.Dataset:
        if key in f:
            return f[key]
        for candidate in f.keys():
            if candidate.endswith(key):
                return f[candidate]
        raise KeyError(f"image_key '{key}' not found in {list(f.keys())}")

    def __len__(self) -> int:
        return len(self._index)

    def __getitem__(self, idx: int):
        file_idx, local_idx = self._index[idx]
        f = self._files[file_idx]
        if f is None:
            f = h5py.File(self.mat_files[file_idx], 'r')
            self._files[file_idx] = f

        img_ds = self._resolve_image_dataset(f, self.image_key)
        axes = self._image_axes[file_idx]
        slicer = [slice(None)] * img_ds.ndim
        slicer[axes[0]] = local_idx
        arr = np.asarray(img_ds[tuple(slicer)])
        arr = _ensure_chw(arr, axes)
        if arr.dtype.kind in ('u', 'i'):
            arr = arr.astype(self.dtype)
            if self.normalize_255:
                arr = arr / 255.0
        else:
            arr = arr.astype(self.dtype)

        if self.transform is not None:
            return self.transform(arr)
        return torch.from_numpy(arr)

    def close(self):
        for i, f in enumerate(self._files):
            if f is not None:
                f.close()
                self._files[i] = None

    def __del__(self):
        try:
            self.close()
        except Exception:
            pass


class TwoCropsTransform:
    """Return two augmented views for SimSiam training."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x: np.ndarray):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return q, k


def _interpret_image_shape(
    shape: Tuple[int, ...],
    override: Optional[Tuple[int, ...]] = None,
) -> Tuple[int, int, int, int, Tuple[int, ...]]:
    if override is not None:
        axes = tuple(override)
        if len(axes) not in (3, 4):
            raise ValueError('image_axes override must have length 3 or 4')
        if len(axes) != len(shape):
            raise ValueError('image_axes override must match dataset rank')
        dims = [shape[a] for a in axes]
        if len(axes) == 4:
            n, c, h, w = dims
            return n, c, h, w, axes
        n, h, w = dims
        return n, 1, h, w, axes

    rank = len(shape)
    if rank == 4:
        candidates = [
            (0, 3, 1, 2),
            (0, 1, 2, 3),
            (3, 2, 0, 1),
            (3, 0, 1, 2),
        ]
        for axes in candidates:
            n = shape[axes[0]]
            c = shape[axes[1]]
            h = shape[axes[2]]
            w = shape[axes[3]]
            if all(v > 0 for v in (n, c, h, w)):
                return n, c, h, w, axes
        raise ValueError(f'Unable to infer N,C,H,W from shape {shape}')
    if rank == 3:
        candidates = [
            (0, 1, 2),
            (0, 2, 1),
            (2, 0, 1),
            (2, 1, 0),
            (1, 0, 2),
            (1, 2, 0),
        ]
        for axes in candidates:
            n = shape[axes[0]]
            h = shape[axes[1]]
            w = shape[axes[2]]
            if all(v > 0 for v in (n, h, w)):
                return n, 1, h, w, axes
        raise ValueError(f'Unable to infer N,H,W from shape {shape}')
    raise ValueError(f'Unsupported image rank {rank}; expected 3D or 4D')


def _ensure_chw(arr: np.ndarray, axes: Tuple[int, ...]) -> np.ndarray:
    n_axis = axes[0]
    remaining = []
    for a in axes[1:]:
        remaining.append(a if a < n_axis else a - 1)

    if len(remaining) == 3:
        c_pos, h_pos, w_pos = remaining
        return np.moveaxis(arr, (c_pos, h_pos, w_pos), (0, 1, 2))
    if len(remaining) == 2:
        h_pos, w_pos = remaining
        hw = np.moveaxis(arr, (h_pos, w_pos), (0, 1))
        return np.expand_dims(hw, axis=0)
    raise ValueError('Unexpected axes configuration for image array')



In [None]:
# 画像前処理（B-mode 変換 + 3ch化）
def bmode_normalize(chw: np.ndarray) -> np.ndarray:
    # chw: [C,H,W] float or double
    x = np.abs(chw.astype(np.float32))
    m = float(x.max())
    x = x / (m + 1e-12)
    x = 20.0 * np.log10(x + 1e-12)
    x = np.clip((x + 60.0) / 60.0, 0.0, 1.0)  # [0,1]
    return x

def to_pil_3ch_from_chw01(chw01: np.ndarray) -> Image.Image:
    # chw01: [1,H,W] or [C,H,W] in [0,1]
    if chw01.shape[0] == 1:
        chw01 = np.repeat(chw01, 3, axis=0)
    hwc255 = np.transpose(chw01, (1, 2, 0))
    hwc255 = (hwc255 * 255.0).astype(np.uint8)
    return Image.fromarray(hwc255)

# SimSiam 用の学習変換
simsiam_aug = T.Compose([
    # 入力: numpy [C,H,W] -> B-mode -> PIL
    T.Lambda(lambda x: to_pil_3ch_from_chw01(bmode_normalize(x))),
    T.RandomResizedCrop(224, scale=(0.2, 1.0), interpolation=T.InterpolationMode.BICUBIC),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(0.4, 0.4, 0.4, 0.1),
    T.RandomGrayscale(p=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 評価時の変換（強いAugなし）
eval_transform = T.Compose([
    T.Lambda(lambda x: to_pil_3ch_from_chw01(bmode_normalize(x))),
    T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
# SimSiam モデル定義（ResNet-18 バックボーン）
class Projector(nn.Module):
    def __init__(self, in_dim=512, hid_dim=2048, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim, bias=False),
            nn.BatchNorm1d(hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim)
        )
    def forward(self, x):
        return self.net(x)

class Predictor(nn.Module):
    def __init__(self, in_dim=2048, hid_dim=512, out_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim, bias=False),
            nn.BatchNorm1d(hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class SimSiam(nn.Module):
    def __init__(self, backbone_name='resnet18', proj_dim=2048):
        super().__init__()
        # backbone
        if backbone_name == 'resnet18':
            backbone = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1)
            feat_dim = 512
        else:
            raise ValueError('Unsupported backbone')
        # 最終FCを除去し GlobalAvgPool 出力を使う
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # -> [B,512,1,1]
        self.feat_dim = feat_dim
        self.projector = Projector(in_dim=feat_dim, hid_dim=proj_dim, out_dim=proj_dim)
        self.predictor = Predictor(in_dim=proj_dim, hid_dim=512, out_dim=proj_dim)
    def forward_backbone(self, x):
        h = self.backbone(x)
        h = torch.flatten(h, 1)  # [B, feat_dim]
        return h
    def forward(self, x1, x2):
        # x1
        h1 = self.forward_backbone(x1)
        z1 = self.projector(h1)
        p1 = self.predictor(z1)
        # x2
        h2 = self.forward_backbone(x2)
        z2 = self.projector(h2)
        p2 = self.predictor(z2)
        return p1, z1, p2, z2

def negative_cosine(p, z):
    z = z.detach()
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()


In [None]:
# データセットとローダの作成 (SimSiam 学習用)
train_dataset = LazyMatImageDataset(
    [DATA_MAT],
    image_key=IMAGE_KEY,
    normalize_255=False,
    image_axes=IMAGE_AXES,
    transform=TwoCropsTransform(simsiam_aug),
)

def collate_two_crops(batch):
    # batch: list of (q,k)
    q = torch.stack([b[0] for b in batch], dim=0)
    k = torch.stack([b[1] for b in batch], dim=0)
    return q, k

train_loader = DataLoader(
    train_dataset,
    batch_size=SIMSIAM.train_batch_size,
    shuffle=True,
    num_workers=SIMSIAM.num_workers,
    pin_memory=(DEVICE == 'cuda'),
    collate_fn=collate_two_crops,
)
len(train_dataset), len(train_loader)


In [None]:

# 学習ループ
import json

def log_to_console_and_file(message: str) -> None:
    '''Utility to mirror console output to the log file.'''
    print(message, flush=True)
    with open(PATHS.log_path, 'a', encoding='utf-8') as f:
        print(message, file=f)

run_ckpt_dir = Path(CHECKPOINT_DIR)
run_ckpt_dir.mkdir(parents=True, exist_ok=True)

model = SimSiam(backbone_name='resnet18', proj_dim=2048)
save_dir = BASE_DIR / 'save'
latest_ckpt = None
if save_dir.exists():
    candidates = [p for p in save_dir.glob('latest.*') if p.is_file()]
    if candidates:
        latest_ckpt = max(candidates, key=lambda p: p.stat().st_mtime)

if latest_ckpt is not None:
    log_to_console_and_file(f'Found pretrained weights: {latest_ckpt}')
    try:
        checkpoint = torch.load(latest_ckpt, map_location='cpu')
        state_dict = checkpoint.get('state_dict', checkpoint)
        if not isinstance(state_dict, dict):
            raise TypeError('Checkpoint does not contain a state_dict mapping.')
        state_dict = {k[len('module.'):] if k.startswith('module.') else k: v for k, v in state_dict.items()}
        load_result = None
        weights_loaded = False
        if any(k.startswith(('backbone.', 'projector.', 'predictor.')) for k in state_dict):
            load_result = model.load_state_dict(state_dict, strict=False)
            matched = len(state_dict) - len(load_result.unexpected_keys)
            weights_loaded = matched > 0
        else:
            backbone_state = model.backbone.state_dict()
            target_keys = list(backbone_state.keys())
            resnet_keys = [k for k in state_dict.keys() if not k.startswith('fc.')]
            new_backbone_state = backbone_state.copy()
            matched = 0
            skipped = []
            for tgt_key, src_key in zip(target_keys, resnet_keys):
                tensor = state_dict[src_key]
                if tensor.shape == new_backbone_state[tgt_key].shape:
                    new_backbone_state[tgt_key] = tensor
                    matched += 1
                else:
                    skipped.append(src_key)
            if matched:
                model.backbone.load_state_dict(new_backbone_state)
                weights_loaded = True
                log_to_console_and_file(f'Loaded {matched} backbone tensors from {latest_ckpt}.')
                zipped_count = min(len(target_keys), len(resnet_keys))
                unused = resnet_keys[zipped_count:] + skipped
                if unused:
                    log_to_console_and_file(f'Warning: skipped backbone keys (shape mismatch or excess): {unused}')
        if load_result:
            if load_result.missing_keys:
                log_to_console_and_file(f'Warning: missing keys when loading weights: {load_result.missing_keys}')
            if load_result.unexpected_keys:
                log_to_console_and_file(f'Warning: unexpected keys when loading weights: {load_result.unexpected_keys}')
        if weights_loaded:
            log_to_console_and_file('Pretrained weights loaded successfully.')
        else:
            log_to_console_and_file('Warning: no matching parameters found in pretrained weights; proceeding with random init.')
    except Exception as exc:
        log_to_console_and_file(f'Failed to load weights from {latest_ckpt}: {exc}')
model = model.to(DEVICE)
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=SIMSIAM.learning_rate,
    momentum=SIMSIAM.momentum,
    weight_decay=SIMSIAM.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=SIMSIAM.cosine_t_max)
epochs = SIMSIAM.epochs  # 設定はノートブック冒頭の SIMSIAM 設定を参照
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
scaler = torch.amp.GradScaler(device=device_type, enabled=(device_type == 'cuda'))

total_batches = len(train_loader)
if total_batches == 0:
    raise RuntimeError('train_loader にバッチが存在しません。データの読み込み設定を確認してください。')
log_interval = max(1, total_batches // 5)

log_to_console_and_file(f'Start training for {epochs} epochs ({total_batches} batches/epoch).')

history = []
model.train()
for epoch in range(epochs):
    epoch_loss = 0.0
    t0 = time.time()
    log_to_console_and_file(f'-- Epoch {epoch + 1}/{epochs} --')
    for batch_idx, (x1, x2) in enumerate(train_loader, start=1):
        x1 = x1.to(DEVICE, non_blocking=True)
        x2 = x2.to(DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type=device_type, enabled=(device_type == 'cuda')):
            p1, z1, p2, z2 = model(x1, x2)
            loss = 0.5 * negative_cosine(p1, z2) + 0.5 * negative_cosine(p2, z1)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += float(loss.item())

        if batch_idx % log_interval == 0 or batch_idx == total_batches:
            elapsed = time.time() - t0
            avg_so_far = epoch_loss / batch_idx
            lr_now = optimizer.param_groups[0]['lr']
            log_to_console_and_file(
                f'  batch {batch_idx}/{total_batches} '
                f'loss={loss.item():.4f} avg_loss={avg_so_far:.4f} '
                f'lr={lr_now:.6f} elapsed={elapsed:.1f}s'
            )
    scheduler.step()
    dt = time.time() - t0
    avg_loss = epoch_loss / max(1, total_batches)
    msg = f'Epoch {epoch+1}/{epochs} loss={avg_loss:.4f} time={dt:.1f}s'
    log_to_console_and_file(msg)
    history.append({'epoch': epoch + 1, 'loss': avg_loss, 'time_sec': dt})

    latest_path = run_ckpt_dir / 'simsiam_latest.pth'
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, latest_path)
    epoch_path = run_ckpt_dir / f'simsiam_epoch_{epoch + 1:02d}.pth'
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, epoch_path)
    log_to_console_and_file(f'Checkpoint updated: {latest_path}')

with open(PATHS.history_json, 'w', encoding='utf-8') as f:
    json.dump(history, f, indent=2)

if history:
    with open(PATHS.summary_path, 'a', encoding='utf-8') as f:
        print(f'epochs_completed={len(history)}', file=f)
        print(f'final_loss={history[-1]["loss"]:.4f}', file=f)
        print(f'checkpoint_latest={latest_path}', file=f)



In [None]:
# 埋め込み抽出（評価用）
from pathlib import Path
eval_dataset = LazyMatImageDataset(
    [DATA_MAT], image_key=IMAGE_KEY, normalize_255=False, image_axes=IMAGE_AXES, transform=eval_transform
)
eval_loader = DataLoader(
    eval_dataset, batch_size=SIMSIAM.eval_batch_size, shuffle=False,
    num_workers=SIMSIAM.num_workers, pin_memory=(DEVICE == 'cuda')
)

def extract_embeddings(model: SimSiam, loader: DataLoader) -> np.ndarray:
    model.eval()
    embs = []
    with torch.no_grad():
        for xb in loader:
            xb = xb.to(DEVICE)
            h = model.forward_backbone(xb)
            z = model.projector(h)
            z = F.normalize(z, dim=1)
            embs.append(z.cpu().numpy())
    return np.concatenate(embs, axis=0)

embeddings = extract_embeddings(model, eval_loader)
embeddings_path = Path(CHECKPOINT_DIR) / 'simsiam_embeddings.npy'
np.save(embeddings_path, embeddings)
print(f'Embeddings saved to {embeddings_path}')
with open(PATHS.summary_path, 'a', encoding='utf-8') as f:
    print(f'embeddings_path={embeddings_path}', file=f)
embeddings.shape


In [None]:
# FastAP 相当のランキング評価 (AP/mAP) を疑似ラベルで評価
import json
labels_df = pd.read_csv(LABEL_CSV)
labels = labels_df.set_index('frame').loc[range(len(eval_dataset)), 'label'].to_numpy()
labels = labels.astype(int)

def average_precision_for_query(sim_vec: np.ndarray, rel: np.ndarray) -> float:
    # sim_vec: (N,) 類似度（自分自身は含めないこと）
    # rel: (N,) 0/1 relevance
    order = np.argsort(-sim_vec)
    rel_sorted = rel[order]
    n_rel = int(rel_sorted.sum())
    if n_rel == 0:
        return 0.0
    cumsum = np.cumsum(rel_sorted)
    idx = np.arange(1, len(rel_sorted) + 1)
    prec_at_k = (cumsum / idx) * rel_sorted
    return float(prec_at_k.sum() / n_rel)

def compute_map(embs: np.ndarray, labels: np.ndarray) -> float:
    # cosine similarity (embeddings are normalized)
    sim_all = embs @ embs.T
    n = sim_all.shape[0]
    ap_list = []
    for i in range(n):
        sim_i = sim_all[i].copy()
        rel_i = (labels == labels[i]).astype(np.int32)
        # 自分自身を除外
        sim_i[i] = -np.inf
        rel_i[i] = 0
        ap = average_precision_for_query(sim_i, rel_i)
        ap_list.append(ap)
    return float(np.mean(ap_list)), ap_list

mAP, ap_list = compute_map(embeddings, labels)
metrics = {
    'mAP': mAP,
    'num_frames': len(eval_dataset),
    'num_positive_labels': int(labels.sum()),
    'timestamp': time.time(),
}
with open(PATHS.metrics_json, 'w', encoding='utf-8') as f:
    json.dump(metrics, f, indent=2)
print(f'mAP (pseudo labels): {mAP:.4f} -> saved to {PATHS.metrics_json}')
with open(PATHS.summary_path, 'a', encoding='utf-8') as f:
    print(f'mAP={mAP:.4f}', file=f)


In [None]:
# 息止め画像を与えて、類似フレーム Top-K を返す評価
from pathlib import Path
BREATH_HOLD_IMAGE = str(PATHS.breath_hold_image)
TOPK = SIMSIAM.topk

def preprocess_external_image(path: str) -> torch.Tensor:
    pil = Image.open(path).convert('RGB')
    t = T.Compose([
        T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])(pil)
    return t

def embed_image_tensor(model: SimSiam, img_t: torch.Tensor) -> np.ndarray:
    model.eval()
    with torch.no_grad():
        xb = img_t.unsqueeze(0).to(DEVICE)
        h = model.forward_backbone(xb)
        z = model.projector(h)
        z = F.normalize(z, dim=1)
        return z.cpu().numpy()[0]

path_obj = Path(BREATH_HOLD_IMAGE)
if path_obj.exists():
    q = preprocess_external_image(BREATH_HOLD_IMAGE)
    q_emb = embed_image_tensor(model, q)
    sims = embeddings @ q_emb
    top_idx = np.argsort(-sims)[:TOPK]
    lines = [
        f'# Top-{TOPK} Similar Frames',
        f'- Query image: {path_obj}',
        f'- Embedding source: {Path(CHECKPOINT_DIR) / "simsiam_embeddings.npy"}',
        '',
        '| Rank | Frame | Cosine sim | Label |',
        '| ---- | ----- | ---------- | ----- |',
    ]
    print(f'Top-{TOPK} 類似フレーム:')
    for rank, idx in enumerate(top_idx, 1):
        line = f'{rank:2d}: frame={idx} sim={sims[idx]:.4f} label={labels[idx]}'
        print(line)
        lines.append(f'| {rank} | {idx} | {sims[idx]:.4f} | {labels[idx]} |')
    report_text = '\n'.join(lines) + '\n'
    with open(PATHS.topk_markdown, 'w', encoding='utf-8') as f:
        f.write(report_text)
    print(f'Top-{TOPK} report saved to {PATHS.topk_markdown}')
    with open(PATHS.summary_path, 'a', encoding='utf-8') as f:
        print(f'topk_report={PATHS.topk_markdown}', file=f)
else:
    warning_msg = 'data/invivo.jpg が見つかりません。Colab にアップロードするか、SHARED_BREATH_HOLD_URL を設定して再実行してください。'
    print(warning_msg)
    with open(PATHS.summary_path, 'a', encoding='utf-8') as f:
        print('topk_report=missing_input', file=f)
