# TP2: Fine-Tuning for Medical Image Classification

**Day 3 - AI for Sciences Winter School**

**Instructor:** Raphael Cousin

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/racousin/ai_for_sciences/blob/main/day3/tp2.ipynb)

---

## Objectives

In this exercise, you will apply transfer learning to a **real scientific task**: classifying blood cells from microscopy images.

By the end of this practical, you will:

1. **Apply fine-tuning** to a medical imaging dataset (BloodMNIST)
2. **Compare approaches**: frozen backbone vs full fine-tuning
3. **Analyze results**: understand when each approach works best
4. **Evaluate models**: using accuracy, confusion matrix, and per-class metrics

---

# Part 1: The Dataset - Blood Cell Classification

## About BloodMNIST

**BloodMNIST** is a dataset of **17,092 microscopy images** of individual white blood cells:

- 8 classes of blood cells (basophil, eosinophil, erythroblast, etc.)
- Image size: 28x28 pixels (we'll resize to 224x224 for ResNet)
- Clinical relevance: automated blood cell counting aids diagnosis

```
Blood Cell Types:
├── Basophil          (class 0)
├── Eosinophil        (class 1)  
├── Erythroblast      (class 2)
├── Immature granulocyte (class 3)
├── Lymphocyte        (class 4)
├── Monocyte          (class 5)
├── Neutrophil        (class 6)
└── Platelet          (class 7)
```

This is a realistic scientific dataset where transfer learning can help!

## Setup

In [None]:
!pip install -q git+https://github.com/racousin/ai_for_sciences.git
!pip install -q medmnist

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import medmnist
from medmnist import INFO
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"MedMNIST version: {medmnist.__version__}")
print("Setup complete!")

## Load the BloodMNIST Dataset

In [None]:
# Dataset info
data_flag = 'bloodmnist'
info = INFO[data_flag]
n_classes = len(info['label'])
class_names = list(info['label'].values())

print(f"Dataset: {info['description']}")
print(f"Number of classes: {n_classes}")
print(f"\nClass names:")
for i, name in enumerate(class_names):
    print(f"  {i}: {name}")

In [None]:
# Data transforms - resize to 224x224 for pre-trained models
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Load datasets
DataClass = getattr(medmnist, info['python_class'])

train_dataset = DataClass(split='train', transform=transform, download=True)
val_dataset = DataClass(split='val', transform=transform, download=True)
test_dataset = DataClass(split='test', transform=transform, download=True)

print(f"\nDataset sizes:")
print(f"  Training:   {len(train_dataset):,} samples")
print(f"  Validation: {len(val_dataset):,} samples")
print(f"  Test:       {len(test_dataset):,} samples")

In [None]:
# Create data loaders
BATCH_SIZE = 64

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Batches per epoch: {len(train_loader)}")

## Visualize the Data

In [None]:
# Helper function to denormalize images for visualization
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return (tensor * std + mean).clamp(0, 1)

# Show samples from each class
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

# Get one sample per class
samples_per_class = {i: None for i in range(n_classes)}
for img, label in train_dataset:
    label_idx = label.item()
    if samples_per_class[label_idx] is None:
        samples_per_class[label_idx] = img
    if all(v is not None for v in samples_per_class.values()):
        break

for i, ax in enumerate(axes.flat):
    img = denormalize(samples_per_class[i])
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.set_title(class_names[i], fontsize=10)
    ax.axis('off')

