In [2]:
# Age & Gender Detection Model Training Project
# Final version with an 80% Training, 20% Testing data split.

# ### Stage 1: Import Libraries & Initial Setup ###
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms
from PIL import Image
import os
from tqdm import tqdm
import numpy as np
import copy

# Check for GPU availability and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [3]:
# ### Stage 2: Data Preparation ###
# Define transforms for data augmentation (train) and normalization (val/test)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([ # Used for both validation and testing
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# --- Dataset Classes (separated from transforms for flexibility) ---
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Warning: Skipping corrupted image {img_path}. Error: {e}")
            return None # We will handle this in the DataLoader
        
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.float32)

def load_initial_data(root_dir, task='gender'):
    image_paths = []
    labels = []
    if task == 'gender':
        for label, gender in enumerate(['female', 'male']):
            gender_path = os.path.join(root_dir, gender)
            if not os.path.isdir(gender_path): continue
            for img_name in os.listdir(gender_path):
                image_paths.append(os.path.join(gender_path, img_name))
                labels.append(label)
    elif task == 'age':
        for age_folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, age_folder)
            if os.path.isdir(folder_path):
                try:
                    parts = age_folder.split('-')
                    avg_age = (int(parts[0]) + int(parts[1])) / 2.0 if len(parts) == 2 else float(age_folder)
                    for img_name in os.listdir(folder_path):
                        image_paths.append(os.path.join(folder_path, img_name))
                        labels.append(avg_age)
                except ValueError: continue
    return image_paths, labels
    
def collate_fn_skip_corrupted(batch):
    # Filter out None items from the batch, which happen if an image fails to load
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return torch.tensor([]), torch.tensor([])
    return torch.utils.data.dataloader.default_collate(batch)


# --- Dataset Paths ---
GENDER_DATA_PATH = './dataset/gender/Training/'
AGE_DATA_PATH = './dataset/age/Training/'

# --- Load and Split Datasets (80/10/10) ---
# Gender
gender_paths, gender_labels = load_initial_data(GENDER_DATA_PATH, 'gender')
full_gender_dataset = CustomDataset(gender_paths, gender_labels) # Create dataset without transforms first

train_size = int(0.8 * len(full_gender_dataset))
val_size = int(0.1 * len(full_gender_dataset))
test_size = len(full_gender_dataset) - train_size - val_size
gender_train_subset, gender_val_subset, gender_test_subset = random_split(full_gender_dataset, [train_size, val_size, test_size])

# Apply the correct transforms to each subset
gender_train_dataset = copy.deepcopy(gender_train_subset); gender_train_dataset.dataset.transform = data_transforms['train']
gender_val_dataset = copy.deepcopy(gender_val_subset); gender_val_dataset.dataset.transform = data_transforms['val']
gender_test_dataset = copy.deepcopy(gender_test_subset); gender_test_dataset.dataset.transform = data_transforms['val']

# Age
age_paths, age_labels = load_initial_data(AGE_DATA_PATH, 'age')
full_age_dataset = CustomDataset(age_paths, age_labels)
train_size = int(0.8 * len(full_age_dataset))
val_size = int(0.1 * len(full_age_dataset))
test_size = len(full_age_dataset) - train_size - val_size
age_train_subset, age_val_subset, age_test_subset = random_split(full_age_dataset, [train_size, val_size, test_size])

age_train_dataset = copy.deepcopy(age_train_subset); age_train_dataset.dataset.transform = data_transforms['train']
age_val_dataset = copy.deepcopy(age_val_subset); age_val_dataset.dataset.transform = data_transforms['val']
age_test_dataset = copy.deepcopy(age_test_subset); age_test_dataset.dataset.transform = data_transforms['val']

# Create DataLoaders
BATCH_SIZE = 32
gender_train_loader = DataLoader(gender_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_skip_corrupted)
gender_val_loader = DataLoader(gender_val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_skip_corrupted)
gender_test_loader = DataLoader(gender_test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_skip_corrupted)
age_train_loader = DataLoader(age_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_skip_corrupted)
age_val_loader = DataLoader(age_val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_skip_corrupted)
age_test_loader = DataLoader(age_test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_skip_corrupted)

print(f"Gender Data -> Training: {len(gender_train_dataset)}, Validation: {len(gender_val_dataset)}, Testing: {len(gender_test_dataset)}")
print(f"Age Data    -> Training: {len(age_train_dataset)}, Validation: {len(age_val_dataset)}, Testing: {len(age_test_dataset)}")


Gender Data -> Training: 176, Validation: 22, Testing: 22
Age Data    -> Training: 100, Validation: 12, Testing: 13


