In [None]:
# Cell 1: Imports
import sys
sys.path.append('..')

import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch

from src.models.baseline_transfer import TransferLearningBaseline
from src.data.standard_dataset import StandardDermaMNIST

print("Imports successful!")

# Cell 2: Load results
with open('../experiments/baseline/transfer_learning/test_results.json', 'r') as f:
    results = json.load(f)

print("Baseline 1: Transfer Learning Results")
print("="*50)
print(f"Best Epoch: {results['best_epoch']}")
print(f"Best Val Accuracy: {results['best_val_accuracy']:.2f}%")
print(f"Test Accuracy: {results['test_accuracy']:.2f}%")

# Cell 3: Per-class accuracy visualization
import yaml
with open('../configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

class_names = list(config['dataset']['class_names'].values())
per_class_acc = results['per_class_accuracy']

fig, ax = plt.subplots(figsize=(12, 6))

bars = ax.bar(range(len(class_names)), per_class_acc, color='steelblue', alpha=0.7)

# Highlight minority classes
minority_classes = [3, 6]  # Dermatofibroma, Vascular
for i in minority_classes:
    bars[i].set_color('coral')
    bars[i].set_alpha(0.9)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Test Accuracy (%)', fontsize=12)
ax.set_title('Baseline 1: Per-Class Test Accuracy', fontsize=14, fontweight='bold')
ax.set_xticks(range(len(class_names)))
ax.set_xticklabels([f"C{i}" for i in range(len(class_names))], rotation=0)
ax.grid(axis='y', alpha=0.3)
ax.axhline(y=results['test_accuracy'], color='red', linestyle='--', 
           label=f"Overall: {results['test_accuracy']:.2f}%")
ax.legend()

# Add values on bars
for i, (bar, acc) in enumerate(zip(bars, per_class_acc)):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{acc:.1f}%',
            ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('../experiments/baseline/transfer_learning/per_class_accuracy.png', dpi=150)
plt.show()

# Cell 4: Confusion matrix
conf_matrix = np.array(results['confusion_matrix'])

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[f"C{i}" for i in range(7)],
            yticklabels=[f"C{i}" for i in range(7)],
            cbar_kws={'label': 'Count'})
ax.set_xlabel('Predicted', fontsize=12)
ax.set_ylabel('True', fontsize=12)
ax.set_title('Baseline 1: Confusion Matrix', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('../experiments/baseline/transfer_learning/confusion_matrix.png', dpi=150)
plt.show()

# Cell 5: Analysis summary
print("\nKey Observations:")
print("-" * 50)

overall_acc = results['test_accuracy']
minority_acc = np.mean([per_class_acc[3], per_class_acc[6]])
majority_acc = np.mean([per_class_acc[i] for i in [0,1,2,4,5]])

print(f"Overall Accuracy: {overall_acc:.2f}%")
print(f"Majority Classes Avg: {majority_acc:.2f}%")
print(f"Minority Classes Avg: {minority_acc:.2f}%")
print(f"Gap: {majority_acc - minority_acc:.2f}%")

print("\nWorst Performing Classes:")
sorted_idx = np.argsort(per_class_acc)
for i in sorted_idx[:3]:
    print(f"  Class {i} ({class_names[i][:30]}...): {per_class_acc[i]:.2f}%")

print("\nThis establishes our baseline to beat!")
print(f"RareSight Target: 75-80% (need +{75 - overall_acc:.2f}% improvement)")