# Fragment Reconstruction: CPU Inference & Evaluation

This notebook demonstrates how to load a trained fragment autoencoder checkpoint and perform CPU-based inference, clustering, metrics, and visualization. It is designed to run on a local machine without GPUs.

In [None]:
# CPU-centric settings
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''  # force CPU

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

from study.fragmentation import FragmentBatchDataset, collate_fragment_batch
from study.train_module import FragmentAE

pl.seed_everything(42, workers=True)
print('Torch device:', 'cuda' if torch.cuda.is_available() else 'cpu')

## 1) Paths & Parameters
Update these to point to your validation data directory and a trained checkpoint.

In [None]:
VAL_DIR = Path('data/imagenet64/dev_data')  # change to your validation images
CKPT = Path('outputs/fragment_clustering_baseline/checkpoints/fragment-ae-epoch=00-val_loss=0.0836.ckpt')  # change to your best ckpt

IMAGES_PER_SAMPLE = 10  # number of source images per sample (k for clustering)
NUM_SAMPLES = 50        # how many samples to evaluate
BATCH_SIZE = 1          # keep 1; each batch is a full unordered fragment set
NUM_WORKERS = 0         # CPU-friendly, deterministic

assert VAL_DIR.exists(), f'Validation directory not found: {VAL_DIR}'
assert CKPT.exists(), f'Checkpoint not found: {CKPT}'
print('Validation dir:', VAL_DIR)
print('Checkpoint:', CKPT)

## 2) Load Model (CPU)
Loads the LightningModule on CPU.

In [None]:
# Load model strictly on CPU
try:
    model = FragmentAE.load_from_checkpoint(str(CKPT), map_location='cpu')
except Exception:
    state = torch.load(CKPT, map_location='cpu')
    model = FragmentAE()
    model.load_state_dict(state['state_dict'])

model.eval(); model.cpu();
print('Model loaded. Parameters:', sum(p.numel() for p in model.parameters()))

## 3) DataLoader (CPU)
Build a small evaluation dataset on CPU.

In [None]:
ds = FragmentBatchDataset(
    images_dir=VAL_DIR,
    images_per_sample=IMAGES_PER_SAMPLE,
    steps_per_epoch=NUM_SAMPLES,
    seed=123,
    augment=False,
)
loader = DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fragment_batch,
    pin_memory=False,
)
len(ds)

## 4) Helper Functions for Metrics & Visualization
Purity, a simple grid helper, and one-batch evaluation.

In [None]:
from collections import Counter, defaultdict


def purity_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    clusters = defaultdict(list)
    for t, p in zip(y_true, y_pred):
        clusters[int(p)].append(int(t))
    total, correct = len(y_true), 0
    for members in clusters.values():
        counts = Counter(members)
        correct += counts.most_common(1)[0][1]
    return correct / float(total) if total > 0 else 0.0


def make_grid(imgs: np.ndarray, cols: int = 10) -> np.ndarray:
    N = imgs.shape[0]
    rows = int(np.ceil(N / cols))
    H, W = 16, 16
    grid = np.ones((rows * H, cols * W, 3), dtype=np.float32)
    for i in range(N):
        r, c = divmod(i, cols)
        grid[r*H:(r+1)*H, c*W:(c+1)*W] = imgs[i]
    return grid


@torch.no_grad()
def eval_one_batch(model: FragmentAE, batch, k: int):
    _, z = model(batch.fragments.cpu())
    z_np = z.detach().cpu().numpy()
    y_true = batch.source_ids.detach().cpu().numpy()

    km = KMeans(n_clusters=k, n_init=10, random_state=0)
    y_pred = km.fit_predict(z_np)

    ari = adjusted_rand_score(y_true, y_pred)
    nmi = normalized_mutual_info_score(y_true, y_pred)
    pur = purity_score(y_true, y_pred)
    return ari, nmi, pur, y_pred


## 5) Evaluate on CPU
Compute ARI, NMI, and Purity over a number of samples.

In [None]:
ari_list, nmi_list, pur_list = [], [], []
for i, batch in enumerate(loader):
    ari, nmi, pur, _ = eval_one_batch(model, batch, k=IMAGES_PER_SAMPLE)
    ari_list.append(ari); nmi_list.append(nmi); pur_list.append(pur)
    if (i+1) % 10 == 0:
        print(f'Processed {i+1}/{len(loader)}')

print({
    'samples': len(ari_list),
    'ARI_mean': float(np.mean(ari_list)) if ari_list else 0.0,
    'NMI_mean': float(np.mean(nmi_list)) if nmi_list else 0.0,
    'Purity_mean': float(np.mean(pur_list)) if pur_list else 0.0,
})

## 6) Visualize a Single Sample
Show predicted clusters vs. true groups for one sample.

In [None]:
batch = next(iter(loader))
_, _, _, y_pred = eval_one_batch(model, batch, k=IMAGES_PER_SAMPLE)
frags = batch.fragments.detach().cpu().numpy().transpose(0,2,3,1)  # [N,16,16,3]
y_true = batch.source_ids.detach().cpu().numpy()

fig, axes = plt.subplots(2, IMAGES_PER_SAMPLE, figsize=(2*IMAGES_PER_SAMPLE, 4))
axes = np.atleast_2d(axes)
for cid in range(IMAGES_PER_SAMPLE):
    idxs = np.where(y_pred == cid)[0][:20]
    grid = make_grid(frags[idxs], cols=10) if len(idxs) else np.ones((16,16,3))
    axes[0, cid].imshow(grid); axes[0, cid].set_title(f'Pred {cid}'); axes[0, cid].axis('off')

    t_idxs = np.where(y_true == cid)[0][:20]
    t_grid = make_grid(frags[t_idxs], cols=10) if len(t_idxs) else np.ones((16,16,3))
    axes[1, cid].imshow(t_grid); axes[1, cid].set_title(f'True {cid}'); axes[1, cid].axis('off')

plt.tight_layout()
plt.show()