In [4]:
# ### Stage 3: Model Architectures (No changes) ###
class ResNet50AgeGenderModel(nn.Module):
    def __init__(self): super(ResNet50AgeGenderModel, self).__init__(); self.base_model = models.resnet50(pretrained=True); in_features = self.base_model.fc.in_features; self.base_model.fc = nn.Identity(); self.gender_head = nn.Sequential(nn.Linear(in_features, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 1)); self.age_head = nn.Sequential(nn.Linear(in_features, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 1));
    def forward(self, x): features = self.base_model(x); return self.gender_head(features), self.age_head(features)
class MobileNetV2AgeGenderModel(nn.Module):
    def __init__(self): super(MobileNetV2AgeGenderModel, self).__init__(); self.base_model = models.mobilenet_v2(pretrained=True); in_features = self.base_model.classifier[1].in_features; self.base_model.classifier = nn.Identity(); self.gender_head = nn.Sequential(nn.Linear(in_features, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 1)); self.age_head = nn.Sequential(nn.Linear(in_features, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 1));
    def forward(self, x): features = self.base_model(x); return self.gender_head(features), self.age_head(features)
class EfficientNetAgeGenderModel(nn.Module):
    def __init__(self): super(EfficientNetAgeGenderModel, self).__init__(); self.base_model = models.efficientnet_b0(pretrained=True); in_features = self.base_model.classifier[1].in_features; self.base_model.classifier = nn.Identity(); self.gender_head = nn.Sequential(nn.Linear(in_features, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 1)); self.age_head = nn.Sequential(nn.Linear(in_features, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 1));
    def forward(self, x): features = self.base_model(x); return self.gender_head(features), self.age_head(features)

In [5]:
# ### Stage 4: Training & Validation Function ###
def train_and_validate(model, model_name, gender_train_loader, age_train_loader, gender_val_loader, age_val_loader, num_epochs=15):
    model.to(device)
    criterion_gender = nn.BCEWithLogitsLoss()
    criterion_age = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001) # Use a smaller learning rate for fine-tuning
    
    print(f'--- Starting Training for {model_name} ---')

    for epoch in range(num_epochs):
        # --- TRAINING PHASE ---
        model.train()
        running_gender_loss, running_age_loss, gender_corrects = 0.0, 0.0, 0
        
        age_train_iter = iter(age_train_loader)
        progress_bar = tqdm(gender_train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]")
        
        for gender_inputs, gender_labels in progress_bar:
            if gender_inputs.nelement() == 0: continue # Skip empty batches
            gender_inputs, gender_labels = gender_inputs.to(device), gender_labels.to(device).unsqueeze(1)
            
            try: age_inputs, age_labels = next(age_train_iter)
            except StopIteration: age_train_iter = iter(age_train_loader); age_inputs, age_labels = next(age_train_iter)
            if age_inputs.nelement() == 0: continue
            age_inputs, age_labels = age_inputs.to(device), age_labels.to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            gender_outputs, _ = model(gender_inputs); _, age_outputs = model(age_inputs)
            loss_gender = criterion_gender(gender_outputs, gender_labels); loss_age = criterion_age(age_outputs, age_labels)
            total_loss = loss_gender + loss_age
            total_loss.backward()
            optimizer.step()
            
            running_gender_loss += loss_gender.item() * gender_inputs.size(0)
            running_age_loss += loss_age.item() * age_inputs.size(0)
            preds = torch.sigmoid(gender_outputs) > 0.5
            gender_corrects += torch.sum(preds == gender_labels.data)
            progress_bar.set_postfix(loss=f"{total_loss.item():.4f}")

        train_gender_loss = running_gender_loss / len(gender_train_loader.dataset); train_age_mae = running_age_loss / len(age_train_loader.dataset); train_gender_acc = gender_corrects.double() / len(gender_train_loader.dataset)

        # --- VALIDATION PHASE ---
        model.eval()
        val_gender_loss, val_age_loss, val_gender_corrects = 0.0, 0.0, 0
        age_val_iter = iter(age_val_loader)
        
        with torch.no_grad():
            for gender_inputs, gender_labels in tqdm(gender_val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]"):
                if gender_inputs.nelement() == 0: continue
                gender_inputs, gender_labels = gender_inputs.to(device), gender_labels.to(device).unsqueeze(1)
                
                try: age_inputs, age_labels = next(age_val_iter)
                except StopIteration: age_val_iter = iter(age_val_loader); age_inputs, age_labels = next(age_val_iter)
                if age_inputs.nelement() == 0: continue
                age_inputs, age_labels = age_inputs.to(device), age_labels.to(device).unsqueeze(1)
                
                gender_outputs, _ = model(gender_inputs); _, age_outputs = model(age_inputs)
                loss_gender = criterion_gender(gender_outputs, gender_labels); loss_age = criterion_age(age_outputs, age_labels)
                
                val_gender_loss += loss_gender.item() * gender_inputs.size(0); val_age_loss += loss_age.item() * age_inputs.size(0)
                preds = torch.sigmoid(gender_outputs) > 0.5
                val_gender_corrects += torch.sum(preds == gender_labels.data)

        val_gender_loss = val_gender_loss / len(gender_val_loader.dataset); val_age_mae = val_age_loss / len(age_val_loader.dataset); val_gender_acc = val_gender_corrects.double() / len(gender_val_loader.dataset)
        
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train: [Acc: {train_gender_acc:.4f}, MAE: {train_age_mae:.4f}] | "
              f"Val: [Acc: {val_gender_acc:.4f}, MAE: {val_age_mae:.4f}]")

    print(f'--- Training for {model_name} Complete! ---')
    return model

