#Fine tuning and validation of VIT-Base model

In [35]:
# Import necessary libraries
import modal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import io
from transformers import ViTForImageClassification
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

# Modal setup
stub = modal.App("vit-fairface-training")

# Define the image for Modal
image = modal.Image.debian_slim().pip_install(
    "torch",
    "torchvision",
    "transformers",
    "pandas",
    "pillow",
    "tqdm",
    "pyarrow",
    "accelerate"
)

# Create volumes
volume = modal.Volume.from_name("fairface-data", create_if_missing=True)
model_volume = modal.Volume.from_name("vit-models", create_if_missing=True)

In [36]:
# Dataset class for FairFace
class FairFaceDataset(Dataset):
    def __init__(self, parquet_file, transform=None):
        self.data = pd.read_parquet(parquet_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        # Extract image bytes from the 'bytes' key in the dict
        image = Image.open(io.BytesIO(row['image']['bytes'])).convert('RGB')
        label = row['race']  # Adjust if your label column is different
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

def get_dataloaders(batch_size=32):
    train_dataset = FairFaceDataset("/root/data/train.parquet")
    val_dataset = FairFaceDataset("/root/data/validation.parquet")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader

In [37]:
# Model definition
def create_vit_model(num_classes):
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224',
                                                     num_labels=num_classes,
                                                     ignore_mismatched_sizes=True)
    return model


from accelerate import Accelerator

def train_epoch(model, train_loader, criterion, optimizer, accelerator, print_every=10):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (images, labels) in enumerate(train_loader):
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if (batch_idx + 1) % print_every == 0 or (batch_idx + 1) == len(train_loader):
            accelerator.print(f"Batch {batch_idx+1}/{len(train_loader)} - "
                              f"Train Loss: {total_loss/(batch_idx+1):.4f}, "
                              f"Train Acc: {100.*correct/total:.2f}%")

    return total_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, accelerator, print_every=10):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(val_loader):
            outputs = model(images).logits
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if (batch_idx + 1) % print_every == 0 or (batch_idx + 1) == len(val_loader):
                accelerator.print(f"Batch {batch_idx+1}/{len(val_loader)} - "
                                  f"Val Loss: {total_loss/(batch_idx+1):.4f}, "
                                  f"Val Acc: {100.*correct/total:.2f}%")

    return total_loss / len(val_loader), 100. * correct / total

In [42]:
@stub.function(
    image=image,
    gpu="A100",  # Use 4 A100 GPUs
    volumes={"/root/data": volume, "/root/models": model_volume},
    timeout=14400
)
def train_model(num_epochs=10, batch_size=32, learning_rate=2e-5):
    from accelerate import Accelerator
    accelerator = Accelerator()
    device = accelerator.device

    train_loader, val_loader = get_dataloaders(batch_size)
    num_classes = 7
    model = create_vit_model(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    # Prepare for distributed training
    model, optimizer, train_loader, val_loader = accelerator.prepare(
        model, optimizer, train_loader, val_loader
    )

    best_val_acc = 0
    for epoch in range(num_epochs):
        accelerator.print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, accelerator)
        accelerator.print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%")

        val_loss, val_acc = validate(model, val_loader, criterion, accelerator)
        accelerator.print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            accelerator.save(model.state_dict(), "/root/models/vit_fairface_best.pth")
            accelerator.print(f"Saved new best model with validation accuracy: {val_acc:.2f}%")

    accelerator.save(model.state_dict(), "/root/models/vit_fairface_final.pth")
    accelerator.print("\nTraining completed!")
    return best_val_acc

In [44]:
with stub.run():
    best_acc = train_model.remote(num_epochs=10, batch_size=16)
    print(f"Best validation accuracy: {best_acc:.2f}%")

Best validation accuracy: 72.06%
