In [1]:
from torchvision import transforms

def imshow(image):
    mean=torch.tensor([0.485, 0.456, 0.406])
    std=torch.tensor([0.229, 0.224, 0.225])

    #normalize = transforms.Normalize(mean.tolist(), std.) 

    unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    img_unn = unnormalize(image)
    plt.imshow(img_unn.permute(1, 2, 0))
    plt.show()

<h1> Hyperparamters </h1>
<h2> We will use the following hyperparameters for training the model:
- Batch size: 512
- Learning rate: 0.01
- Number of epochs: 1000 </h2>
<h3> We don’t have data for certain ages (e.g no image of a 94 year old person), so we can group ages into ranges (like 90–95, 95–100) instead of predicting each specific age. 
This is called class binning. By using age ranges, the model can handle missing data more effectively and make more accurate predictions, even with a smaller dataset.</h4>

In [2]:
import math
SIZE_OF_BIN = 10
NUM_OF_CLASSES = math.ceil(116 / SIZE_OF_BIN)
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPOCHS = 1000

In [3]:
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler

import matplotlib.pyplot as plt
from PIL import Image
import os
import torch
from torchvision import transforms
# coutainer of classes

class CustomDataset(Dataset):
    def __init__(self, images_path):
        self.image_files = [os.path.join(images_path, f) for f in os.listdir(images_path) 
                            if os.path.isfile(os.path.join(images_path, f)) and f.lower().endswith('.jpg')]
        self.transform = transforms.Compose([
            transforms.Resize([256], transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop([224]),
            
            # Subtle augmentations because a person normally e.g does not stand on upside down
            transforms.RandomHorizontalFlip(p=0.2),       # Small chance of flipping
            transforms.RandomRotation(degrees=5),         # Small rotation range
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Subtle color jitter
            
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = Image.open(self.image_files[idx]).convert("RGB")
        image = self.transform(image)
        age = self.image_files[idx].split("\\")[-1].split('_')[0]
        return image, int(age) // SIZE_OF_BIN 


dataset = CustomDataset(r"C:\Users\morit\Downloads\UTKface_inthewild-20241024T082001Z-001\UTKface_inthewild")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

train_size = int(0.8 * len(dataset))






val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
for images, labels in dataloader:
    for i, labels in enumerate(labels):
        
        print(f"{labels*SIZE_OF_BIN}-{labels*SIZE_OF_BIN + SIZE_OF_BIN}:")
    break


100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
100-110:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:
10-20:




Due to the large imbalance of the dataset, we need to calculate the class weights to ensure that the model does not overfit to the majority class.
 
We will use the inverse frequency of each class as the class weight. This will help the model to pay more attention to the minority classes.

<h1> Calculate Class Weights </h1>
Since the calculation of class takes a long time, we will use the hardcoded output of the calculation.

In [4]:
from torch import tensor
class_weights_tensor = tensor([0.1378, 0.0637, 0.3042, 0.1891, 0.0946, 0.0966, 0.0550, 0.0298, 0.0220,
        0.0058, 0.0007, 0.0006], dtype=torch.float)

In [5]:
# import torch
# from collections import defaultdict
# 
# # Step 1: Count occurrences of each class
# class_counts = defaultdict(int)
# 
# for _, labels in train_loader:
#     for label in labels.numpy():
#             
#         class_counts[label] += 1
# 
# # Step 2: Calculate class weights
# total_samples = sum(class_counts.values())
# 
# # Calculate weights: inverse frequency
# class_weights = {cls: count/ total_samples for cls, count in class_counts.items()}
# # Convert t a tensor for PyTorch
# class_weights_tensor = torch.tensor([class_weights[i] for i in range(NUM_OF_CLASSES)], dtype=torch.float)
# 
# print("Class weights:", class_weights_tensor)


In [6]:
# import matplotlib.pyplot as plt
# import numpy as np
# 
# all_labels = []
# 
# for _, labels in dataloader:
#     all_labels.extend(labels.numpy())  
# 
# all_labels = np.array(all_labels)
# 
# plt.figure(figsize=(10, 6))
# plt.hist(all_labels, bins=range(int(all_labels.min()), int(all_labels.max()) + 2), edgecolor='black', alpha=0.7)
# plt.title('Age Distribution in Dataset')
# plt.xlabel('Age')
# plt.ylabel('Frequency')
# plt.grid(axis='y', linestyle='--', alpha=0.7)
# plt.show()

<h1> Model </h1>

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class highLevelNN(nn.Module):
    def __init__(self):
        super(highLevelNN, self).__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Dropout(0.4),
        )

    def forward(self, x):
        return self.CNN(x)


class lowLevelNN(nn.Module):
    def __init__(self, num_out):
        super(lowLevelNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(256)

        self.fc4 = nn.Linear(in_features=256, out_features=128)
        self.fc5 = nn.Linear(in_features=128, out_features=64)
        self.fc6 = nn.Linear(in_features=64, out_features=num_out)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, kernel_size=6, stride=3, padding=1)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, kernel_size=6, stride=3, padding=1)

        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc4(x))
        x = self.dropout(x)
        x = F.relu(self.fc5(x))
        x = self.dropout(x)
        return self.fc6(x)


