# 05 - Results Analysis & Visualization

This notebook provides tools for analyzing MedJEPA training results and
visualizing what the model has learned.

**Sections:**
1. Training History
2. Embedding Space Visualization (t-SNE)
3. Attention Maps
4. Data Efficiency Curves
5. Evaluation Summary

In [None]:
import sys, os, json
sys.path.insert(0, os.path.abspath('..'))

import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend if needed
import matplotlib.pyplot as plt

from medjepa.utils.visualization import (
    plot_training_history,
    plot_embedding_space,
    plot_attention_map,
    plot_data_efficiency,
    extract_attention_weights,
    plot_evaluation_summary,
)
from medjepa.utils.device import get_device

device = get_device()
print(f'Device: {device}')

## 1. Training History

Load and plot the loss curve from pre-training.

In [None]:
# Load training history
history_path = '../checkpoints/training_history.json'
if os.path.exists(history_path):
    with open(history_path) as f:
        history = json.load(f)
    print(f"Epochs: {len(history['epochs'])}")
    print(f"Final loss: {history['train_loss'][-1]:.6f}")
    plot_training_history(history, save_path='../results/training_history.png')
else:
    print(f'No training history found at {history_path}')
    print('Run pre-training first: python scripts/pretrain.py ...')

## 2. Embedding Space (t-SNE)

Visualize how the model organises different disease types in its
embedding space. Good representations should cluster similar
diseases together.

In [None]:
from medjepa.models.lejepa import LeJEPA
from medjepa.evaluation.linear_probe import LinearProbeEvaluator
from medjepa.data.datasets import MedicalImageDataset
from torch.utils.data import DataLoader

# --- Configure these paths ---
CHECKPOINT = '../checkpoints/best_model.pt'
DATA_DIR   = '../data/raw/ham10000'
CSV_PATH   = '../data/raw/ham10000/HAM10000_metadata.csv'
LABEL_COL  = 'dx'
# ---

# Load checkpoint
if os.path.exists(CHECKPOINT):
    ckpt = torch.load(CHECKPOINT, map_location='cpu', weights_only=False)
    cfg  = ckpt.get('config', {})

    embed_dim  = cfg.get('embed_dim', 768)
    enc_depth  = cfg.get('encoder_depth', 12)
    pred_depth = cfg.get('predictor_depth', 6)
    image_size = cfg.get('image_size', 224)
    patch_size = cfg.get('patch_size', 16)

    print(f'Model: embed_dim={embed_dim}, encoder_depth={enc_depth}, '
          f'predictor_depth={pred_depth}')
else:
    print(f'No checkpoint found at {CHECKPOINT}')
    print('Run pre-training first: python scripts/run_gpu_full.py')

In [None]:
# Build model and load weights
if os.path.exists(CHECKPOINT):
    model = LeJEPA(
        image_size=image_size, patch_size=patch_size,
        embed_dim=embed_dim, encoder_depth=enc_depth,
        predictor_depth=pred_depth,
    )
    model.load_state_dict(ckpt['model_state_dict'])
    model = model.to(device).eval()
    print('Model loaded.')
else:
    model = None
    print('Skipping — no checkpoint.')

In [None]:
# Load dataset for t-SNE visualization
if model is not None and os.path.exists(DATA_DIR):
    dataset = MedicalImageDataset(
        image_dir=DATA_DIR, metadata_csv=CSV_PATH,
        label_column=LABEL_COL, target_size=(image_size, image_size),
    )
    loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=0)

    num_classes = len(set(dataset.labels))
    class_names = None
    if hasattr(dataset, 'label_map'):
        inv_map = {v: k for k, v in dataset.label_map.items()}
        class_names = [inv_map[i] for i in range(num_classes)]

    print(f'Dataset: {len(dataset)} images, {num_classes} classes')
    if class_names:
        print(f'Classes: {class_names}')
else:
    dataset = None
    print('Skipping dataset load — no model or data.')

In [None]:
# Extract embeddings and plot t-SNE
if model is not None and dataset is not None:
    evaluator = LinearProbeEvaluator(model, num_classes=num_classes, embed_dim=embed_dim)
    features, labels = evaluator.extract_features(loader)
    print(f'Extracted: features {features.shape}, labels {labels.shape}')

    os.makedirs('../results', exist_ok=True)
    plot_embedding_space(
        features.numpy(), labels.numpy(), class_names=class_names,
        title=f'MedJEPA Embedding Space ({len(dataset)} images)',
        save_path='../results/embedding_tsne.png',
    )
else:
    print('Skipping t-SNE — no model/dataset.')

## 3. Attention Maps

