# Orion Inference Visualizer

This notebook loads a trained ConvNeXt regression model (from `train_orion_patches.py`) and lets you interactively inspect:

- H&E patch
- Ground-truth per-marker heatmaps (robust-normalized to [0,1])
- Predicted per-marker scalar (one value per marker per patch)

Set `pairs_dir` and `checkpoint_path` below, then use the widgets to select a `basename` and `patch_index` (defaults to 40).


In [None]:
# Config
from pathlib import Path
import os

# Change these as needed
pairs_dir = Path('core_patches_npy')  # directory with *_HE.npy and *_ORION.npy
# You can set either a checkpoint file OR a directory to search in
checkpoint_path = None  # e.g., Path('runs/orion_seg/best_model.pth') or None
model_dir = Path('runs/orion_seg')

# Visualization defaults
default_basename_idx = 0
default_patch_index = 40
patch_size = 224

assert pairs_dir.exists(), f"pairs_dir not found: {pairs_dir}"

# Resolve checkpoint
resolved_ckpt = None
if checkpoint_path is not None:
    checkpoint_path = Path(checkpoint_path)
    if checkpoint_path.exists() and checkpoint_path.is_file():
        resolved_ckpt = checkpoint_path

if resolved_ckpt is None and model_dir is not None and Path(model_dir).exists():
    model_dir = Path(model_dir)
    candidates = [
        model_dir / 'checkpoint_epoch_10.pth',
    ]
    epoch_ckpts = sorted(model_dir.glob('checkpoint_epoch_*.pth'), key=lambda p: int(p.stem.split('_')[-1]), reverse=True)
    candidates.extend(epoch_ckpts)
    candidates.extend([
        model_dir / 'best_model.pth',
        model_dir / 'final_model.pth',
    ])
    for pth in candidates:
        if pth.exists():
            resolved_ckpt = pth
            break

assert resolved_ckpt is not None, f"No checkpoint found. Set checkpoint_path or ensure model_dir has checkpoints. Tried model_dir={model_dir}"
print(f"Using checkpoint: {resolved_ckpt}")


In [None]:
# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from ipywidgets import interact, IntSlider, Dropdown, fixed

# Reuse utility from training script
def robust_norm01(a: np.ndarray, p1=1, p99=99, eps=1e-6) -> np.ndarray:
    lo, hi = np.percentile(a, (p1, p99))
    if hi <= lo:
        return np.zeros_like(a, dtype=np.float32)
    return np.clip((a - lo) / (hi - lo + eps), 0, 1).astype(np.float32)


def discover_basenames(pairs_dir: Path):
    out = []
    for hef in sorted(pairs_dir.glob('core_*_HE.npy')):
        base = hef.stem.replace('_HE', '')
        if (pairs_dir / f"{base}_ORION.npy").exists():
            out.append(base)
    return out


def list_grid_coords(H, W, ps=224, stride=112):
    ys = [0] if H <= ps else list(range(0, max(1, H - ps) + 1, stride))
    xs = [0] if W <= ps else list(range(0, max(1, W - ps) + 1, stride))
    coords = [(y, x) for y in ys for x in xs]
    return coords


def load_pair(pairs_dir: Path, basename: str):
    he = np.load(pairs_dir / f"{basename}_HE.npy", mmap_mode='r')
    orion = np.load(pairs_dir / f"{basename}_ORION.npy", mmap_mode='r')
    if orion.ndim == 3 and orion.shape[0] == 20:
        orion = np.transpose(orion, (1, 2, 0))
    return he, orion


def get_patch(he: np.ndarray, orion: np.ndarray, y0: int, x0: int, ps: int):
    he_crop = he[y0:y0+ps, x0:x0+ps, :]
    or_crop = orion[y0:y0+ps, x0:x0+ps, :]
    return he_crop, or_crop


def scale_orion(or_crop: np.ndarray):
    C = or_crop.shape[2]
    or_scaled = np.zeros_like(or_crop, dtype=np.float32)
    for c in range(C):
        or_scaled[..., c] = robust_norm01(or_crop[..., c])
    return or_scaled


