In [None]:
import os
import scipy.io
import pandas as pd
import numpy as np
import requests
import tarfile
from PIL import Image
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# ========== CONFIG ==========
DATA_DIR = 'data'
IMG_SIZE = 128
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
MODEL_PATH = '/content/drive/MyDrive/best_age_model.pth'
LIMIT_IMAGES = 5000  # use smaller dataset for testing
# ============================

# ---------- Step 1: Download ----------
def download_dataset():
    url = "https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar"
    os.makedirs(DATA_DIR, exist_ok=True)
    tar_path = os.path.join(DATA_DIR, "wiki_crop.tar")
    if not os.path.exists(tar_path):
        print("Downloading IMDB-WIKI subset...")
        with requests.get(url, stream=True) as r:
            with open(tar_path, 'wb') as f:
                for chunk in tqdm(r.iter_content(1024*1024), desc="Downloading", unit="MB"):
                    f.write(chunk)
    with tarfile.open(tar_path) as tar:
        tar.extractall(DATA_DIR)
        print("Extracted dataset.")

# ---------- Step 2: Preprocess ----------
def extract_labels():
    print("Processing .mat label file...")
    mat = scipy.io.loadmat(os.path.join(DATA_DIR, "wiki_crop/wiki.mat"))
    mat = mat['wiki'][0, 0]
    full_path = mat['full_path'][0]
    dob = mat['dob'][0]
    photo_taken = mat['photo_taken'][0]

    ages = []
    filenames = []
    for i in range(len(full_path)):
        age = photo_taken[i] - datetime.fromordinal(int(dob[i])).year
        if 0 < age <= 100:
            ages.append(age)
            filenames.append(full_path[i][0])
        if len(filenames) >= LIMIT_IMAGES:
            break

    df = pd.DataFrame({'filename': filenames, 'age': ages})
    df.to_csv(os.path.join(DATA_DIR, 'full_labels.csv'), index=False)
    print(f"Saved full_labels.csv with {len(df)} samples.")

    resize_and_copy(df)

def resize_and_copy(df):
    src_folder = os.path.join(DATA_DIR, 'wiki_crop')
    dst_folder = os.path.join(DATA_DIR, 'images')
    os.makedirs(dst_folder, exist_ok=True)
    print("Resizing images...")
    for i, row in tqdm(df.iterrows(), total=len(df)):
        src = os.path.join(src_folder, row['filename'])
        dst = os.path.join(dst_folder, f"{i}.jpg")
        try:
            img = Image.open(src).convert('RGB').resize((IMG_SIZE, IMG_SIZE))
            img.save(dst)
        except:
            continue
        df.at[i, 'filename'] = f"{i}.jpg"
    df.to_csv(os.path.join(DATA_DIR, 'full_labels.csv'), index=False)

# ---------- Step 3: Dataset ----------
class AgeDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])
        image = Image.open(img_path).convert('RGB')
        age = torch.tensor(row['age'], dtype=torch.float32)  # FIX: ensure float32
        if self.transform:
            image = self.transform(image)
        return image, age

# ---------- Step 4: Model ----------
class SimpleAgeCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.fc(self.cnn(x)).squeeze(1)

