In [1]:
import numpy as np
import pandas as pd
from glob import glob
from os.path import join
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.nn.parallel import DataParallel
from torchvision.models import ResNet50_Weights

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation
import torch.optim as optim
from torch.optim import lr_scheduler
from sklearn.metrics import mean_absolute_error

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Augmentation
class AgeDataset(torch.utils.data.Dataset):

    def __init__(self, data_path, annot_path, train=True):
        super(AgeDataset, self).__init__()

        self.annot_path = annot_path
        self.data_path = data_path
        self.train = train

        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        if train:
            self.ages = self.ann['age']
        
        self.transform = self._transform(224)

    @staticmethod    
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def _transform(self, n_px):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        return Compose([
            Resize(n_px),
            RandomHorizontalFlip(),  
            RandomRotation(10),     
            self._convert_image_to_rgb,
            ToTensor(),
            Normalize(mean, std),
        ])

    def read_img(self, file_name):
        im_path = join(self.data_path,file_name)   
        img = Image.open(im_path)
        img = self.transform(img)
        return img

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)

# Paths
train_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train'
train_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train.csv'
test_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/test'
test_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv'

# Datasets
train_dataset = AgeDataset(train_path, train_ann, train=True)
test_dataset = AgeDataset(test_path, test_ann, train=False)

# Data Loaders
train_indices, val_indices = train_test_split(range(len(train_dataset)), test_size=0.1, random_state=42, shuffle=True)
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=val_sampler)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Model with Transfer Learning and Fine-tuning
class AgePredictor(nn.Module):
    def __init__(self, model_name="resnet50", pretrained=True):
        super(AgePredictor, self).__init__()
        if model_name == "resnet34":
            self.model = torchvision.models.resnet34(pretrained=pretrained)
        elif model_name == "resnet50":
#             self.model = torchvision.models.resnet50(pretrained=pretrained)
            self.model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)

        else:
            raise ValueError("Invalid model name. Choose 'resnet18' or 'resnet34'")
        
        for param in self.model.parameters():
            param.requires_grad = False
        
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1)
        )
        
        # Wrap the model with DataParallel
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs!")
            self.model = nn.DataParallel(self.model)

    def forward(self, x):
        return self.model(x)

model = AgePredictor().to(device)

def test_loss(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs.flatten(), labels.float())
            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
    return total_loss / total_samples

def train(model, train_loader, val_loader, optimizer, criterion, epochs=25, patience=5):
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True)
    best_mae = float('inf')
    no_improve_epochs = 0
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.flatten(), labels.float()) 
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        val_mae = test_loss(model, val_loader, criterion)
        
        scheduler.step(val_mae)
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {running_loss/len(train_loader):.4f}, Val MAE: {val_mae:.4f}')
        
        if val_mae < best_mae:
            best_mae = val_mae
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, 'best_model.pth')
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
        
        if no_improve_epochs >= patience:
            print(f'Early stopping after {patience} epochs of no improvement.')
            break
            
    print(f'Finished Training. Best Validation MAE: {best_mae:.4f}')

def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = torch.nn.L1Loss()(outputs.flatten(), labels.float())
            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
    return total_loss / total_samples

def predict(loader, model):
    model.eval()
    predictions = []
    with torch.no_grad():
        for inputs in tqdm(loader):
            inputs = inputs.to(device)
            pred = model(inputs)
            predictions.extend(pred.flatten().cpu().detach().numpy().tolist())
    return predictions

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

train(model, train_loader, val_loader, optimizer, criterion, epochs=25, patience=5)

# Load the best model
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Make predictions on the test set
preds = predict(test_loader, model)

# Calculate MAE on test set
test_labels = pd.read_csv(test_ann)['age'].values
test_mae = mean_absolute_error(test_labels, preds)
print(f'Test MAE: {test_mae:.4f}')

# Save predictions to a CSV file
submit = pd.read_csv(test_ann)
submit['age'] = preds
submit.to_csv('submission.csv', index=False)


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 163MB/s]


Using 2 GPUs!
Epoch 1/25, Train Loss: 18.6430, Val MAE: 14.7736
Epoch 2/25, Train Loss: 14.2433, Val MAE: 13.3796
Epoch 3/25, Train Loss: 13.1826, Val MAE: 12.6083
Epoch 4/25, Train Loss: 12.3185, Val MAE: 11.5019
Epoch 5/25, Train Loss: 11.6804, Val MAE: 11.1777
Epoch 6/25, Train Loss: 11.4136, Val MAE: 10.8381
Epoch 7/25, Train Loss: 11.1687, Val MAE: 11.1522
Epoch 8/25, Train Loss: 11.1114, Val MAE: 10.5965
Epoch 9/25, Train Loss: 10.8922, Val MAE: 10.4471
Epoch 10/25, Train Loss: 10.7854, Val MAE: 10.3827
Epoch 11/25, Train Loss: 10.5876, Val MAE: 10.4239
Epoch 12/25, Train Loss: 10.5749, Val MAE: 10.3782
Epoch 13/25, Train Loss: 10.4411, Val MAE: 10.0837
Epoch 14/25, Train Loss: 10.3560, Val MAE: 10.1082
Epoch 15/25, Train Loss: 10.3169, Val MAE: 10.0059
Epoch 16/25, Train Loss: 10.3020, Val MAE: 10.0258
Epoch 17/25, Train Loss: 10.2608, Val MAE: 9.8412
Epoch 18/25, Train Loss: 10.1521, Val MAE: 10.0290
Epoch 19/25, Train Loss: 10.1237, Val MAE: 10.0921
Epoch 20/25, Train Loss: 10

100%|██████████| 31/31 [00:18<00:00,  1.68it/s]

Test MAE: 11.6730