Visualize where the model focuses its attention. Clinicians can use
this to verify the model looks at the right areas (lesion, not background).

In [None]:
if model is not None and dataset is not None:
    from PIL import Image
    try:
        import cv2
    except ImportError:
        cv2 = None
        print("Install opencv-python for attention maps: pip install opencv-python")

    if cv2 is not None:
        sample_indices = [0, 100, 500, 1000]
        for idx in sample_indices:
            if idx >= len(dataset):
                continue
            img_tensor, label = dataset[idx]

            attn_map = extract_attention_weights(model, img_tensor)
            attn_resized = cv2.resize(
                attn_map, (image_size, image_size),
                interpolation=cv2.INTER_LINEAR,
            )

            img_np = img_tensor.permute(1, 2, 0).numpy()
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)

            label_name = class_names[label] if class_names else f'Class {label}'
            plot_attention_map(
                img_np, attn_resized,
                title=f'Sample {idx} — {label_name}',
                save_path=f'../results/attention_sample_{idx}.png',
            )
else:
    print('Skipping attention maps — no model/dataset.')

## 4. Data Efficiency Curves (All Datasets)

The "money plot" — how accuracy scales with labeled data.
This proves MedJEPA's value: strong performance even with very few labels.

In [None]:
# Load evaluation results (generated by run_gpu_full.py Phase 3)
eval_path = '../results/evaluation_results.json'
if os.path.exists(eval_path):
    with open(eval_path) as f:
        eval_results = json.load(f)
    print(f"Loaded results for {len(eval_results)} datasets:")
    for name in eval_results:
        print(f"  - {name}")
else:
    print(f'No evaluation results at {eval_path}')
    print('Run: python scripts/run_gpu_full.py')
    eval_results = {}

# --- Multi-Dataset Linear Probe Comparison (bar chart) ---
if eval_results:
    datasets_with_lp = {k: v for k, v in eval_results.items() if 'linear_probing' in v}

    if datasets_with_lp:
        names = list(datasets_with_lp.keys())
        medjepa_accs = [v['linear_probing']['accuracy'] for v in datasets_with_lp.values()]

        # Supervised baseline (if available)
        baseline_accs = []
        has_baseline = all('supervised_baseline' in v for v in datasets_with_lp.values())
        if has_baseline:
            baseline_accs = [v['supervised_baseline']['accuracy'] for v in datasets_with_lp.values()]

        fig, ax = plt.subplots(figsize=(14, 6))
        x = np.arange(len(names))
        width = 0.35

        bars1 = ax.bar(x - width/2 if has_baseline else x, medjepa_accs, width,
                       label='MedJEPA (pre-trained)', color='steelblue')
        if has_baseline:
            bars2 = ax.bar(x + width/2, baseline_accs, width,
                           label='Random Init (baseline)', color='lightcoral')

        ax.set_ylabel('Accuracy')
        ax.set_title('Linear Probe Accuracy: MedJEPA vs Supervised Baseline')
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=30, ha='right')
        ax.legend()
        ax.set_ylim(0, 1)
        ax.grid(axis='y', alpha=0.3)

        # Add value labels on bars
        for bar in bars1:
            h = bar.get_height()
            ax.annotate(f'{h:.3f}', xy=(bar.get_x() + bar.get_width()/2, h),
                       xytext=(0, 3), textcoords='offset points', ha='center', fontsize=8)
        if has_baseline:
            for bar in bars2:
                h = bar.get_height()
                ax.annotate(f'{h:.3f}', xy=(bar.get_x() + bar.get_width()/2, h),
                           xytext=(0, 3), textcoords='offset points', ha='center', fontsize=8)

        os.makedirs('../results', exist_ok=True)
        plt.tight_layout()
        plt.savefig('../results/multi_dataset_linear_probe.png', dpi=150, bbox_inches='tight')
        plt.show()
    else:
        print("No linear probing results found.")

## 5. N-Shot Results + Dice Scores

Evaluate 5-shot, 10-shot, 20-shot classification and Dice segmentation scores.

In [None]:
# --- N-Shot Classification Results ---
if eval_results:
    n_shot_datasets = {k: v for k, v in eval_results.items() if 'n_shot' in v}

    if n_shot_datasets:
        fig, ax = plt.subplots(figsize=(12, 6))
        x = np.arange(len(n_shot_datasets))
        width = 0.25

        shot_keys = ['5-shot', '10-shot', '20-shot']
        colors = ['#2ecc71', '#3498db', '#9b59b6']

        for i, shot in enumerate(shot_keys):
            accs = []
            for name, res in n_shot_datasets.items():
                accs.append(res['n_shot'].get(shot, {}).get('accuracy', 0))
            ax.bar(x + i*width, accs, width, label=shot, color=colors[i])

        ax.set_ylabel('Accuracy')
        ax.set_title('N-Shot Classification Across Datasets')
        ax.set_xticks(x + width)
        ax.set_xticklabels(n_shot_datasets.keys(), rotation=30, ha='right')
        ax.legend()
        ax.set_ylim(0, 1)
        ax.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig('../results/n_shot_results.png', dpi=150, bbox_inches='tight')
        plt.show()
    else:
        print("No n-shot results found.")