# ---------- Step 5: Training ----------
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    df = pd.read_csv(os.path.join(DATA_DIR, 'full_labels.csv'))
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
    train_df.to_csv(os.path.join(DATA_DIR, 'train_labels.csv'), index=False)
    val_df.to_csv(os.path.join(DATA_DIR, 'val_labels.csv'), index=False)

    train_dataset = AgeDataset(os.path.join(DATA_DIR, 'train_labels.csv'), os.path.join(DATA_DIR, 'images'), transform)
    val_dataset = AgeDataset(os.path.join(DATA_DIR, 'val_labels.csv'), os.path.join(DATA_DIR, 'images'), transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    model = SimpleAgeCNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.MSELoss()

    best_val_loss = float('inf')

    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_train_loss = 0
        for imgs, ages in tqdm(train_loader, desc=f"Epoch {epoch} - Train"):
            imgs, ages = imgs.to(device), ages.to(device)
            preds = model(imgs)
            loss = criterion(preds, ages)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item() * imgs.size(0)

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for imgs, ages in tqdm(val_loader, desc=f"Epoch {epoch} - Val"):
                imgs, ages = imgs.to(device), ages.to(device)
                preds = model(imgs)
                loss = criterion(preds, ages)
                total_val_loss += loss.item() * imgs.size(0)

        avg_train = total_train_loss / len(train_dataset)
        avg_val = total_val_loss / len(val_dataset)
        print(f"Epoch {epoch}: Train Loss = {avg_train:.4f}, Val Loss = {avg_val:.4f}")

        if avg_val < best_val_loss:
            best_val_loss = avg_val
            torch.save(model.state_dict(), MODEL_PATH)
            print("📦 Saved best model")

# ---------- Run All ----------
if __name__ == '__main__':
    download_dataset()
    extract_labels()
    train_model()


Downloading IMDB-WIKI subset...


Downloading: 774MB [00:37, 20.75MB/s]


Extracted dataset.
Processing .mat label file...


  age = photo_taken[i] - datetime.fromordinal(int(dob[i])).year


Saved full_labels.csv with 5000 samples.
Resizing images...


100%|██████████| 5000/5000 [00:10<00:00, 467.63it/s]
Epoch 1 - Train: 100%|██████████| 71/71 [00:06<00:00, 10.86it/s]
Epoch 1 - Val: 100%|██████████| 8/8 [00:00<00:00, 16.26it/s]


Epoch 1: Train Loss = 533.7124, Val Loss = 340.5427
📦 Saved best model


Epoch 2 - Train: 100%|██████████| 71/71 [00:04<00:00, 17.36it/s]
Epoch 2 - Val: 100%|██████████| 8/8 [00:00<00:00, 15.84it/s]


Epoch 2: Train Loss = 309.1904, Val Loss = 308.5546
📦 Saved best model


Epoch 3 - Train: 100%|██████████| 71/71 [00:04<00:00, 17.59it/s]
Epoch 3 - Val: 100%|██████████| 8/8 [00:00<00:00, 11.65it/s]


Epoch 3: Train Loss = 298.2868, Val Loss = 291.3585
📦 Saved best model


Epoch 4 - Train: 100%|██████████| 71/71 [00:05<00:00, 13.12it/s]
Epoch 4 - Val: 100%|██████████| 8/8 [00:00<00:00, 16.01it/s]


Epoch 4: Train Loss = 294.4931, Val Loss = 282.2455
📦 Saved best model


Epoch 5 - Train: 100%|██████████| 71/71 [00:03<00:00, 17.76it/s]
Epoch 5 - Val: 100%|██████████| 8/8 [00:00<00:00, 16.02it/s]


Epoch 5: Train Loss = 282.8377, Val Loss = 275.2137
📦 Saved best model


Epoch 6 - Train: 100%|██████████| 71/71 [00:04<00:00, 15.38it/s]
Epoch 6 - Val: 100%|██████████| 8/8 [00:00<00:00,  9.94it/s]


Epoch 6: Train Loss = 279.4037, Val Loss = 273.1553
📦 Saved best model


Epoch 7 - Train: 100%|██████████| 71/71 [00:04<00:00, 15.68it/s]
Epoch 7 - Val: 100%|██████████| 8/8 [00:00<00:00, 16.46it/s]


Epoch 7: Train Loss = 278.6101, Val Loss = 282.6181


Epoch 8 - Train: 100%|██████████| 71/71 [00:04<00:00, 17.66it/s]
Epoch 8 - Val: 100%|██████████| 8/8 [00:00<00:00, 15.95it/s]


Epoch 8: Train Loss = 273.5891, Val Loss = 276.2077


Epoch 9 - Train: 100%|██████████| 71/71 [00:05<00:00, 13.60it/s]
Epoch 9 - Val: 100%|██████████| 8/8 [00:00<00:00, 12.55it/s]


Epoch 9: Train Loss = 271.4840, Val Loss = 271.1701
📦 Saved best model


Epoch 10 - Train: 100%|██████████| 71/71 [00:04<00:00, 17.07it/s]
Epoch 10 - Val: 100%|██████████| 8/8 [00:00<00:00, 15.69it/s]

Epoch 10: Train Loss = 274.4053, Val Loss = 271.1621
📦 Saved best model





In [None]:
import os
print("Current working directory:", os.getcwd())


Current working directory: /content


In [None]:
def evaluate_model(model_path=MODEL_PATH, tolerance=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_dataset = AgeDataset(
        os.path.join(DATA_DIR, 'val_labels.csv'),
        os.path.join(DATA_DIR, 'images'),
        transform=transform
    )
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    model = SimpleAgeCNN().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    mae = 0
    total = 0
    within_tol = 0

    with torch.no_grad():
        for imgs, ages in tqdm(val_loader, desc="Evaluating"):
            imgs, ages = imgs.to(device), ages.to(device)
            preds = model(imgs)
            error = torch.abs(preds - ages)
            mae += error.sum().item()
            within_tol += (error <= tolerance).sum().item()
            total += imgs.size(0)

    mae /= total
    acc_tol = within_tol / total * 100
    print(f"\n📊 MAE: {mae:.2f} years")
    print(f"🎯 Accuracy within ±{tolerance} years: {acc_tol:.2f}%")


In [None]:
evaluate_model()


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  8.99it/s]


📊 MAE: 13.20 years
🎯 Accuracy within ±5 years: 20.40%





In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def age_to_class(age):
    return min(age // 10, 9)  # Cap at class 9 (for ages >= 90)

def evaluate_classification(model_path=MODEL_PATH):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    dataset = AgeDataset(
        csv_file=os.path.join(DATA_DIR, 'val_labels.csv'),
        img_dir=os.path.join(DATA_DIR, 'images'),
        transform=transform
    )
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = SimpleAgeCNN().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    y_true = []
    y_pred = []

    with torch.no_grad():
        for imgs, ages in tqdm(loader, desc="Evaluating Classification"):
            imgs = imgs.to(device)
            preds = model(imgs).cpu().numpy()
            true_ages = ages.numpy()

            y_true.extend([age_to_class(a) for a in true_ages])
            y_pred.extend([age_to_class(p) for p in preds])

    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)

    print(f"🎯 Accuracy:  {acc:.4f}")
    print(f"🎯 Precision: {precision:.4f}")
    print(f"🎯 Recall:    {recall:.4f}")
    print(f"🎯 F1 Score:  {f1:.4f}")


In [None]:
evaluate_classification()

Evaluating Classification: 100%|██████████| 8/8 [00:01<00:00,  7.25it/s]

🎯 Accuracy:  0.2220
🎯 Precision: 0.3477
🎯 Recall:    0.2220
🎯 F1 Score:  0.1359



