# Crop Disease Classification

**Objective:** Build a machine learning model to classify plant diseases from leaf images using transfer learning with PyTorch.  
**Dataset:** PlantVillage (12 selected classes across Tomato, Potato, and Pepper crops)  
**Models:** ResNet-50, EfficientNet-B0, MobileNetV3-Small  
**Author:** Santosh Shinde  
**Date:** February 2026

---
## 0. Setup
Import all modules, set seed, detect device, print environment.

In [None]:
# Standard library
import os
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# Third party
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
from PIL import Image
from sklearn.metrics import confusion_matrix as sk_confusion_matrix

# Project modules
from src.config import Config, DataConfig, TrainConfig, ModelConfig
from src.utils.seed import set_seed
from src.data.dataset import PlantDiseaseDataset
from src.data.transforms import get_train_transforms, get_val_transforms
from src.data.splitter import create_stratified_split
from src.data.loader import create_dataloaders, SplitDataset
from src.utils.text_helpers import shorten_class_name, get_crop_name
from src.utils.plot_data import (
    plot_sample_images, plot_class_distribution, plot_augmentation_examples,
)
from src.utils.plot_training import plot_training_curves, plot_model_comparison
from src.models.factory import get_model, count_parameters, get_differential_lr_params
from src.models.freeze import freeze_backbone, partial_unfreeze, full_unfreeze
from src.training.trainer import Trainer
from src.evaluation.metrics import (
    compute_predictions, generate_classification_report, compute_summary_metrics
)
from src.evaluation.confusion import plot_confusion_matrix
from src.evaluation.predictions import get_prediction_examples, plot_prediction_grid
from src.evaluation.profiler import profile_model

# Matplotlib settings
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 150
sns.set_style('whitegrid')
%matplotlib inline

print(f'Python: {sys.version}')
print(f'PyTorch: {torch.__version__}')
print(f'Torchvision: {torchvision.__version__}')
print(f'NumPy: {np.__version__}')
print(f'Project Root: {PROJECT_ROOT}')

In [None]:
# Set seed for reproducibility
set_seed(42)

# Detect device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print(f'Using CUDA: {torch.cuda.get_device_name(0)}')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
    print('Using Apple MPS (Metal Performance Shaders)')
else:
    DEVICE = torch.device('cpu')
    print('Using CPU')

print(f'Device: {DEVICE}')

In [None]:
# Initialize configuration
config = Config()

# Update data directory to point to actual dataset
# Adjust this path based on your setup
DATASET_ROOT = PROJECT_ROOT.parent / 'PlantVillage Dataset' / 'PlantVillage'
if not DATASET_ROOT.exists():
    # Try alternative paths
    alt_paths = [
        PROJECT_ROOT / 'data' / 'raw' / 'PlantVillage',
        Path('../PlantVillage Dataset/PlantVillage'),
    ]
    for p in alt_paths:
        if p.exists():
            DATASET_ROOT = p
            break

config.data.raw_data_dir = DATASET_ROOT
print(f'Dataset root: {DATASET_ROOT}')
print(f'Dataset exists: {DATASET_ROOT.exists()}')
print(f'Selected classes ({len(config.data.selected_classes)}):')
for cls in config.data.selected_classes:
    print(f'  - {cls}')

---
## 1. Data Exploration (Part 1)

Load and visualize the dataset, analyze class distribution, and extract key insights.

In [None]:
# Load dataset without transforms for visualization
dataset = PlantDiseaseDataset(
    root_dir=config.data.raw_data_dir,
    selected_classes=config.data.selected_classes,
    transform=None,  # Raw images for exploration
)

print(f'Total images: {len(dataset):,}')
print(f'Number of classes: {dataset.num_classes}')
print(f'\nClass-to-Index mapping:')
for cls, idx in dataset.class_to_idx.items():
    print(f'  {idx:2d}: {cls}')

In [None]:
# Dataset statistics
class_counts = dataset.get_class_counts()
counts_series = pd.Series(class_counts)

print('\nPer-class statistics:')
print(f'  Total images: {counts_series.sum():,}')
print(f'  Min count:    {counts_series.min():,} ({counts_series.idxmin()})')
print(f'  Max count:    {counts_series.max():,} ({counts_series.idxmax()})')
print(f'  Mean count:   {counts_series.mean():,.0f}')
print(f'  Std count:    {counts_series.std():,.0f}')
print(f'  Imbalance ratio (max/min): {counts_series.max() / counts_series.min():.1f}x')