In [6]:
# ### Stage 5: Testing Function (NEW) ###
def test_model(model_architecture, model_path, gender_test_loader, age_test_loader):
    print(f"\n--- Starting Final Testing for {model_path} ---")
    
    model = model_architecture()
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()

    criterion_age = nn.L1Loss()
    test_age_loss, test_gender_corrects = 0.0, 0
    age_test_iter = iter(age_test_loader)

    with torch.no_grad():
        for gender_inputs, gender_labels in tqdm(gender_test_loader, desc="[Final Testing]"):
            if gender_inputs.nelement() == 0: continue
            gender_inputs, gender_labels = gender_inputs.to(device), gender_labels.to(device).unsqueeze(1)
            
            try: age_inputs, age_labels = next(age_test_iter)
            except StopIteration: age_test_iter = iter(age_test_loader); age_inputs, age_labels = next(age_test_iter)
            if age_inputs.nelement() == 0: continue
            age_inputs, age_labels = age_inputs.to(device), age_labels.to(device).unsqueeze(1)

            gender_outputs, _ = model(gender_inputs); _, age_outputs = model(age_inputs)
            
            loss_age = criterion_age(age_outputs, age_labels)
            test_age_loss += loss_age.item() * age_inputs.size(0)
            preds = torch.sigmoid(gender_outputs) > 0.5
            test_gender_corrects += torch.sum(preds == gender_labels.data)

    final_age_mae = test_age_loss / len(age_test_loader.dataset)
    final_gender_acc = test_gender_corrects.double() / len(gender_test_loader.dataset)

    print("\n--- FINAL TEST RESULTS ---")
    print(f"Model: {model_path}")
    print(f"  Gender Accuracy on Test Set: {final_gender_acc:.4f}")
    print(f"  Age MAE on Test Set: {final_age_mae:.4f}")
    print("--------------------------")

In [9]:
# ### Stage 6: Main Execution Block ###
def main():
    # --- CHOOSE THE MODEL TO TRAIN AND ITS OUTPUT FILENAME HERE ---
    # model_architecture = ResNet50AgeGenderModel
    # output_filename = 'resnet50_age_gender.pth'
    
    # model_architecture = MobileNetV2AgeGenderModel
    # output_filename = 'mobilenetv2_age_gender.pth'
    
    model_architecture = EfficientNetAgeGenderModel
    output_filename = 'efficientnet_age_gender.pth'

    # 1. Train and Validate the Model
    model_to_train = model_architecture()
    trained_model = train_and_validate(
        model=model_to_train, 
        model_name=output_filename.split('_')[0].upper(),
        gender_train_loader=gender_train_loader, 
        age_train_loader=age_train_loader,
        gender_val_loader=gender_val_loader,
        age_val_loader=age_val_loader,
        num_epochs=15
    )
    # Save the best model
    torch.save(trained_model.state_dict(), output_filename)
    print(f"Model successfully saved as {output_filename}")

    # 2. Perform Final Testing on the Unseen Test Set
    test_model(
        model_architecture=model_architecture,
        model_path=output_filename,
        gender_test_loader=gender_test_loader,
        age_test_loader=age_test_loader
    )

if __name__ == '__main__':
    main()





--- Starting Training for EFFICIENTNET ---


Epoch 1/15 [Training]: 100%|██████████| 6/6 [01:09<00:00, 11.51s/it, loss=33.0303]
Epoch 1/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.71s/it]


Epoch 1/15 | Train: [Acc: 0.4716, MAE: 58.6599] | Val: [Acc: 0.5909, MAE: 35.5828]


Epoch 2/15 [Training]: 100%|██████████| 6/6 [01:06<00:00, 11.14s/it, loss=38.2669]
Epoch 2/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.61s/it]


Epoch 2/15 | Train: [Acc: 0.6420, MAE: 59.3499] | Val: [Acc: 0.7273, MAE: 34.9668]


Epoch 3/15 [Training]: 100%|██████████| 6/6 [01:06<00:00, 11.11s/it, loss=36.3815]
Epoch 3/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.69s/it]


