# 05 - Results Analysis & Visualization

This notebook provides tools for analyzing MedJEPA training results and
visualizing what the model has learned.

**Sections:**
1. Training History
2. Embedding Space Visualization (t-SNE)
3. Attention Maps
4. Data Efficiency Curves
5. Evaluation Summary

In [None]:
import sys, os, json
sys.path.insert(0, os.path.abspath('..'))

import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend if needed
import matplotlib.pyplot as plt

from medjepa.utils.visualization import (
    plot_training_history,
    plot_embedding_space,
    plot_attention_map,
    plot_data_efficiency,
    extract_attention_weights,
    plot_evaluation_summary,
)
from medjepa.utils.device import get_device

device = get_device()
print(f'Device: {device}')

## 1. Training History

Load and plot the loss curve from pre-training.

In [None]:
# Load training history
history_path = '../checkpoints/training_history.json'
if os.path.exists(history_path):
    with open(history_path) as f:
        history = json.load(f)
    print(f"Epochs: {len(history['epochs'])}")
    print(f"Final loss: {history['train_loss'][-1]:.6f}")
    plot_training_history(history, save_path='../results/training_history.png')
else:
    print(f'No training history found at {history_path}')
    print('Run pre-training first: python scripts/pretrain.py ...')

## 2. Embedding Space (t-SNE)

Visualize how the model organises different disease types in its
embedding space. Good representations should cluster similar
diseases together.

In [None]:
from medjepa.models import LeJEPA, ViTEncoder
from medjepa.evaluation import LinearProbeEvaluator
from medjepa.data.datasets import MedicalImageDataset
from torch.utils.data import DataLoader

# --- Configure these paths ---
CHECKPOINT = '../checkpoints/best_model.pt'
DATA_DIR   = '../data/raw/ham10000'
CSV_PATH   = '../data/raw/ham10000/HAM10000_metadata.csv'
LABEL_COL  = 'dx'
# ---

# Load checkpoint
ckpt = torch.load(CHECKPOINT, map_location='cpu', weights_only=False)
cfg  = ckpt.get('config', {})

embed_dim  = cfg.get('embed_dim', 768)
enc_depth  = cfg.get('encoder_depth', 12)
pred_depth = cfg.get('predictor_depth', 6)
image_size = cfg.get('image_size', 224)
patch_size = cfg.get('patch_size', 16)

print(f'Model: embed_dim={embed_dim}, encoder_depth={enc_depth}, '
      f'predictor_depth={pred_depth}')

In [None]:
# Build model and load weights
encoder = ViTEncoder(
    image_size=image_size, patch_size=patch_size,
    embed_dim=embed_dim, depth=enc_depth,
)
model = LeJEPA(
    encoder=encoder, embed_dim=embed_dim,
    predictor_depth=pred_depth,
)
model.load_state_dict(ckpt['model_state_dict'])
model = model.to(device).eval()
print('Model loaded.')

In [None]:
# Load dataset
dataset = MedicalImageDataset(
    data_dir=DATA_DIR, metadata_csv=CSV_PATH,
    label_column=LABEL_COL, image_size=image_size,
)
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=0)

num_classes = len(set(dataset.labels))
class_names = None
if hasattr(dataset, 'label_map'):
    inv_map = {v: k for k, v in dataset.label_map.items()}
    class_names = [inv_map[i] for i in range(num_classes)]

print(f'Dataset: {len(dataset)} images, {num_classes} classes')
if class_names:
    print(f'Classes: {class_names}')

In [None]:
# Extract embeddings
evaluator = LinearProbeEvaluator(model)
features, labels = evaluator.extract_features(loader)
print(f'Extracted: features {features.shape}, labels {labels.shape}')

# t-SNE plot
os.makedirs('../results', exist_ok=True)
plot_embedding_space(
    features, labels, class_names=class_names,
    title=f'MedJEPA Embedding Space ({len(dataset)} images)',
    save_path='../results/embedding_tsne.png',
)

## 3. Attention Maps

Visualize where the model focuses its attention. Clinicians can use
this to verify the model looks at the right areas (lesion, not background).

In [None]:
from PIL import Image
import cv2

# Pick a few sample images
sample_indices = [0, 100, 500, 1000]
for idx in sample_indices:
    if idx >= len(dataset):
        continue
    img_tensor, label = dataset[idx]

    # Get attention map
    attn_map = extract_attention_weights(model, img_tensor)

    # Resize attention to image size
    attn_resized = cv2.resize(
        attn_map, (image_size, image_size),
        interpolation=cv2.INTER_LINEAR,
    )

    # Convert tensor image to numpy for display
    img_np = img_tensor.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)

    label_name = class_names[label] if class_names else f'Class {label}'
    plot_attention_map(
        img_np, attn_resized,
        title=f'Sample {idx} — {label_name}',
        save_path=f'../results/attention_sample_{idx}.png',
    )

## 4. Data Efficiency Curve

Load few-shot / data-efficiency evaluation results and plot the
"money plot" — how accuracy scales with labeled data.

In [None]:
# Load evaluation results if available
eval_path = '../results/evaluation_results.json'
if os.path.exists(eval_path):
    with open(eval_path) as f:
        eval_results = json.load(f)

    # Data efficiency results
    if 'data_efficiency' in eval_results:
        plot_data_efficiency(
            eval_results['data_efficiency'],
            title='MedJEPA Data Efficiency (HAM10000)',
            save_path='../results/data_efficiency.png',
        )
    else:
        print('No data_efficiency key in results. '
              'Run: python scripts/evaluate.py ...')
else:
    print(f'No evaluation results at {eval_path}')
    print('Run: python scripts/evaluate.py --checkpoint ... --data_dir ...')

## 5. Evaluation Summary

Bar chart of all evaluation metrics at a glance.

In [None]:
if os.path.exists(eval_path):
    with open(eval_path) as f:
        eval_results = json.load(f)

    summary = {}

    # Linear probe accuracy
    if 'linear_probe' in eval_results:
        lp = eval_results['linear_probe']
        if 'accuracy' in lp:
            summary['Linear Probe'] = lp['accuracy']

    # Few-shot results
    if 'few_shot' in eval_results:
        for key, val in eval_results['few_shot'].items():
            if isinstance(val, dict) and 'accuracy' in val:
                summary[key] = val['accuracy']

    if summary:
        plot_evaluation_summary(
            summary,
            title='MedJEPA Evaluation Summary',
            save_path='../results/evaluation_summary.png',
        )
    else:
        print('No summary metrics found in results.')
else:
    print('Run evaluation first.')

---

**Generated plots are saved to `results/` folder.**

For full evaluation, run:
```bash
python scripts/evaluate.py \
    --checkpoint checkpoints/best_model.pt \
    --data_dir data/raw/ham10000 \
    --metadata_csv data/raw/ham10000/HAM10000_metadata.csv \
    --label_column dx --num_classes 7
```