In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import numpy as np
import os

from data_loader import FaceDataset
from model import AgeGenderModel

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
model = AgeGenderModel().to(device)
model.load_state_dict(torch.load('models/age_gender_model.pth', map_location=torch.device('cpu')))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
full_dataset = FaceDataset(root_dir='dataset', transform=transform)
# filtered_indices = [i for i, (data, label) in enumerate(full_dataset) if data is not None]
dataset = full_dataset

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
_, _, test_dataset = random_split(dataset, [train_size, val_size, test_size])

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [5]:
all_ages = []
all_age_preds = []
all_genders = []
all_gender_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        ages = labels[:, 0].numpy()
        genders = labels[:, 1].numpy()
        
        age_pred, gender_pred = model(images)
        
        all_ages.extend(ages)
        all_genders.extend(genders)
        all_age_preds.extend(age_pred.squeeze().cpu().numpy())
        all_gender_preds.extend(torch.argmax(gender_pred, dim=1).cpu().numpy())


In [6]:
from sklearn.metrics import accuracy_score, mean_absolute_error, confusion_matrix

In [14]:

gender_accuracy = accuracy_score(all_genders, all_gender_preds)
gender_conf_matrix = confusion_matrix(all_genders, all_gender_preds)

age_accuracy = accuracy_score(all_ages, np.round(all_age_preds))
# age_conf_matrix = confusion_matrix(all_ages, np.round(all_age_preds))

mae_age = mean_absolute_error(all_ages, all_age_preds)
mae_gender = mean_absolute_error(all_genders, all_gender_preds)

print(f"Age Accuracy: {age_accuracy*100:.4f}%")
# print("Age Confusion Matrix:\n", age_conf_matrix)
print(f"Gender Accuracy: {gender_accuracy*100:.4f}%")
print("Gender Confusion Matrix:\n", gender_conf_matrix)
print(f"Age Mean Absolute Error (MAE): {mae:.2f}")
print(f"Gender Mean Absolute Error (MAE): {mae_gender:.2f}")

Age Accuracy: 18.9163%
Gender Accuracy: 77.7340%
Gender Confusion Matrix:
 [[326 111]
 [115 463]]
Age Mean Absolute Error (MAE): 2.46
Gender Mean Absolute Error (MAE): 0.22