class AgeNN(nn.Module):
    def __init__(self, num_age):
        super(AgeNN, self).__init__()
        self.CNN = highLevelNN()
        self.ageNN = lowLevelNN(num_out=num_age)
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        x = self.CNN(x)
        return self.ageNN(x)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__ == '__main__':
    from torchsummary import summary
    print('Testing out Multi-Label NN')
    mlNN = AgeNN(NUM_OF_CLASSES).to(device)
    
    summary(mlNN, input_size=(3, 224, 224))


Testing out Multi-Label NN
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]             896
       BatchNorm2d-2         [-1, 32, 224, 224]              64
              ReLU-3         [-1, 32, 224, 224]               0
            Conv2d-4         [-1, 32, 224, 224]           9,248
       BatchNorm2d-5         [-1, 32, 224, 224]              64
         MaxPool2d-6         [-1, 32, 112, 112]               0
              ReLU-7         [-1, 32, 112, 112]               0
           Dropout-8         [-1, 32, 112, 112]               0
            Conv2d-9         [-1, 64, 112, 112]          18,496
      BatchNorm2d-10         [-1, 64, 112, 112]             128
             ReLU-11         [-1, 64, 112, 112]               0
           Conv2d-12         [-1, 64, 112, 112]          36,928
      BatchNorm2d-13         [-1, 64, 112, 112]             128
        MaxP

In [8]:
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import f1_score, balanced_accuracy_score
from torch.cuda.amp import GradScaler, autocast

def train(trainloader, testloader, model, opt, scheduler, num_epoch, save_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    age_loss = nn.CrossEntropyLoss()
    print("1")

    for epoch in range(num_epoch):
        model.train()

        loop = tqdm(enumerate(trainloader), total=len(trainloader), leave=False)
        age_correct = 0
        total = 0
        

        for _, (X, y) in loop:

            age = y.to(device).long()
            X = X.to(device)

            with autocast():
                pred = model(X)
                loss = age_loss(pred, age)

            opt.zero_grad()
            loss.backward()
            opt.step()

            age_correct += (pred.argmax(1) == age).type(torch.float).sum().item()
            total += age.size(0)

            loop.set_description(f"Epoch [{epoch+1}/{num_epoch}]")
            loop.set_postfix(loss=loss.item())
        
        scheduler.step()
        age_acc = age_correct / total
        print(f'Epoch: {epoch+1}/{num_epoch}, Age Accuracy: {age_acc * 100:.2f}%')

        torch.save(model.state_dict(), f"{save_path}/model_epoch_{epoch+1}.pth")

        if (epoch + 1) % 10 == 0:
            evaluate(testloader, model, epoch + 1)
            
def evaluate(testloader, model, epoch):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval()
    test_correct = 0
    test_total = 0
    all_preds = []
    all_labels = []
    
    for X_test, y_test in testloader:
        X_test, y_test = X_test.to(device), y_test.to(device).long()
        test_pred = model(X_test)
        
        test_correct += (test_pred.argmax(1) == y_test).type(torch.float).sum().item()
        test_total += y_test.size(0)
        
        all_preds.extend(test_pred.argmax(1).cpu().numpy())
        all_labels.extend(y_test.cpu().numpy())
    
    test_acc = test_correct / test_total
    f1 = f1_score(all_labels, all_preds, average='weighted')
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)
    
    print(f'Test Accuracy after Epoch {epoch}: {test_acc * 100:.2f}%')
    print(f'Test F1 Score after Epoch {epoch}: {f1 * 100:.2f}%')
    print(f'Test Balanced Accuracy after Epoch {epoch}: {balanced_acc * 100:.2f}%\n')