# --- Dice Segmentation Scores ---
if eval_results:
    seg_datasets = {k: v for k, v in eval_results.items() if v.get('type') == 'segmentation'}

    if seg_datasets:
        fig, ax = plt.subplots(figsize=(10, 5))
        names = list(seg_datasets.keys())
        dice_scores = [v['mean_dice'] for v in seg_datasets.values()]

        bars = ax.bar(names, dice_scores, color='teal')
        ax.set_ylabel('Mean Dice Score')
        ax.set_title('Segmentation Performance (Dice Score)')
        ax.set_ylim(0, 1)
        ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Baseline (0.5)')
        ax.legend()
        ax.grid(axis='y', alpha=0.3)

        for bar in bars:
            h = bar.get_height()
            ax.annotate(f'{h:.3f}', xy=(bar.get_x() + bar.get_width()/2, h),
                       xytext=(0, 3), textcoords='offset points', ha='center', fontsize=10)

        plt.xticks(rotation=30, ha='right')
        plt.tight_layout()
        plt.savefig('../results/dice_scores.png', dpi=150, bbox_inches='tight')
        plt.show()

        # Per-class Dice
        for name, res in seg_datasets.items():
            print(f"\n{name}:")
            print(f"  Mean Dice: {res['mean_dice']:.4f}")
            for cls, d in res.get('per_class_dice', {}).items():
                label = 'Background' if str(cls) == '0' else 'Foreground'
                print(f"  Class {cls} ({label}): {d:.4f}")
    else:
        print("No segmentation results found.")

## 6. Full Results Summary Table

In [None]:
# --- Comprehensive Results Table ---
import pandas as pd

if eval_results:
    rows = []
    for name, res in eval_results.items():
        row = {'Dataset': name, 'Type': res.get('type', '')}

        # Linear probe
        lp = res.get('linear_probing', {})
        row['LP Accuracy'] = f"{lp['accuracy']:.4f}" if 'accuracy' in lp else 'N/A'
        row['LP AUC'] = f"{lp['auc']:.4f}" if lp.get('auc') else 'N/A'

        # Baseline
        bl = res.get('supervised_baseline', {})
        row['Baseline Acc'] = f"{bl['accuracy']:.4f}" if 'accuracy' in bl else 'N/A'

        # Improvement
        if 'accuracy' in lp and 'accuracy' in bl:
            imp = lp['accuracy'] - bl['accuracy']
            row['Improvement'] = f"{imp:+.4f}"
        else:
            row['Improvement'] = 'N/A'

        # N-shot
        ns = res.get('n_shot', {})
        row['5-shot'] = f"{ns['5-shot']['accuracy']:.4f}" if '5-shot' in ns else 'N/A'
        row['10-shot'] = f"{ns['10-shot']['accuracy']:.4f}" if '10-shot' in ns else 'N/A'

        # Dice
        row['Dice'] = f"{res['mean_dice']:.4f}" if 'mean_dice' in res else 'N/A'

        rows.append(row)

    results_df = pd.DataFrame(rows)
    print("=" * 100)
    print("MEDJEPA FULL EVALUATION RESULTS")
    print("=" * 100)
    print(results_df.to_string(index=False))
    print("=" * 100)

    # Save as CSV for the submission
    os.makedirs('../results', exist_ok=True)
    results_df.to_csv('../results/full_results_table.csv', index=False)
    print("\nSaved to results/full_results_table.csv")
else:
    print("No results to summarize. Run: python scripts/run_gpu_full.py")

---

**All plots are saved to `results/` folder.**

### Generated Artifacts:
- `results/training_history.png` — Loss curve
- `results/embedding_tsne.png` — t-SNE embedding space
- `results/attention_sample_*.png` — Attention maps
- `results/multi_dataset_linear_probe.png` — MedJEPA vs baseline (all datasets)
- `results/n_shot_results.png` — 5/10/20-shot classification
- `results/dice_scores.png` — Segmentation Dice scores
- `results/full_results_table.csv` — Complete results table

### To reproduce full results:
```bash
python scripts/run_gpu_full.py
```