# Swin Transformer Baseline
## Disease Classification on PlantVillage

**Architecture**: Swin-Tiny -> Custom Head (768->256->15)

**Training Schedule**:
- Epochs 1-10: Stages 0,1 frozen, lr=1e-4
- Epoch 10: Unfreeze all, lr=1e-5
- Epochs 11-50: Full fine-tuning with CosineAnnealing

In [None]:
import sys
sys.path.insert(0, '../..')

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.models.swin_classifier import SwinClassifier
from src.data.datamodule import PlantVillageDataModule
from src.features.augmentation import get_train_transforms, get_val_transforms
from src.utils.config import ConfigManager
from src.utils.seed import set_seed

cfg = ConfigManager('../../configs/config.yaml')
set_seed(cfg.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Model Architecture

In [None]:
model = SwinClassifier(
    model_name=cfg.model.variant,
    num_classes=cfg.model.num_classes,
    pretrained=True,
    hidden_dim=cfg.model.head.hidden_dim,
    dropout=cfg.model.head.dropout,
)

print(f'Total parameters: {model.get_total_params():,}')
print(f'Trainable parameters: {model.get_trainable_params():,}')

# Test forward pass
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
    logits = model(x)
    features = model.get_features(x)
print(f'\nLogits shape: {logits.shape}')
print(f'Feature shape: {features.shape}')

## 2. Freeze Strategy Verification

In [None]:
# Before freezing
print(f'Before freeze: {model.get_trainable_params():,} trainable')

# Freeze stages 0,1
model.freeze_stages([0, 1])
print(f'After freeze [0,1]: {model.get_trainable_params():,} trainable')

# Unfreeze all
model.unfreeze_all()
print(f'After unfreeze all: {model.get_trainable_params():,} trainable')

## 3. Training

Run the full training with:
```bash
python train_classifier.py --config configs/config.yaml
```

## 4. Load Results (after training)

In [None]:
import json
from pathlib import Path

metrics_path = Path('../../reports/metrics/test_metrics.json')
if metrics_path.exists():
    with open(metrics_path) as f:
        metrics = json.load(f)
    print('Overall Metrics:')
    for k, v in metrics['overall'].items():
        print(f'  {k}: {v:.4f}')
else:
    print('Metrics file not found. Run train_classifier.py and evaluate.py first.')

## 5. Grad-CAM Visualization

In [None]:
from src.visualization.gradcam import SwinGradCAM
from PIL import Image

# Load best model checkpoint
ckpt_path = Path('../../models/checkpoints/best_swin_classifier.pth')
if ckpt_path.exists():
    model_eval = SwinClassifier(
        model_name=cfg.model.variant,
        num_classes=cfg.model.num_classes,
        pretrained=False,
    )
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model_eval.load_state_dict(ckpt['model_state_dict'])
    model_eval.eval()

    gradcam = SwinGradCAM(model_eval)
    transform = get_val_transforms(img_size=224)

    # Pick a sample image
    sample_dir = Path('../../data/raw/PlantVillage/Tomato_Early_blight')
    if sample_dir.exists():
        sample_img = list(sample_dir.glob('*.jpg'))[0]
        img = np.array(Image.open(sample_img).convert('RGB'))
        tensor = transform(image=img)['image'].unsqueeze(0)

        heatmap = gradcam.generate(tensor)
        overlay = gradcam.visualize(img, heatmap)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img); axes[0].set_title('Original')
        axes[1].imshow(heatmap, cmap='jet'); axes[1].set_title('Grad-CAM')
        axes[2].imshow(overlay); axes[2].set_title('Overlay')
        for ax in axes: ax.axis('off')
        plt.tight_layout()
        plt.show()
else:
    print('Model checkpoint not found. Train the model first.')