# Display as DataFrame
df_counts = pd.DataFrame({
    'Class': [shorten_class_name(k) for k in class_counts.keys()],
    'Full Name': list(class_counts.keys()),
    'Count': list(class_counts.values()),
    'Crop': [get_crop_name(k) for k in class_counts.keys()],
})
df_counts = df_counts.sort_values('Count', ascending=False).reset_index(drop=True)
df_counts

In [None]:
# Plot 5x5 sample images grid
(PROJECT_ROOT / 'outputs').mkdir(parents=True, exist_ok=True)

fig = plot_sample_images(
    dataset,
    num_classes=5,
    images_per_class=5,
    figsize=(20, 20),
    save_path=PROJECT_ROOT / 'outputs' / 'sample_images_grid.png',
)
plt.show()

In [None]:
# Plot class distribution bar chart
fig = plot_class_distribution(
    class_counts,
    figsize=(12, 8),
    save_path=PROJECT_ROOT / 'outputs' / 'class_distribution.png',
)
plt.show()

### Key Insights from Data Exploration

**Insight 1 — Class Imbalance:**  
There is significant class imbalance in the dataset. The largest class (Tomato Bacterial Spot with ~2,127 images) has substantially more samples than the smallest class (Potato Healthy with ~152 images), yielding an imbalance ratio of approximately 14:1. This imbalance could bias the model toward majority classes, making stratified splitting and potentially class-weighted loss important considerations.

**Insight 2 — Visual Similarity Across Crops:**  
Diseases that occur across multiple crops (e.g., Early Blight on both Tomato and Potato, Bacterial Spot on Tomato and Pepper) share similar visual patterns — yellowing, spotting, and necrotic lesions. This cross-crop visual similarity is a key classification challenge that tests whether the model learns crop-specific vs. disease-specific features. Confusion between these pairs is expected and informative.

**Insight 3 — Lab-Controlled Image Quality:**  
PlantVillage images are captured under controlled laboratory conditions with uniform backgrounds and consistent lighting. While this produces a clean signal for training, it creates a significant domain gap with real-world field photography where images may include variable lighting, complex backgrounds, partial leaf occlusion, and motion blur. This limitation should be addressed through aggressive data augmentation and acknowledged as a deployment consideration.

In [None]:
# Visual similarity insight — show Early Blight comparison across crops
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('Visual Similarity: Same Disease Across Crops', fontsize=14)

# Tomato Early Blight samples
tomato_eb_idx = dataset.class_to_idx.get('Tomato_Early_blight', 0)
tomato_eb_samples = [i for i, (_, l) in enumerate(dataset.samples) if l == tomato_eb_idx][:4]
for col, idx in enumerate(tomato_eb_samples):
    img, _ = dataset[idx]
    axes[0, col].imshow(img)
    axes[0, col].axis('off')
    if col == 0:
        axes[0, col].set_ylabel('Tomato\nEarly Blight', fontsize=11, rotation=0, labelpad=80)

# Potato Early Blight samples
potato_eb_idx = dataset.class_to_idx.get('Potato___Early_blight', 0)
potato_eb_samples = [i for i, (_, l) in enumerate(dataset.samples) if l == potato_eb_idx][:4]
for col, idx in enumerate(potato_eb_samples):
    img, _ = dataset[idx]
    axes[1, col].imshow(img)
    axes[1, col].axis('off')
    if col == 0:
        axes[1, col].set_ylabel('Potato\nEarly Blight', fontsize=11, rotation=0, labelpad=80)

plt.tight_layout()
plt.show()

---
## 2. Data Pipeline (Part 2a)

Create stratified train/val/test splits, show augmentation examples, and build DataLoaders.

In [None]:
# Create stratified split
splits = create_stratified_split(
    samples=dataset.samples,
    split_ratios=config.data.split_ratios,
    seed=config.data.random_seed,
)

print('Split sizes:')
for split_name, samples in splits.items():
    labels = [s[1] for s in samples]
    print(f'  {split_name:5s}: {len(samples):,} images ({len(samples)/len(dataset)*100:.1f}%)')

# Verify stratification
print('\nStratification verification (class proportions):')
for split_name, samples in splits.items():
    label_counts = Counter(s[1] for s in samples)
    total = len(samples)
    print(f'  {split_name}: ', end='')
    for idx in sorted(label_counts.keys())[:3]:
        print(f'{dataset.idx_to_class[idx][:15]}={label_counts[idx]/total:.2%} ', end='')
    print('...')

