In [None]:
import sys
sys.path.append('..')

import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from src.data.dataset import EpisodicDermaMNIST
from src.data.preprocessing import DermaMNISTPreprocessor
import yaml

print("Imports successful!")

dataset = EpisodicDermaMNIST(split="train", download=True)

print(f"Dataset loaded: {len(dataset.images)} images")
print(f"Classes: {dataset.class_names}")

fig, ax = plt.subplots(figsize=(12, 6))

classes = list(dataset.class_counts.keys())
counts = [dataset.class_counts[c] for c in classes]
names = [dataset.class_names[c] for c in classes]

bars = ax.bar(range(len(classes)), counts, color='steelblue', alpha=0.7)

for i, (cls, count) in enumerate(zip(classes, counts)):
    if count < 200:
        bars[i].set_color('coral')
        bars[i].set_alpha(0.9)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Number of Samples', fontsize=12)
ax.set_title('DermaMNIST Training Set Class Distribution', fontsize=14, fontweight='bold')
ax.set_xticks(range(len(classes)))
ax.set_xticklabels([f"Class {c}" for c in classes], rotation=45, ha='right')
ax.grid(axis='y', alpha=0.3)

for i, (bar, count) in enumerate(zip(bars, counts)):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{count}',
            ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('../experiments/class_distribution.png', dpi=150)
plt.show()

support_img, support_lbl, query_img, query_lbl = dataset.sample_episode(
    n_way=5, k_shot=5, n_query=15
)

print(f"Episode sampled:")
print(f"  Support: {support_img.shape}")
print(f"  Query: {query_img.shape}")
print(f"  Classes in episode: {np.unique(support_lbl.numpy())}")

preprocessor = DermaMNISTPreprocessor()

fig, axes = plt.subplots(5, 5, figsize=(12, 12))
fig.suptitle('5-Way 5-Shot Support Set', fontsize=16, fontweight='bold')

for i, ax in enumerate(axes.flat):
    if i < len(support_img):
        img = preprocessor.denormalize(support_img[i])
        label = support_lbl[i].item()
        
        ax.imshow(img)
        ax.set_title(f"Class {label}", fontsize=10)
        ax.axis('off')

plt.tight_layout()
plt.savefig('../experiments/sample_episode.png', dpi=150)
plt.show()

print("\nData exploration complete!")