# 01 — Data Exploration

Explore the HAM10000 skin lesion dataset:
- Class distribution
- Image dimensions & statistics
- Sample visualisations per class
- Correlation between metadata fields

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
from collections import Counter

from src.config import (
    RAW_DIR, PROCESSED_DIR, SPLITS_DIR,
    MALIGNANT_CLASSES, BENIGN_CLASSES, IDX_TO_LABEL,
)

# Original 7-class HAM10000 diagnosis labels (for raw data exploration)
DX_CLASSES = sorted(BENIGN_CLASSES | MALIGNANT_CLASSES)

sns.set_theme(style='whitegrid')
%matplotlib inline

## 1. Load metadata

In [None]:
# Update this path to your HAM10000 metadata CSV
META_CSV = RAW_DIR / 'HAM10000_metadata.csv'

df = pd.read_csv(META_CSV)
print(f'Total samples: {len(df)}')
df.head()

In [None]:
df.info()
print()
df.describe()

## 2. Class distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Count plot
order = df['dx'].value_counts().index
sns.countplot(data=df, y='dx', order=order, ax=axes[0], palette='viridis')
axes[0].set_title('Class distribution (count)')
axes[0].set_xlabel('Count')

# Pie chart
n_dx = df['dx'].nunique()
df['dx'].value_counts().plot.pie(autopct='%1.1f%%', ax=axes[1], colors=sns.color_palette('viridis', n_dx))
axes[1].set_ylabel('')
axes[1].set_title('Class distribution (%)')

plt.tight_layout()
plt.show()

print(df['dx'].value_counts())

## 3. Image dimensions

In [None]:
# Sample a subset for speed
sample_ids = df['image_id'].sample(min(500, len(df)), random_state=42)
widths, heights = [], []

for img_id in sample_ids:
    path = RAW_DIR / f'{img_id}.jpg'
    if path.exists():
        w, h = Image.open(path).size
        widths.append(w)
        heights.append(h)

print(f'Width  — min: {min(widths)}, max: {max(widths)}, mean: {np.mean(widths):.0f}')
print(f'Height — min: {min(heights)}, max: {max(heights)}, mean: {np.mean(heights):.0f}')

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(widths, bins=30, color='steelblue', edgecolor='white')
axes[0].set_title('Image widths')
axes[1].hist(heights, bins=30, color='coral', edgecolor='white')
axes[1].set_title('Image heights')
plt.tight_layout()
plt.show()

## 4. Sample images per class

In [None]:
fig, axes = plt.subplots(len(DX_CLASSES), 5, figsize=(15, 3 * len(DX_CLASSES)))

for row, cls in enumerate(DX_CLASSES):
    subset = df[df['dx'] == cls]
    if len(subset) == 0:
        for col in range(5):
            axes[row, col].axis('off')
        axes[row, 0].set_title(f'{cls} (no samples)', fontsize=14, fontweight='bold')
        continue
    samples = subset.sample(min(5, len(subset)), random_state=42)
    for col, (_, s) in enumerate(samples.iterrows()):
        path = RAW_DIR / f"{s['image_id']}.jpg"
        if path.exists():
            img = Image.open(path)
            axes[row, col].imshow(img)
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].set_title(cls, fontsize=14, fontweight='bold')

plt.suptitle('Sample images per class (HAM10000 dx labels)', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

## 5. Metadata analysis

In [None]:
# Age distribution by class
if 'age' in df.columns:
    plt.figure(figsize=(12, 5))
    sns.boxplot(data=df, x='dx', y='age', order=DX_CLASSES, palette='viridis')
    plt.title('Age distribution per class')
    plt.show()

# Sex distribution
if 'sex' in df.columns:
    fig, ax = plt.subplots(figsize=(10, 5))
    present_classes = [c for c in DX_CLASSES if c in df['dx'].values]
    pd.crosstab(df['dx'], df['sex']).loc[present_classes].plot(kind='bar', ax=ax)
    plt.title('Sex distribution per class')
    plt.tight_layout()
    plt.show()

# Localisation
if 'localization' in df.columns:
    plt.figure(figsize=(12, 5))
    sns.countplot(data=df, y='localization',
                  order=df['localization'].value_counts().index[:10],
                  palette='viridis')
    plt.title('Top 10 lesion locations')
    plt.tight_layout()
    plt.show()

## 6. Check splits

In [None]:
for split in ['train', 'val', 'test']:
    path = SPLITS_DIR / f'{split}.csv'
    if path.exists():
        sdf = pd.read_csv(path)
        label_dist = dict(sdf['label'].value_counts().sort_index())
        label_named = {IDX_TO_LABEL[k]: v for k, v in label_dist.items()}
        print(f'{split:>5}: {len(sdf)} samples — {label_named}')
    else:
        print(f'{split:>5}: NOT FOUND  (run src/data/split_data.py first)')