In [None]:
# Show augmentation examples
train_transform = get_train_transforms(config.data.image_size)
val_transform = get_val_transforms(config.data.image_size)

# Pick a sample image for augmentation demo
sample_path = str(dataset.samples[0][0])
fig = plot_augmentation_examples(sample_path, train_transform, num_augmented=5)
plt.show()

In [None]:
# Create DataLoaders
dataloaders = create_dataloaders(
    splits=splits,
    train_transform=train_transform,
    val_transform=val_transform,
    batch_size=config.train.batch_size,
    num_workers=config.train.num_workers,
    pin_memory=config.train.pin_memory,
)

# Verify batch shapes
for split_name, loader in dataloaders.items():
    images, labels = next(iter(loader))
    print(f'{split_name:5s}: images={images.shape}, labels={labels.shape}, dtype={images.dtype}')

---
## 3. Model Training (Part 2b)

Train three models using three-stage progressive fine-tuning:
1. **Stage 1** -- Feature Extraction (frozen backbone, head only)
2. **Stage 2** -- Adaptation (partial unfreeze + head)
3. **Stage 3** -- Full Refinement (all parameters, differential LR)

In [None]:
def train_model_three_stages(
    model_name: str,
    config: Config,
    dataloaders: dict,
    device: torch.device,
) -> dict:
    """Train a model through all three stages.
    
    Returns combined training history.
    """
    print(f'\n{"="*60}')
    print(f'Training {model_name.upper()}')
    print(f'{"="*60}')
    
    # Initialize model
    model = get_model(
        name=model_name,
        num_classes=config.model.num_classes,
        pretrained=config.model.pretrained,
        dropout=config.model.dropout,
    ).to(device)
    
    combined_history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'val_f1': [],
        'lr': [], 'epoch_time': [],
    }
    
    # ---- STAGE 1: Feature Extraction ----
    print(f'\n--- Stage 1: Feature Extraction (frozen backbone) ---')
    freeze_backbone(model, model_name)
    params = count_parameters(model)
    print(f'Trainable: {params["trainable"]:,} / {params["total"]:,} ({params["trainable"]/params["total"]*100:.1f}%)')
    
    trainer = Trainer(
        model=model,
        num_classes=config.model.num_classes,
        learning_rate=config.train.stage1_lr,
        weight_decay=config.train.weight_decay,
        label_smoothing=config.train.label_smoothing,
        device=device,
        checkpoint_dir=config.train.checkpoint_dir,
        model_name=model_name,
        max_grad_norm=config.train.max_grad_norm,
        use_amp=config.train.use_amp,
    )
    
    history1 = trainer.fit(
        dataloaders['train'], dataloaders['val'],
        num_epochs=config.train.stage1_epochs,
        scheduler_type='cosine',
        patience=config.train.early_stopping_patience,
    )
    
    for key in combined_history:
        combined_history[key].extend(history1[key])
    
    # ---- STAGE 2: Adaptation ----
    print(f'\n--- Stage 2: Adaptation (partial unfreeze) ---')
    partial_unfreeze(model, model_name)
    params = count_parameters(model)
    print(f'Trainable: {params["trainable"]:,} / {params["total"]:,} ({params["trainable"]/params["total"]*100:.1f}%)')
    
    # Re-initialize optimizer for newly unfrozen parameters
    trainer = Trainer(
        model=model,
        num_classes=config.model.num_classes,
        learning_rate=config.train.stage2_lr,
        weight_decay=config.train.weight_decay,
        label_smoothing=config.train.label_smoothing,
        device=device,
        checkpoint_dir=config.train.checkpoint_dir,
        model_name=model_name,
        max_grad_norm=config.train.max_grad_norm,
        use_amp=config.train.use_amp,
    )
    # Carry over best F1 from stage 1
    trainer.best_val_f1 = max(history1['val_f1']) if history1['val_f1'] else 0.0
    
    history2 = trainer.fit(
        dataloaders['train'], dataloaders['val'],
        num_epochs=config.train.stage2_epochs,
        scheduler_type='cosine',
        patience=config.train.early_stopping_patience,
    )
    
    for key in combined_history:
        combined_history[key].extend(history2[key])
    
    # ---- STAGE 3: Full Refinement ----
    print(f'\n--- Stage 3: Full refinement (all parameters, differential LR) ---')
    full_unfreeze(model)
    params = count_parameters(model)
    print(f'Trainable: {params["trainable"]:,} / {params["total"]:,} ({params["trainable"]/params["total"]*100:.1f}%)')
    
    # Differential LR: backbone gets lower LR
    param_groups = get_differential_lr_params(
        model, model_name,
        backbone_lr=config.train.stage3_lr,
        head_lr=config.train.stage3_lr * 5,  # 5× higher for head
    )
    
    # Re-initialize optimizer with differential LR
    trainer = Trainer(
        model=model,
        num_classes=config.model.num_classes,
        weight_decay=config.train.weight_decay,
        label_smoothing=config.train.label_smoothing,
        device=device,
        checkpoint_dir=config.train.checkpoint_dir,
        model_name=model_name,
        max_grad_norm=config.train.max_grad_norm,
        use_amp=config.train.use_amp,
        param_groups=param_groups,
    )
    # Carry over best F1 from stage 2
    prev_best = max(
        max(history1['val_f1']) if history1['val_f1'] else 0.0,
        max(history2['val_f1']) if history2['val_f1'] else 0.0
    )
    trainer.best_val_f1 = prev_best
    
    history3 = trainer.fit(
        dataloaders['train'], dataloaders['val'],
        num_epochs=config.train.stage3_epochs,
        scheduler_type='cosine',
        patience=config.train.early_stopping_patience,
    )
    
    for key in combined_history:
        combined_history[key].extend(history3[key])
    
    print(f'\n{model_name.upper()} training complete!')
    print(f'Best val F1: {trainer.best_val_f1:.4f}')
    
    return combined_history

