# 15. Transfer Learning Mini

This notebook demonstrates transfer learning with tiny models.

## Experiment Overview
- **Goal**: Apply transfer learning to small datasets
- **Model**: Pre-trained CNN with fine-tuning
- **Features**: Feature extraction, fine-tuning, domain adaptation
- **Learning**: Understanding transfer learning principles

## What You'll Learn
- Transfer learning concepts
- Feature extraction vs fine-tuning
- Domain adaptation
- Pre-trained model utilization


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Add scripts directory to path
sys.path.append('../scripts')
from utils import get_device, set_seed

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

# Load CIFAR-10 dataset
print("Loading CIFAR-10 dataset...")
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../data/cifar10', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='../data/cifar10', train=False, download=True, transform=transform)

# Create small subset for transfer learning
train_size = 1000
test_size = 200
train_indices = np.random.choice(len(train_dataset), train_size, replace=False)
test_indices = np.random.choice(len(test_dataset), test_size, replace=False)

train_subset = torch.utils.data.Subset(train_dataset, train_indices)
test_subset = torch.utils.data.Subset(test_dataset, test_indices)

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

print(f"Training samples: {len(train_subset)}")
print(f"Test samples: {len(test_subset)}")


In [None]:
# Define models for transfer learning comparison
class TransferLearningCNN(nn.Module):
    def __init__(self, num_classes=10, freeze_features=True):
        super(TransferLearningCNN, self).__init__()
        
        # Feature extractor (pretrained-like)
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
        # Freeze features if specified
        if freeze_features:
            for param in self.features.parameters():
                param.requires_grad = False
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FromScratchCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(FromScratchCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Training function
def train_model(model, train_loader, test_loader, epochs=20, lr=0.001, model_name=""):
    """Train model and return training history."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    test_accuracies = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        # Test accuracy
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        accuracy = 100. * correct / total
        test_accuracies.append(accuracy)
        
        if (epoch + 1) % 5 == 0:
            print(f'{model_name} - Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Acc: {accuracy:.2f}%')
    
    return train_losses, test_accuracies

# Train models
print("Training Transfer Learning Model (Frozen Features)...")
transfer_model = TransferLearningCNN(freeze_features=True).to(device)
transfer_losses, transfer_accuracies = train_model(transfer_model, train_loader, test_loader, model_name="Transfer")

print("\nTraining Transfer Learning Model (Fine-tuned)...")
finetune_model = TransferLearningCNN(freeze_features=False).to(device)
finetune_losses, finetune_accuracies = train_model(finetune_model, train_loader, test_loader, model_name="Fine-tune")

print("\nTraining From Scratch Model...")
scratch_model = FromScratchCNN().to(device)
scratch_losses, scratch_accuracies = train_model(scratch_model, train_loader, test_loader, model_name="From Scratch")

# Plot results
plt.figure(figsize=(15, 10))

plt.subplot(2, 2, 1)
plt.plot(transfer_losses, label='Transfer Learning (Frozen)', linewidth=2)
plt.plot(finetune_losses, label='Transfer Learning (Fine-tuned)', linewidth=2)
plt.plot(scratch_losses, label='From Scratch', linewidth=2)
plt.title('Training Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 2)
plt.plot(transfer_accuracies, label='Transfer Learning (Frozen)', linewidth=2)
plt.plot(finetune_accuracies, label='Transfer Learning (Fine-tuned)', linewidth=2)
plt.plot(scratch_accuracies, label='From Scratch', linewidth=2)
plt.title('Test Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 3)
# Final accuracies
models = ['Transfer (Frozen)', 'Transfer (Fine-tuned)', 'From Scratch']
final_accuracies = [transfer_accuracies[-1], finetune_accuracies[-1], scratch_accuracies[-1]]
plt.bar(models, final_accuracies)
plt.title('Final Test Accuracy')
plt.ylabel('Accuracy (%)')
plt.xticks(rotation=45)
plt.grid(True)

plt.subplot(2, 2, 4)
# Training efficiency
efficiency = [acc / loss for acc, loss in zip(final_accuracies, [transfer_losses[-1], finetune_losses[-1], scratch_losses[-1]])]
plt.bar(models, efficiency)
plt.title('Training Efficiency (Accuracy/Loss)')
plt.ylabel('Efficiency')
plt.xticks(rotation=45)
plt.grid(True)

plt.tight_layout()
plt.savefig('../results/plots/transfer_learning_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

# Print summary
print("\nTransfer Learning Summary:")
print(f"Transfer Learning (Frozen): {transfer_accuracies[-1]:.2f}%")
print(f"Transfer Learning (Fine-tuned): {finetune_accuracies[-1]:.2f}%")
print(f"From Scratch: {scratch_accuracies[-1]:.2f}%")

# Save models
torch.save(transfer_model.state_dict(), '../results/logs/transfer_model.pth')
torch.save(finetune_model.state_dict(), '../results/logs/finetune_model.pth')
torch.save(scratch_model.state_dict(), '../results/logs/scratch_model.pth')

print("\nModels saved successfully!")
