# Age recognition 

**Authors**: Richard Šléher, Tomáš Majerník

**Dataset**: https://www.kaggle.com/datasets/arashnic/faces-age-detection-dataset/code?select=train.csv

# TODO
- wandb
- hyperparameter tuning (sweep)

In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
from summarytools import dfSummary
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import torch.nn as nn
import torch.nn.functional as F


Hyperparameters

In [None]:
IMAGE_SIZE = 128

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device} device')

## EDA

In [None]:
data = pd.read_csv('data/train.csv')

In [None]:
dfSummary(data)

In [None]:
fig = plt.figure()

for i in range(9):
    plt.subplot(3, 3, i + 1)
    img = plt.imread('data/train/' + data.iloc[i]['ID'])
    plt.imshow(img)
    plt.title(data.iloc[i]['Class'])
    plt.axis('off')

plt.show()

In [None]:
sampled_data = data.groupby('Class').apply(lambda x: x.sample(n=2000)).reset_index(drop=True)

remaining_data = data.drop(sampled_data.index)
print(f"Training data size: {sampled_data.shape}")
print(f"Validation data size: {remaining_data.shape}")

In [None]:
dfSummary(sampled_data)

In [None]:
dfSummary(remaining_data)

In [None]:
class AgeDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.label_mapping = {'YOUNG': 0, 'MIDDLE': 1, 'OLD': 2} 

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join('data/train/', self.dataframe.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        label = self.dataframe.iloc[idx, 1]
        label = self.label_mapping[label]

        if self.transform:
            image = self.transform(image)

        return image, label

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
])

train_dataset = AgeDataset(dataframe=sampled_data, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=(32), shuffle=True)

val_dataset = AgeDataset(dataframe=remaining_data, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=(32), shuffle=False)


## Model

In [None]:
class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=5, stride=1, padding=2)  # Output: 6x128x128
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)  # Output: 6x64x64
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)  # Output: 16x60x60
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)  # Output: 16x30x30
        self.fc1 = nn.Linear(16 * 30 * 30, 120)  # Adjusted for 128x128 input
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(120, 84)
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 30 * 30)  # Adjusted for 128x128 input
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

In [None]:
model = CNN(in_channels=3, num_classes=3)

In [None]:
def evaluate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average='weighted')
    recall = recall_score(all_labels, all_predictions, average='weighted')
    f1 = f1_score(all_labels, all_predictions, average='weighted')
    
    return val_loss / len(val_loader), accuracy, precision, recall, f1

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1e-3)

# Training loop
model.to(device)

train_losses = []
val_losses = []
val_accuracies = []
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    val_loss, val_accuracy, precision, recall = evaluate(model, val_loader, criterion)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f},\
    Val Accuracy: {val_accuracy:.3f}%, Precision: {precision:.3f}, Recall: {recall:.3f}')

In [None]:
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

In [None]:
# Plotting the validation accuracy
plt.figure(figsize=(10, 5))
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Validation Accuracy')
plt.show()

In [None]:
torch.save(model.state_dict(), 'age_recognition_model.pth')