In [None]:
# Train all three models
model_names = ['resnet50', 'efficientnet_b0', 'mobilenetv3']
all_histories = {}

for model_name in model_names:
    history = train_model_three_stages(model_name, config, dataloaders, DEVICE)
    all_histories[model_name] = history

In [None]:
# Plot training curves for all models
stage_boundaries = [
    config.train.stage1_epochs,
    config.train.stage1_epochs + config.train.stage2_epochs,
]

fig = plot_training_curves(
    all_histories,
    stage_boundaries=stage_boundaries,
    figsize=(16, 12),
    save_path=PROJECT_ROOT / 'outputs' / 'training_curves.png',
)
plt.show()

---
## 4. Evaluation (Part 3a)

Load best checkpoints, generate predictions on test set, and analyze performance.

In [None]:
# Evaluate all models on test set
all_metrics = {}
all_predictions = {}

for model_name in model_names:
    print(f'\n{"="*60}')
    print(f'Evaluating {model_name.upper()}')
    print(f'{"="*60}')
    
    # Load best checkpoint
    model = get_model(
        name=model_name,
        num_classes=config.model.num_classes,
        pretrained=False,
        dropout=config.model.dropout,
    )
    
    checkpoint_path = config.train.checkpoint_dir / f'{model_name}_best.pth'
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f'Loaded checkpoint from epoch {checkpoint["epoch"]} (val_f1={checkpoint["val_f1"]:.4f})')
    else:
        print(f'WARNING: No checkpoint found at {checkpoint_path}')
    
    model = model.to(DEVICE)
    model.eval()
    
    # Generate predictions
    preds, labels, probs = compute_predictions(model, dataloaders['test'], DEVICE)
    all_predictions[model_name] = (preds, labels, probs)
    
    # Classification report
    class_names = [dataset.idx_to_class[i] for i in range(config.model.num_classes)]
    report = generate_classification_report(labels, preds, class_names)
    print(f'\nClassification Report:')
    print(report)
    
    # Summary metrics
    metrics = compute_summary_metrics(labels, preds)
    all_metrics[model_name] = metrics
    print(f'Summary: Accuracy={metrics["accuracy"]:.4f}, F1 Macro={metrics["f1_macro"]:.4f}, F1 Weighted={metrics["f1_weighted"]:.4f}')

In [None]:
# Confusion matrices for each model
class_names = [dataset.idx_to_class[i] for i in range(config.model.num_classes)]

for model_name in model_names:
    preds, labels, _ = all_predictions[model_name]
    fig = plot_confusion_matrix(
        labels, preds, class_names,
        normalize=True,
        save_path=PROJECT_ROOT / 'outputs' / f'confusion_matrix_{model_name}.png',
        figsize=(12, 10),
    )
    plt.suptitle(f'Confusion Matrix -- {model_name}', fontsize=14)
    plt.show()

In [None]:
# 5 Correct and 5 Incorrect predictions for best model
best_model_name = max(all_metrics, key=lambda k: all_metrics[k]['f1_macro'])
print(f'Best model: {best_model_name} (F1 Macro: {all_metrics[best_model_name]["f1_macro"]:.4f})')

