### 1. Prepare Dataset

In [1]:
import random
import torch

device = torch.device("mps")
torch.manual_seed(42)
random.seed(42)

In [None]:
# Load and examie the CIFAR-10 dataset
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10

train_set = CIFAR10(root="./data", train=True, download=True, transform=None)
test_set = CIFAR10(root="./data", train=False, download=True, transform=None)

imgs = train_set.data[:4]
labels = train_set.targets[:4]
classes = train_set.classes

fig, axes = plt.subplots(1, 4, figsize=(6,6))
for img, lbl, ax in zip(imgs, labels, axes.flatten()):
    ax.imshow(img)
    ax.set_title(classes[lbl])
    ax.axis("off")

plt.tight_layout()
plt.show()

In [3]:
# Transform the dataset
import torchvision.transforms as T
from torch.utils.data import DataLoader

batch_size = 64
transform_data = T.Compose(
    [
        T.Resize((299, 299)), # Inception-v4 spatial dimention
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics
    ]
)

train_set.transform = transform_data
test_set.transform = transform_data

train_loader = DataLoader(train_set, batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_set, batch_size, shuffle=True, num_workers=8)

### 2. Load the Inception-v4 model

In [None]:
import timm

model_ori = timm.create_model("inception_v4", pretrained=True)
model_ori = model_ori.to(device)
print(model_ori)

In [None]:
# Modify the number of output features to 10 for CIFAR-10 dataset
model_modified = timm.create_model("inception_v4", pretrained=True, num_classes=10)
model_modified = model_modified.to(device)
print(model_modified)

### 3. Training & Validation

In [6]:
from tqdm import tqdm

# Define training procedure
def training(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc="Training")
    for inputs, targets in train_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Track statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        train_bar.set_postfix({
            'loss': running_loss/total, 
            'acc': 100.*correct/total
        })
    
    # Calculate loss and accuracy
    train_loss = running_loss / len(train_loader.dataset)
    train_accuracy = 100. * correct / total
    
    return train_loss, train_accuracy


# Define validation procedure
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        val_bar = tqdm(val_loader, desc="Validation")
        for inputs, targets in val_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Track statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            val_bar.set_postfix({
                'loss': running_loss/total, 
                'acc': 100.*correct/total
            })
    
    # Calculate loss and accuracy
    val_loss = running_loss / len(val_loader.dataset)
    val_accuracy = 100. * correct / total
    
    return val_loss, val_accuracy

In [None]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

# Evaluate the original pre-trained Inception-v4
print("Evaluating original Inception v4 on CIFAR-10...")
original_val_loss, original_val_accuracy = validate(model_ori, test_loader, criterion, device)
print(f"Original model - Val Loss: {original_val_loss:.4f}, Val Acc: {original_val_accuracy:.2f}%")

In [None]:
# Fine-tune the model

# Define hyperparameters
epochs = 2
lr = 5e-3
momentum = 0.9
weight_decay = 1e-4

optimizer = optim.SGD(
    model_modified.parameters(),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay
)

# Save training data for further analysis
training_losses = []
training_accuracies = []
validation_losses = []
validation_accuracies = []

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    
    # Training phase
    train_loss, train_accuracy = training(model_modified, train_loader, criterion, optimizer, device)
    training_losses.append(train_loss)
    training_accuracies.append(train_accuracy)
    
    # Validation phase
    val_loss, val_accuracy = validate(model_modified, test_loader, criterion, device)
    validation_losses.append(val_loss)
    validation_accuracies.append(val_accuracy)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
    print("-" * 50)

In [None]:
torch.save(model_modified.state_dict(), 'inception_v4_cifar10.pth')
print("Model saved to inception_v4_cifar10.pth")