Epoch 3/15 | Train: [Acc: 0.7443, MAE: 56.8466] | Val: [Acc: 0.8182, MAE: 34.1495]


Epoch 4/15 [Training]: 100%|██████████| 6/6 [01:06<00:00, 11.12s/it, loss=35.1789]
Epoch 4/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.65s/it]


Epoch 4/15 | Train: [Acc: 0.8182, MAE: 55.5145] | Val: [Acc: 0.8636, MAE: 32.8771]


Epoch 5/15 [Training]: 100%|██████████| 6/6 [01:06<00:00, 11.14s/it, loss=32.5387]
Epoch 5/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.61s/it]


Epoch 5/15 | Train: [Acc: 0.8125, MAE: 53.4834] | Val: [Acc: 0.8636, MAE: 30.9741]


Epoch 6/15 [Training]: 100%|██████████| 6/6 [01:07<00:00, 11.19s/it, loss=33.2304]
Epoch 6/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.64s/it]


Epoch 6/15 | Train: [Acc: 0.8750, MAE: 51.4855] | Val: [Acc: 0.8636, MAE: 28.6725]


Epoch 7/15 [Training]: 100%|██████████| 6/6 [01:07<00:00, 11.19s/it, loss=28.4866]
Epoch 7/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.83s/it]


Epoch 7/15 | Train: [Acc: 0.8693, MAE: 48.4894] | Val: [Acc: 0.8636, MAE: 25.6587]


Epoch 8/15 [Training]: 100%|██████████| 6/6 [01:06<00:00, 11.09s/it, loss=24.1122]
Epoch 8/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.64s/it]


Epoch 8/15 | Train: [Acc: 0.8750, MAE: 44.6286] | Val: [Acc: 0.8636, MAE: 22.2592]


Epoch 9/15 [Training]: 100%|██████████| 6/6 [01:06<00:00, 11.12s/it, loss=23.2229]
Epoch 9/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.74s/it]


Epoch 9/15 | Train: [Acc: 0.8750, MAE: 41.1288] | Val: [Acc: 0.8636, MAE: 20.1599]


Epoch 10/15 [Training]: 100%|██████████| 6/6 [01:18<00:00, 13.01s/it, loss=22.9787]
Epoch 10/15 [Validation]: 100%|██████████| 1/1 [00:04<00:00,  4.45s/it]


Epoch 10/15 | Train: [Acc: 0.8920, MAE: 36.8466] | Val: [Acc: 0.8636, MAE: 18.2080]


Epoch 11/15 [Training]: 100%|██████████| 6/6 [01:19<00:00, 13.25s/it, loss=20.1371]
Epoch 11/15 [Validation]: 100%|██████████| 1/1 [00:04<00:00,  4.45s/it]


Epoch 11/15 | Train: [Acc: 0.8807, MAE: 33.5454] | Val: [Acc: 0.8636, MAE: 15.0443]


Epoch 12/15 [Training]: 100%|██████████| 6/6 [01:09<00:00, 11.59s/it, loss=18.2822]
Epoch 12/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.79s/it]


Epoch 12/15 | Train: [Acc: 0.9148, MAE: 29.7282] | Val: [Acc: 0.8636, MAE: 13.4695]


Epoch 13/15 [Training]: 100%|██████████| 6/6 [01:09<00:00, 11.58s/it, loss=14.6191]
Epoch 13/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.71s/it]


Epoch 13/15 | Train: [Acc: 0.8977, MAE: 23.6088] | Val: [Acc: 0.8636, MAE: 11.0238]


Epoch 14/15 [Training]: 100%|██████████| 6/6 [01:09<00:00, 11.63s/it, loss=11.1859]
Epoch 14/15 [Validation]: 100%|██████████| 1/1 [00:04<00:00,  4.45s/it]


Epoch 14/15 | Train: [Acc: 0.9034, MAE: 19.0689] | Val: [Acc: 0.8636, MAE: 11.1100]


Epoch 15/15 [Training]: 100%|██████████| 6/6 [01:14<00:00, 12.35s/it, loss=12.4815]
Epoch 15/15 [Validation]: 100%|██████████| 1/1 [00:03<00:00,  3.60s/it]


Epoch 15/15 | Train: [Acc: 0.9091, MAE: 17.3838] | Val: [Acc: 0.8636, MAE: 9.2213]
--- Training for EFFICIENTNET Complete! ---
Model successfully saved as efficientnet_age_gender.pth

--- Starting Final Testing for efficientnet_age_gender.pth ---


[Final Testing]: 100%|██████████| 1/1 [00:03<00:00,  3.89s/it]


--- FINAL TEST RESULTS ---
Model: efficientnet_age_gender.pth
  Gender Accuracy on Test Set: 0.8182
  Age MAE on Test Set: 10.5810
--------------------------