# Load best model
best_model = get_model(
    name=best_model_name,
    num_classes=config.model.num_classes,
    pretrained=False,
    dropout=config.model.dropout,
)
checkpoint = torch.load(
    config.train.checkpoint_dir / f'{best_model_name}_best.pth',
    map_location=DEVICE, weights_only=True
)
best_model.load_state_dict(checkpoint['model_state_dict'])
best_model = best_model.to(DEVICE)
best_model.eval()

# Get test dataset with transform
test_dataset = SplitDataset(splits['test'], transform=val_transform)

correct_examples, incorrect_examples = get_prediction_examples(
    best_model, test_dataset, dataset.idx_to_class, DEVICE,
    num_correct=5, num_incorrect=5,
)

print(f'\nCollected {len(correct_examples)} correct and {len(incorrect_examples)} incorrect examples')

In [None]:
# Plot correct predictions
fig = plot_prediction_grid(
    correct_examples,
    title='Correct Predictions',
    save_path=PROJECT_ROOT / 'outputs' / 'correct_predictions.png',
)
plt.show()

In [None]:
# Plot incorrect predictions
fig = plot_prediction_grid(
    incorrect_examples,
    title='Incorrect Predictions',
    save_path=PROJECT_ROOT / 'outputs' / 'incorrect_predictions.png',
)
plt.show()

In [None]:
# Error analysis -- most confused class pairs
preds, labels, _ = all_predictions[best_model_name]
cm = sk_confusion_matrix(labels, preds)

# Work on a copy to avoid modifying the original matrix
cm_off_diag = cm.copy()
np.fill_diagonal(cm_off_diag, 0)

confused_pairs = []
for i in range(cm_off_diag.shape[0]):
    for j in range(cm_off_diag.shape[1]):
        if cm_off_diag[i, j] > 0:
            confused_pairs.append((i, j, cm_off_diag[i, j]))

confused_pairs.sort(key=lambda x: x[2], reverse=True)

print(f'\nTop 5 Most Confused Class Pairs ({best_model_name}):')
print(f'{"True Class":<35} {"Predicted As":<35} {"Count":<10}')
print('-' * 80)
for true_idx, pred_idx, count in confused_pairs[:5]:
    true_name = shorten_class_name(class_names[true_idx])
    pred_name = shorten_class_name(class_names[pred_idx])
    print(f'{true_name:<35} {pred_name:<35} {count:<10}')

---
## 5. Model Comparison (Part 3b)

Profile all models and compare on accuracy, F1, model size, and latency.

In [None]:
# Profile all models
comparison_data = []

for model_name in model_names:
    print(f'\nProfiling {model_name}...')
    
    model = get_model(
        name=model_name,
        num_classes=config.model.num_classes,
        pretrained=False,
        dropout=config.model.dropout,
    )
    checkpoint_path = config.train.checkpoint_dir / f'{model_name}_best.pth'
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
    
    profile = profile_model(model, device=DEVICE, num_warmup=5, num_runs=50)
    metrics = all_metrics[model_name]
    
    comparison_data.append({
        'model': model_name,
        'accuracy': metrics['accuracy'],
        'f1_macro': metrics['f1_macro'],
        'f1_weighted': metrics['f1_weighted'],
        'model_size_mb': profile['model_size_mb'],
        'total_params': profile['total_params'],
        'cpu_latency_mean_ms': profile['cpu_latency_mean_ms'],
        'cpu_latency_p95_ms': profile['cpu_latency_p95_ms'],
    })

comparison_df = pd.DataFrame(comparison_data)
print('\n' + '='*80)
print('MODEL COMPARISON')
print('='*80)
print(comparison_df.to_string(index=False))

In [None]:
# Plot model comparison charts
fig = plot_model_comparison(
    comparison_df,
    figsize=(14, 6),
    save_path=PROJECT_ROOT / 'outputs' / 'model_comparison.png',
)
plt.show()

---
## 6. Business Recommendation (Part 3c)

### Which model to deploy for a mobile app for farmers?

**Recommendation: EfficientNet-B0**

For a mobile crop disease detection app, we recommend **EfficientNet-B0** as the deployment model based on the following analysis:

| Criterion | ResNet-50 | EfficientNet-B0 | MobileNetV3-Small |
|-----------|-----------|-----------------|--------------------|
| Model Size | ~98 MB | ~20 MB | ~10 MB |
| TFLite INT8 | ~25 MB | ~5 MB | ~2.5 MB |
| CPU Inference | Slowest | Moderate | Fastest |
| Accuracy | Highest | High (within 1-2%) | Lower (2-3% drop) |