In [9]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
# Decrease the learning rate by 0.1 every 10 epochs
import os   
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
model = AgeNN(NUM_OF_CLASSES)  
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
# Train the model
train(train_loader, test_loader, model, optimizer,scheduler, EPOCHS, "model_weights")

1


  with autocast():
                                                                            

Epoch: 1/1000, Age Accuracy: 26.25%


                                                                            

Epoch: 2/1000, Age Accuracy: 30.41%


                                                                            

Epoch: 3/1000, Age Accuracy: 33.89%


                                                                            

Epoch: 4/1000, Age Accuracy: 36.73%


                                                                            

Epoch: 5/1000, Age Accuracy: 38.72%


                                                                            

Epoch: 6/1000, Age Accuracy: 40.52%


                                                                            

Epoch: 7/1000, Age Accuracy: 42.03%


                                                                            

Epoch: 8/1000, Age Accuracy: 43.06%


                                                                            

Epoch: 9/1000, Age Accuracy: 44.09%


                                                                             

Epoch: 10/1000, Age Accuracy: 44.86%
Test Accuracy after Epoch 10: 44.73%
Test F1 Score after Epoch 10: 39.02%
Test Balanced Accuracy after Epoch 10: 18.97%



  with autocast():
                                                                             

Epoch: 11/1000, Age Accuracy: 46.41%


                                                                             

Epoch: 12/1000, Age Accuracy: 46.81%


                                                                             

Epoch: 13/1000, Age Accuracy: 46.92%


                                                                             

Epoch: 14/1000, Age Accuracy: 47.48%


                                                                             

Epoch: 15/1000, Age Accuracy: 47.54%


                                                                             

Epoch: 16/1000, Age Accuracy: 47.58%


                                                                             

Epoch: 17/1000, Age Accuracy: 48.23%


                                                                             

Epoch: 18/1000, Age Accuracy: 47.95%


                                                                             

Epoch: 19/1000, Age Accuracy: 47.82%


                                                                             

Epoch: 20/1000, Age Accuracy: 48.26%
Test Accuracy after Epoch 20: 49.17%
Test F1 Score after Epoch 20: 41.41%
Test Balanced Accuracy after Epoch 20: 23.45%



  with autocast():
                                                                             

Epoch: 21/1000, Age Accuracy: 48.65%


                                                                             

Epoch: 22/1000, Age Accuracy: 48.19%


                                                                             

Epoch: 23/1000, Age Accuracy: 48.14%


                                                                             

Epoch: 24/1000, Age Accuracy: 48.27%


                                                                             

Epoch: 25/1000, Age Accuracy: 47.91%


                                                                             

Epoch: 26/1000, Age Accuracy: 48.50%


                                                                             

Epoch: 27/1000, Age Accuracy: 47.94%


                                                                             

Epoch: 28/1000, Age Accuracy: 48.44%


                                                                             

Epoch: 29/1000, Age Accuracy: 48.58%


                                                                             

Epoch: 30/1000, Age Accuracy: 48.28%
Test Accuracy after Epoch 30: 49.42%
Test F1 Score after Epoch 30: 41.83%
Test Balanced Accuracy after Epoch 30: 23.39%



  with autocast():
                                                                             

Epoch: 31/1000, Age Accuracy: 48.87%


                                                                             

Epoch: 32/1000, Age Accuracy: 48.62%


                                                                             

Epoch: 33/1000, Age Accuracy: 48.62%


                                                                             

Epoch: 34/1000, Age Accuracy: 48.50%


                                                                             

Epoch: 35/1000, Age Accuracy: 48.61%


                                                                             