def to_tensor_image(he_crop: np.ndarray, ps: int):
    # mimic eval transform from training: ToTensor + Resize + Normalize
    tf_eval = T.Compose([
        T.ToTensor(),
        T.Resize(ps, antialias=True),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    he_img = (he_crop * 255).astype(np.uint8) if he_crop.dtype != np.uint8 else he_crop
    return tf_eval(he_img)


def prepare_target_vector(or_crop_scaled: np.ndarray):
    # (H, W, C) -> (C,)
    vec = or_crop_scaled.transpose(2, 0, 1).reshape(or_crop_scaled.shape[2], -1).mean(axis=1)
    return torch.from_numpy(vec.astype(np.float32))

# Fluorescence colormap utilities
FLUOR_COLORS = [
    (0.0, 0.5, 1.0),   # Blue
    (0.0, 1.0, 0.0),   # Green
    (1.0, 0.0, 0.0),   # Red
    (1.0, 1.0, 0.0),   # Yellow
    (1.0, 0.0, 1.0),   # Magenta
    (0.0, 1.0, 1.0),   # Cyan
    (1.0, 0.5, 0.0),   # Orange
    (0.5, 0.0, 1.0),   # Purple
    (0.0, 0.8, 0.4),   # Teal
    (1.0, 0.2, 0.6),   # Pink
    (0.6, 1.0, 0.2),   # Lime
    (0.8, 0.4, 0.0),   # Brown
    (0.4, 0.6, 1.0),   # Light Blue
    (1.0, 0.8, 0.2),   # Gold
    (0.6, 0.0, 0.6),   # Maroon
    (0.0, 0.6, 0.8),   # Steel Blue
    (0.8, 0.2, 0.4),   # Crimson
    (0.2, 0.8, 0.6),   # Sea Green
    (0.9, 0.6, 0.1),   # Dark Orange
    (0.3, 0.3, 0.9),   # Royal Blue
]

def create_fluorescence_colormap(color):
    colors = [(0, 0, 0), color]
    return LinearSegmentedColormap.from_list('fluor', colors, N=256)

FLUOR_CMAPS = [create_fluorescence_colormap(color) for color in FLUOR_COLORS]


In [None]:
# Model definitions (UNet and ConvNeXt heads)
import torchvision.models as tvm

class ConvNeXtHead(nn.Module):
    def __init__(self, num_outputs: int = 20, backbone: str = 'convnext_small'):
        super().__init__()
        if backbone == 'convnext_tiny':
            m = tvm.convnext_tiny(weights=tvm.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
            in_f = m.classifier[2].in_features
            m.classifier[2] = nn.Linear(in_f, num_outputs)
        else:
            m = tvm.convnext_small(weights=tvm.ConvNeXt_Small_Weights.IMAGENET1K_V1)
            in_f = m.classifier[2].in_features
            m.classifier[2] = nn.Linear(in_f, num_outputs)
        self.backbone = m
    def forward(self, x):
        return self.backbone(x)

# Import UNetSmall from training file if available
UNetSmall = None
try:
    import sys
    from pathlib import Path as _P
    proj_root = _P('/Users/ranystephan/Library/Mobile Documents/com~apple~CloudDocs/Desktop/RA/ra_biomed/hexif')
    if str(proj_root) not in sys.path:
        sys.path.insert(0, str(proj_root))
    from train_orion_patches import UNetSmall as _UNetSmall
    UNetSmall = _UNetSmall
except Exception as e:
    UNetSmall = None


def load_model_any(checkpoint_path: Path, device: torch.device = None):
    """
    Load checkpoint and return (model, device, mode) where mode is one of:
    - 'seg' for spatial output model (H,W per channel), e.g., UNetSmall
    - 'reg' for regression vector model, e.g., ConvNeXtHead
    """
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    device = device or torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu'))

    state = ckpt.get('model', ckpt)
    args = ckpt.get('args', {})

    # Try detect architecture from keys
    keys = list(state.keys())
    if any(k.startswith(('enc1.', 'dec1.', 'up1.', 'down1.', 'out')) for k in keys) or any(k.startswith(('d1.', 'u1.', 'c1.')) for k in keys):
        # Assume UNet-like spatial model
        if UNetSmall is None:
            raise RuntimeError('UNetSmall not importable; cannot render spatial predictions.')
        model = UNetSmall(in_ch=3, out_ch=20, base=args.get('base_features', 32))
        model.load_state_dict(state, strict=False)
        model.to(device)
        model.eval()
        return model, device, 'seg'
    else:
        # Fallback to ConvNeXt regression
        backbone = 'convnext_small'
        model = ConvNeXtHead(num_outputs=20, backbone=backbone)
        missing, unexpected = model.load_state_dict(state, strict=False)
        if missing or unexpected:
            print('Loaded with missing/unexpected keys:', missing, unexpected)
        model.to(device)
        model.eval()
        return model, device, 'reg'


In [None]:
# Prepare data lists and model
basenames = discover_basenames(pairs_dir)
assert len(basenames) > 0, f"No paired cores found in {pairs_dir}"

model, device, model_mode = load_model_any(resolved_ckpt)
print(f"Model mode: {model_mode}")

# Cache shapes and grid coords for each basename
shapes = {}
grids = {}
for b in basenames:
    he, orion = load_pair(pairs_dir, b)
    H, W = orion.shape[0], orion.shape[1]
    shapes[b] = (H, W)
    grids[b] = list_grid_coords(H, W, ps=patch_size, stride=patch_size//2)

print(f"Found {len(basenames)} cores. Example: {basenames[0]}")
print(f"Grid patches for first core: {len(grids[basenames[0]])}")


In [None]:
# Inference + visualization
from math import ceil

# If you have marker names, fill them here; else use indices 0..19
marker_names = [f"Marker {i}" for i in range(20)]

def infer_and_plot(basename: str, patch_index: int):
    he, orion = load_pair(pairs_dir, basename)
    coords = grids[basename]
    if len(coords) == 0:
        raise ValueError('No grid coordinates computed')
    patch_index = max(0, min(patch_index, len(coords)-1))
    y0, x0 = coords[patch_index]

    he_crop, or_crop = get_patch(he, orion, y0, x0, patch_size)
    or_scaled = scale_orion(or_crop)

    # prepare input tensor
    he_t = to_tensor_image(he_crop, patch_size).unsqueeze(0).to(device)

    with torch.no_grad():
        out = model(he_t)
        if model_mode == 'seg':
            # Expect (1, 20, H', W'); resize back to crop size if needed
            pred_map = torch.sigmoid(out)
            pred_map = pred_map.squeeze(0)  # (20, H', W')
            if pred_map.shape[1:] != or_scaled.shape[:2]:
                pred_map = F.interpolate(pred_map.unsqueeze(0), size=or_scaled.shape[:2], mode='bilinear', align_corners=False).squeeze(0)
            pred_np = pred_map.detach().cpu().numpy()  # (20, H, W)
        else:
            # Regression vector per marker; repeat to form uniform maps
            pred_vec = out.squeeze(0).detach().cpu().numpy()  # (20,)
            pred_np = np.stack([np.full(or_scaled.shape[:2], np.clip(float(v), 0.0, 1.0), dtype=np.float32) for v in pred_vec], axis=0)

    # display: For each marker, show [H&E | Predicted | Ground Truth]
    num_markers = or_scaled.shape[2]
    nrows = num_markers
    ncols = 3
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.0, nrows*3.0))
    if nrows == 1:
        axes = np.expand_dims(axes, 0)

    for c in range(num_markers):
        # Column 1: H&E
        ax = axes[c, 0]
        ax.imshow(he_crop)
        title = marker_names[c] if c < len(marker_names) else f'Marker {c}'
        ax.set_title(f"{title} - H&E", fontsize=9)
        ax.axis('off')

        # Column 2: Predicted heatmap
        ax = axes[c, 1]
        cmap = FLUOR_CMAPS[c % len(FLUOR_CMAPS)]
        ax.imshow(pred_np[c], cmap=cmap, vmin=0, vmax=1)
        ax.set_title("Pred", fontsize=9)
        ax.axis('off')

        # Column 3: Ground truth heatmap for this marker
        ax = axes[c, 2]
        ax.imshow(or_scaled[..., c], cmap=cmap, vmin=0, vmax=1)
        ax.set_title("GT", fontsize=9)
        ax.axis('off')

    fig.suptitle(f"{basename} | patch #{patch_index} at (y={y0}, x={x0})", y=0.995)
    plt.tight_layout()
    plt.show()

# Interactive controls
basename_dd = Dropdown(options=basenames, value=basenames[default_basename_idx], description='basename')
patch_slider = IntSlider(value=default_patch_index, min=0, max=max(0, len(grids[basenames[default_basename_idx]])-1), step=1, description='patch_index')

def update_patch_slider(*args):
    b = basename_dd.value
    patch_slider.max = max(0, len(grids[b]) - 1)
    patch_slider.value = min(patch_slider.value, patch_slider.max)

basename_dd.observe(update_patch_slider, names='value')

interact(infer_and_plot, basename=basename_dd, patch_index=patch_slider);