**Why not ResNet-50?** At ~98 MB, it is too large for mobile deployment. The marginal accuracy gain (~1%) does not justify the 5x larger model size and 3x higher inference latency on mobile devices.

**Why not MobileNetV3-Small?** While ultra-lightweight (~10 MB), the 2-3% accuracy drop translates to missed diseases in production. In agriculture, a false negative (missing a disease) can lead to crop loss and incorrect treatment, directly impacting farmer outcomes .

**EfficientNet-B0 is the sweet spot:** It achieves competitive accuracy while being small enough for mobile deployment (~5 MB after INT8 quantization). The deployment pipeline would be: PyTorch (.pth) -> ONNX (.onnx) -> TFLite (.tflite) with INT8 quantization, running entirely on-device for offline functionality.

### Confidence Thresholding Strategy
Predictions with confidence below 70% should be suppressed with a message: "Low confidence. Please retake the photo with better lighting and a clear view of the leaf." This prevents the worst UX failure: a confidently wrong diagnosis leading to incorrect crop treatment.

### Known Limitations
1. **Lab-to-field domain gap**: Model trained on clean lab images may underperform on messy field photos
2. **Single-disease assumption**: Cannot detect co-infections
3. **Limited crop coverage**: Only 3 crops / 12 classes (same pipeline scales to more)
4. **No severity grading**: Detects disease type but not progression stage

---
## 7. Export

Save best model checkpoint, class mapping, and all figures.

In [None]:
# Export class mapping
class_mapping = {str(idx): name for idx, name in dataset.idx_to_class.items()}

config.train.checkpoint_dir.mkdir(parents=True, exist_ok=True)

mapping_path = config.train.checkpoint_dir / 'class_mapping.json'
with open(mapping_path, 'w') as f:
    json.dump(class_mapping, f, indent=2)
print(f'Class mapping saved to {mapping_path}')

# Print class mapping
print('\nClass Mapping:')
for idx, name in sorted(class_mapping.items(), key=lambda x: int(x[0])):
    print(f'  {idx}: {name}')

In [None]:
# Save configuration
config_dict = {
    'data': {
        'raw_data_dir': str(config.data.raw_data_dir),
        'image_size': config.data.image_size,
        'selected_classes': config.data.selected_classes,
        'split_ratios': config.data.split_ratios,
        'random_seed': config.data.random_seed,
    },
    'train': {
        'batch_size': config.train.batch_size,
        'stage1_epochs': config.train.stage1_epochs,
        'stage1_lr': config.train.stage1_lr,
        'stage2_epochs': config.train.stage2_epochs,
        'stage2_lr': config.train.stage2_lr,
        'stage3_epochs': config.train.stage3_epochs,
        'stage3_lr': config.train.stage3_lr,
        'weight_decay': config.train.weight_decay,
        'label_smoothing': config.train.label_smoothing,
        'early_stopping_patience': config.train.early_stopping_patience,
        'scheduler': config.train.scheduler,
    },
    'model': {
        'num_classes': config.model.num_classes,
        'dropout': config.model.dropout,
        'pretrained': config.model.pretrained,
    },
    'results': {
        model_name: all_metrics[model_name]
        for model_name in model_names
        if model_name in all_metrics
    }
}

config_path = config.train.checkpoint_dir / 'training_config.json'
with open(config_path, 'w') as f:
    json.dump(config_dict, f, indent=2)
print(f'Configuration saved to {config_path}')

In [None]:
# Summary of all outputs
print('\n' + '=' * 60)
print('EXPORT SUMMARY')
print('=' * 60)

outputs_dir = PROJECT_ROOT / 'outputs'
models_dir = config.train.checkpoint_dir

print('\nSaved Models:')
if models_dir.exists():
    for f in sorted(models_dir.iterdir()):
        size_mb = f.stat().st_size / (1024 ** 2)
        print(f'  {f.name} ({size_mb:.1f} MB)')

print('\nSaved Figures:')
if outputs_dir.exists():
    for f in sorted(outputs_dir.iterdir()):
        if f.suffix == '.png':
            print(f'  {f.name}')

print('\nFinal Model Performance:')
for model_name, metrics in all_metrics.items():
    print(f'  {model_name:20s}: Acc={metrics["accuracy"]:.4f}, F1={metrics["f1_macro"]:.4f}')

print('\nAll exports complete!')