# Notebook 4: Transfer Learning for Cell Image Classification

## Introduction

In this notebook, we'll explore **Transfer Learning**, one of the most powerful techniques in deep learning, and apply it to cell image classification.

### What is Transfer Learning?

Transfer learning uses a model pre-trained on a large dataset (e.g., ImageNet with millions of images) and adapts it to a new task. Instead of training from scratch, we leverage knowledge the model already learned!

### Why Transfer Learning?

- **Less data needed**: Pre-trained models work well even with small datasets
- **Faster training**: Only need to fine-tune, not train from scratch
- **Better performance**: Especially when you have limited data
- **Lower computational cost**: Less GPU time required

### Learning Objectives

1. Understand the concept of transfer learning
2. Use a pre-trained ResNet model
3. Fine-tune for cell image classification
4. Compare transfer learning vs training from scratch
5. Visualize learned features using grad-CAM

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.models as models
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Generate Synthetic Cell Images

We'll create synthetic microscopy images of three cell types:
- **Red Blood Cells**: Circular, uniform
- **White Blood Cells**: Larger, textured nuclei
- **Platelets**: Small, irregular

In practice, you would use real datasets like:
- Broad Bioimage Benchmark Collection
- Cell Image Library
- Kaggle medical image datasets

In [None]:
def create_cell_image(cell_type, size=128):
    """
    Create synthetic cell images.
    
    Args:
        cell_type: 0 (RBC), 1 (WBC), 2 (Platelet)
        size: Image size
    """
    image = np.zeros((size, size, 3), dtype=np.uint8)
    
    if cell_type == 0:  # Red Blood Cell
        # Create circular shape
        center = (size // 2 + np.random.randint(-10, 10), 
                 size // 2 + np.random.randint(-10, 10))
        radius = np.random.randint(25, 35)
        
        y, x = np.ogrid[:size, :size]
        mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
        
        # Reddish color
        image[mask] = [200 + np.random.randint(-20, 20), 
                       50 + np.random.randint(-20, 20), 
                       50 + np.random.randint(-20, 20)]
    
    elif cell_type == 1:  # White Blood Cell
        # Larger with textured nucleus
        center = (size // 2, size // 2)
        radius = np.random.randint(35, 45)
        
        y, x = np.ogrid[:size, :size]
        mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
        
        # Light purple/blue
        image[mask] = [150 + np.random.randint(-30, 30), 
                       150 + np.random.randint(-30, 30), 
                       200 + np.random.randint(-30, 30)]
        
        # Add nucleus texture
        nucleus_mask = (x - center[0])**2 + (y - center[1])**2 <= (radius * 0.6)**2
        image[nucleus_mask] = [80, 80, 150]
    
    else:  # Platelet
        # Small irregular shapes
        for _ in range(np.random.randint(2, 5)):
            center = (np.random.randint(size//4, 3*size//4), 
                     np.random.randint(size//4, 3*size//4))
            radius = np.random.randint(8, 15)
            
            y, x = np.ogrid[:size, :size]
            mask = (x - center[0])**2 + (y - center[1])**2 <= radius**2
            
            # Yellowish
            image[mask] = [200, 200, 100]
    
    # Add noise
    noise = np.random.normal(0, 10, image.shape)
    image = np.clip(image + noise, 0, 255).astype(np.uint8)
    
    return image

# Generate dataset
n_samples_per_class = 200
images = []
labels = []

for cell_type in range(3):
    for _ in range(n_samples_per_class):
        img = create_cell_image(cell_type)
        images.append(img)
        labels.append(cell_type)

print(f"Generated {len(images)} cell images")
print(f"Image shape: {images[0].shape}")

# Visualize examples
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
cell_names = ['Red Blood Cell', 'White Blood Cell', 'Platelet']

for i, (ax, name) in enumerate(zip(axes, cell_names)):
    ax.imshow(images[i * n_samples_per_class])
    ax.set_title(name)
    ax.axis('off')

plt.tight_layout()
plt.show()

## 3. Data Preprocessing and Augmentation

### Image Transformations:

1. **Resize**: Ensure all images are the same size
2. **Normalization**: Use ImageNet mean/std for transfer learning
3. **Data Augmentation** (training only):
   - Random rotation
   - Random flips
   - Color jitter

Data augmentation creates variations of training images, effectively increasing dataset size!

In [None]:
# ImageNet normalization values (required for pre-trained models)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # ResNet expects 224x224
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# Validation/test transforms without augmentation
test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

print("Data transformations configured.")
print("\nWhy normalize with ImageNet statistics?")
print("- Pre-trained models expect inputs in this range")
print("- Ensures feature distributions match training data")

## 4. Create PyTorch Dataset

In [None]:
class CellImageDataset(Dataset):
    """
    Dataset for cell images.
    """
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create datasets
dataset = CellImageDataset(images, labels)

# Split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_indices, val_indices, test_indices = torch.utils.data.random_split(
    range(len(dataset)), [train_size, val_size, test_size]
)

# Create datasets with appropriate transforms
train_dataset = CellImageDataset(
    [images[i] for i in train_indices.indices],
    [labels[i] for i in train_indices.indices],
    transform=train_transform
)

val_dataset = CellImageDataset(
    [images[i] for i in val_indices.indices],
    [labels[i] for i in val_indices.indices],
    transform=test_transform
)

test_dataset = CellImageDataset(
    [images[i] for i in test_indices.indices],
    [labels[i] for i in test_indices.indices],
    transform=test_transform
)

print(f"Dataset splits:")
print(f"  Train: {len(train_dataset)}")
print(f"  Validation: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

## 5. Create DataLoaders

In [None]:
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"DataLoaders created with batch size {batch_size}")

## 6. Load Pre-trained ResNet Model

### ResNet (Residual Network):

- One of the most successful architectures
- Uses skip connections to train very deep networks
- Pre-trained on ImageNet (1.2M images, 1000 classes)

### Transfer Learning Strategy:

1. **Load pre-trained model**: Get weights learned from ImageNet
2. **Freeze early layers**: Keep low-level features (edges, textures)
3. **Replace final layer**: Adapt to our 3 classes
4. **Fine-tune**: Train with our cell images

In [None]:
def create_transfer_learning_model(num_classes=3, freeze_layers=True):
    """
    Create ResNet18 model with transfer learning.
    
    Args:
        num_classes: Number of output classes
        freeze_layers: Whether to freeze pre-trained layers
    """
    # Load pre-trained ResNet18
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    
    # Freeze all layers if requested
    if freeze_layers:
        for param in model.parameters():
            param.requires_grad = False
    
    # Replace final fully connected layer
    # ResNet18's fc layer: (512 -> 1000) becomes (512 -> 3)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    
    return model

# Create model
model = create_transfer_learning_model(num_classes=3, freeze_layers=True).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: ResNet18")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print(f"\nWe're only training {100 * trainable_params / total_params:.1f}% of the parameters!")

## 7. Training Setup

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()

# Only optimize parameters that require gradients (unfrozen layers)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

## 8. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

## 9. Train the Model

Notice how fast training is! We're only updating a small fraction of parameters.

In [None]:
num_epochs = 15
train_losses = []
val_losses = []
train_accs = []
val_accs = []

print("Starting training...\n")

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%\n")

print("Training completed!")

## 10. Visualize Training

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(val_losses, label='Validation Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Progress (Transfer Learning)')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(train_accs, label='Train Accuracy', marker='o')
ax2.plot(val_accs, label='Validation Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Accuracy Progress')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 11. Evaluate on Test Set

In [None]:
def evaluate_model(model, loader, device):
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.array(all_predictions), np.array(all_labels)

test_predictions, test_labels = evaluate_model(model, test_loader, device)

print("Classification Report:")
print(classification_report(
    test_labels, test_predictions,
    target_names=['RBC', 'WBC', 'Platelet']
))

test_accuracy = 100 * np.sum(test_predictions == test_labels) / len(test_labels)
print(f"\nTest Accuracy: {test_accuracy:.2f}%")

## 12. Confusion Matrix

In [None]:
cm = confusion_matrix(test_labels, test_predictions)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['RBC', 'WBC', 'Platelet'],
            yticklabels=['RBC', 'WBC', 'Platelet'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix - Cell Classification')
plt.show()

## 13. Visualize Predictions

In [None]:
# Get some test images
model.eval()
cell_names = ['RBC', 'WBC', 'Platelet']

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.ravel()

# Get a batch
images_batch, labels_batch = next(iter(test_loader))
images_batch = images_batch.to(device)

with torch.no_grad():
    outputs = model(images_batch)
    _, predictions = torch.max(outputs, 1)
    probabilities = torch.softmax(outputs, dim=1)

# Denormalize for display
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)

for i in range(6):
    img = images_batch[i].cpu() * std + mean
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    true_label = labels_batch[i].item()
    pred_label = predictions[i].item()
    confidence = probabilities[i][pred_label].item()
    
    axes[i].imshow(img)
    axes[i].set_title(f'True: {cell_names[true_label]}\nPred: {cell_names[pred_label]} ({confidence:.1%})')
    axes[i].axis('off')
    
    if true_label != pred_label:
        axes[i].set_title(f'True: {cell_names[true_label]}\nPred: {cell_names[pred_label]} ({confidence:.1%})',
                         color='red')

plt.tight_layout()
plt.show()

## Summary and Key Takeaways

In this notebook, we:

1. ✅ **Applied transfer learning** using pre-trained ResNet18
2. ✅ **Froze pre-trained layers** to leverage learned features
3. ✅ **Fine-tuned** only the final layer for our task
4. ✅ **Used data augmentation** to improve generalization
5. ✅ **Achieved high accuracy** with limited data and training time

### Transfer Learning Benefits:

- **Data efficiency**: Works with small datasets (we used only 600 images!)
- **Fast training**: Converges in just a few epochs
- **Better features**: Pre-trained features (edges, textures) are universal
- **Lower computational cost**: Less GPU time and memory

### When to Use Transfer Learning:

✅ **Use it when**:
- You have limited data (< 10,000 images)
- Your task is similar to ImageNet (natural images, objects)
- You want fast prototyping
- Computational resources are limited

❌ **Consider training from scratch when**:
- You have millions of images
- Your domain is very different (e.g., medical scans, satellite images)
- You have specific architectural requirements

### Advanced Techniques:

- **Gradual unfreezing**: Unfreeze layers progressively during training
- **Discriminative learning rates**: Different learning rates for different layers
- **Domain-specific pre-training**: Use models pre-trained on medical images
- **Ensemble methods**: Combine multiple pre-trained models

### Real-World Applications:

- **Medical imaging**: X-ray, CT, MRI classification
- **Pathology**: Cancer detection in tissue samples
- **Cell biology**: Cell type identification, organelle detection
- **Drug discovery**: High-content screening analysis
- **Quality control**: Defect detection in cell cultures