plt.suptitle('Blood Cell Types', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### Question 1

Look at the cell images:
1. Can you visually distinguish between the different cell types?
2. Which classes look most similar to each other?
3. Why might this be a challenging classification task?

---

# Part 2: Training Utilities

We'll reuse the training functions from TP1.

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(loader, desc='Training', leave=False):
        inputs = inputs.to(device)
        labels = labels.squeeze().long().to(device)  # MedMNIST labels need squeeze
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total


def evaluate(model, loader, criterion, device):
    """Evaluate model on a dataset."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            labels = labels.squeeze().long().to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), 100. * correct / total, all_preds, all_labels


def count_parameters(model):
    """Count trainable parameters."""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total

---

# Part 3: Exercise - Frozen Backbone

Your first task: train a ResNet18 with a **frozen backbone** on the blood cell dataset.

This approach is ideal when:
- You have limited data
- You want fast training
- The source domain (ImageNet) shares some features with your domain

In [None]:
# TODO: Create a ResNet18 model with frozen backbone
# Step 1: Load pre-trained ResNet18
model_frozen = models.resnet18(weights='IMAGENET1K_V1')

# Step 2: Freeze all parameters
# TODO: Complete this loop  # <-- Modify this!
for param in model_frozen.parameters():
    pass  # What should go here?

# Step 3: Replace the classifier for 8 classes
# TODO: Replace the fc layer  # <-- Modify this!
# Hint: model_frozen.fc = nn.Linear(..., n_classes)

model_frozen = model_frozen.to(device)

trainable, total = count_parameters(model_frozen)
print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)')

In [None]:
# Train the frozen model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_frozen.fc.parameters(), lr=0.001)  # Only train classifier

EPOCHS = 5
history_frozen = {'train_acc': [], 'val_acc': []}

print("Training ResNet18 (frozen backbone)...")
print("="*50)

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model_frozen, train_loader, criterion, optimizer, device)
    val_loss, val_acc, _, _ = evaluate(model_frozen, val_loader, criterion, device)
    
    history_frozen['train_acc'].append(train_acc)
    history_frozen['val_acc'].append(val_acc)
    
    print(f'Epoch {epoch+1}/{EPOCHS} | Train: {train_acc:.1f}% | Val: {val_acc:.1f}%')

# Final test evaluation
_, test_acc_frozen, preds_frozen, labels_frozen = evaluate(model_frozen, test_loader, criterion, device)
print(f'\nTest Accuracy (frozen): {test_acc_frozen:.1f}%')

---

# Part 4: Exercise - Full Fine-Tuning

Now train a ResNet18 with **full fine-tuning** - all layers are trainable.

Remember the key tips:
- Use a **lower learning rate** (10x lower than from scratch)
- All layers will adapt to the new domain

In [None]:
# TODO: Create a ResNet18 model for full fine-tuning
model_finetune = models.resnet18(weights='IMAGENET1K_V1')

# TODO: Replace the classifier (but DON'T freeze any layers!)  # <-- Modify this!
# model_finetune.fc = ...

model_finetune = model_finetune.to(device)

trainable, total = count_parameters(model_finetune)
print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)')

In [None]:
# TODO: Train with full fine-tuning
criterion = nn.CrossEntropyLoss()

# TODO: What learning rate should you use?  # <-- Modify this!
# Hint: Lower than 0.001!
optimizer = optim.Adam(model_finetune.parameters(), lr=0.001)  # <-- Modify this!

history_finetune = {'train_acc': [], 'val_acc': []}

print("Training ResNet18 (full fine-tuning)...")
print("="*50)

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model_finetune, train_loader, criterion, optimizer, device)
    val_loss, val_acc, _, _ = evaluate(model_finetune, val_loader, criterion, device)
    
    history_finetune['train_acc'].append(train_acc)
    history_finetune['val_acc'].append(val_acc)
    
    print(f'Epoch {epoch+1}/{EPOCHS} | Train: {train_acc:.1f}% | Val: {val_acc:.1f}%')

# Final test evaluation
_, test_acc_finetune, preds_finetune, labels_finetune = evaluate(model_finetune, test_loader, criterion, device)
print(f'\nTest Accuracy (fine-tuned): {test_acc_finetune:.1f}%')

---

# Part 5: Compare Results

Let's visualize and compare both approaches.

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

epochs_range = range(1, EPOCHS + 1)

# Training accuracy
axes[0].plot(epochs_range, history_frozen['train_acc'], 'g-o', label='Frozen', linewidth=2)
axes[0].plot(epochs_range, history_finetune['train_acc'], 'r-^', label='Fine-tuned', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Accuracy (%)')
axes[0].set_title('Training Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Validation accuracy
axes[1].plot(epochs_range, history_frozen['val_acc'], 'g-o', label='Frozen', linewidth=2)
axes[1].plot(epochs_range, history_finetune['val_acc'], 'r-^', label='Fine-tuned', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Accuracy (%)')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary
print("\n" + "="*50)
print("TEST ACCURACY COMPARISON")
print("="*50)
print(f"Frozen backbone:  {test_acc_frozen:.1f}%")
print(f"Full fine-tuning: {test_acc_finetune:.1f}%")
print("="*50)

## Confusion Matrix Analysis

Let's see which cell types are most often confused.

In [None]:
# Plot confusion matrices for both models
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, preds, labels, title in [
    (axes[0], preds_frozen, labels_frozen, 'Frozen Backbone'),
    (axes[1], preds_finetune, labels_finetune, 'Full Fine-tuning')
]:
    cm = confusion_matrix(labels, preds)
    # Normalize by row (true labels)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=[c[:6] for c in class_names],
                yticklabels=[c[:6] for c in class_names],
                ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(f'{title}\n(normalized by row)')

plt.tight_layout()
plt.show()

In [None]:
# Per-class metrics for the best model
best_preds = preds_finetune if test_acc_finetune > test_acc_frozen else preds_frozen
best_labels = labels_finetune if test_acc_finetune > test_acc_frozen else labels_frozen
best_name = "Fine-tuned" if test_acc_finetune > test_acc_frozen else "Frozen"

print(f"\nDetailed Classification Report ({best_name} model):")
print("="*60)
print(classification_report(best_labels, best_preds, target_names=class_names))

### Question 2

Looking at the confusion matrices:
1. Which cell types are most often confused with each other?
2. Does full fine-tuning improve classification for the difficult classes?
3. What biological reason might explain why certain classes are confused?

---

# Part 6: Bonus Exercise - Data Augmentation

Data augmentation can improve generalization, especially with limited data.

In [None]:
# TODO: Create a transform with data augmentation
# Experiment with different augmentations!

transform_augmented = transforms.Compose([
    transforms.Resize((224, 224)),
    # TODO: Add augmentations here  # <-- Modify this!
    # Suggestions:
    # - transforms.RandomHorizontalFlip()
    # - transforms.RandomRotation(10)
    # - transforms.ColorJitter(brightness=0.2, contrast=0.2)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Transform with augmentation created!")
print("\nTo use it, create new datasets with this transform:")
print("  train_dataset_aug = DataClass(split='train', transform=transform_augmented, download=True)")

### Question 3

1. Which augmentations make sense for microscopy images? (Think about what variations occur naturally)
2. Which augmentations would NOT make sense? (e.g., would vertical flip be appropriate?)
3. How could you validate that augmentation helps?

---

# Part 7: Bonus - Try Another Medical Dataset

MedMNIST includes many medical imaging datasets. You can easily try another one!

In [None]:
# Available MedMNIST datasets
print("Available MedMNIST datasets:\n")
for flag, info in INFO.items():
    print(f"{flag:20s}: {info['description'][:60]}...")
    print(f"{'':20s}  Classes: {len(info['label'])}, Task: {info['task']}")
    print()

In [None]:
# TODO: Try a different dataset!
# Change data_flag to try another dataset:
# - 'dermamnist': Skin lesion images (7 classes)
# - 'pathmnist': Colon pathology (9 classes)
# - 'chestmnist': Chest X-ray (14 classes, multi-label)
# - 'retinamnist': Retinal OCT (5 classes)

# new_data_flag = 'dermamnist'  # <-- Modify this!
# new_info = INFO[new_data_flag]
# print(f"Dataset: {new_info['description']}")

---

# Summary

## What You Learned

1. **Applied transfer learning** to a real medical imaging task

2. **Compared approaches**:
   - Frozen backbone: Fast training, good for limited data
   - Full fine-tuning: Better accuracy, adapts features to new domain

3. **Analyzed results** using confusion matrices and per-class metrics

4. **Key findings**:
   - Pre-trained ImageNet features transfer to medical images!
   - Some cell types are inherently harder to distinguish
   - Fine-tuning typically improves performance, especially on difficult classes

## For Your Research

- **Medical imaging**: Consider fine-tuning from RadImageNet or other medical pre-trained models
- **Microscopy**: Transfer from ImageNet often works surprisingly well
- **Limited data**: Start with frozen backbone, add augmentation
- **Different modalities**: Domain-specific pre-training helps (X-ray, MRI, microscopy, etc.)

---

## Reflection Questions

1. **Your data**: Could you apply similar techniques to images in your research?

2. **Domain gap**: How different is your imaging modality from natural images? Does transfer still help?

3. **Class imbalance**: Many scientific datasets have imbalanced classes. How would you handle this?

4. **Interpretation**: How would you explain model predictions to domain experts?