Epoch: 36/1000, Age Accuracy: 48.74%


                                                                             

Epoch: 37/1000, Age Accuracy: 48.23%


                                                                             

Epoch: 38/1000, Age Accuracy: 48.72%


                                                                             

Epoch: 39/1000, Age Accuracy: 48.28%


                                                                             

Epoch: 40/1000, Age Accuracy: 48.44%
Test Accuracy after Epoch 40: 50.17%
Test F1 Score after Epoch 40: 42.70%
Test Balanced Accuracy after Epoch 40: 23.87%



  with autocast():
                                                                             

Epoch: 41/1000, Age Accuracy: 48.56%


                                                                             

Epoch: 42/1000, Age Accuracy: 48.59%


                                                                             

Epoch: 43/1000, Age Accuracy: 48.52%


                                                                             

Epoch: 44/1000, Age Accuracy: 48.40%


                                                                             

Epoch: 45/1000, Age Accuracy: 48.52%


                                                                             

Epoch: 46/1000, Age Accuracy: 48.39%


                                                                             

Epoch: 47/1000, Age Accuracy: 48.75%


                                                                             

Epoch: 48/1000, Age Accuracy: 48.57%


                                                                             

Epoch: 49/1000, Age Accuracy: 48.55%


                                                                             

Epoch: 50/1000, Age Accuracy: 48.95%
Test Accuracy after Epoch 50: 49.17%
Test F1 Score after Epoch 50: 41.43%
Test Balanced Accuracy after Epoch 50: 22.80%



  with autocast():
                                                                             

Epoch: 51/1000, Age Accuracy: 48.70%


                                                                             

Epoch: 52/1000, Age Accuracy: 48.25%


                                                                             

Epoch: 53/1000, Age Accuracy: 48.49%


                                                                             

Epoch: 54/1000, Age Accuracy: 48.50%


                                                                             

Epoch: 55/1000, Age Accuracy: 48.71%


                                                                             

Epoch: 56/1000, Age Accuracy: 48.21%


                                                                             

Epoch: 57/1000, Age Accuracy: 48.44%


                                                                             

Epoch: 58/1000, Age Accuracy: 48.63%


                                                                             

Epoch: 59/1000, Age Accuracy: 48.64%


                                                                             

Epoch: 60/1000, Age Accuracy: 48.82%
Test Accuracy after Epoch 60: 49.34%
Test F1 Score after Epoch 60: 41.86%
Test Balanced Accuracy after Epoch 60: 23.10%



  with autocast():
                                                                             

Epoch: 61/1000, Age Accuracy: 48.41%


                                                                             

Epoch: 62/1000, Age Accuracy: 48.50%


                                                                             

Epoch: 63/1000, Age Accuracy: 48.45%


                                                                             

Epoch: 64/1000, Age Accuracy: 48.80%


                                                                             

Epoch: 65/1000, Age Accuracy: 48.63%


                                                                             

Epoch: 66/1000, Age Accuracy: 48.79%


                                                                             

Epoch: 67/1000, Age Accuracy: 48.05%


                                                                             

Epoch: 68/1000, Age Accuracy: 48.64%


                                                                             

Epoch: 69/1000, Age Accuracy: 48.66%


                                                                             

Epoch: 70/1000, Age Accuracy: 48.43%
Test Accuracy after Epoch 70: 49.59%
Test F1 Score after Epoch 70: 42.01%
Test Balanced Accuracy after Epoch 70: 23.54%



  with autocast():
                                                                             

Epoch: 71/1000, Age Accuracy: 48.41%


                                                                             

Epoch: 72/1000, Age Accuracy: 48.47%


                                                                             

Epoch: 73/1000, Age Accuracy: 48.58%


                                                                             

Epoch: 74/1000, Age Accuracy: 48.61%


                                                                             

Epoch: 75/1000, Age Accuracy: 48.59%


                                                                             

Epoch: 76/1000, Age Accuracy: 48.71%


                                                                             

Epoch: 77/1000, Age Accuracy: 48.59%


                                                                             

Epoch: 78/1000, Age Accuracy: 48.59%


                                                                             

KeyboardInterrupt: 