# FacesMultiNet Training Notebook | Shreyan Chaubey

The purpose of the notebook is to conduct multitask learning experiments using shared representations derived from a common feature extractor attached with two task specific heads.

- **Author**: Shreyan Chaubey (22f3001642@ds.study.iitm.ac.in)
- **Model**: [thethinkmachine/EfficientNetV2-S-FacesMTL-EXP1](https://huggingface.co/thethinkmachine/EfficientNetV2-S-FacesMTL-EXP1) (best candidate)
- **Training Dataset**: [thethinkmachine/faces-mtl](thethinkmachine/faces-mtl) (train split) (Sep '25 DLGenAI NPPE 1 Competition Dataset)
- **Evaluation Dataset**: [thethinkmachine/faces-mtl](thethinkmachine/faces-mtl) (eval split) (Sep '25 DLGenAI NPPE 1 Competition Dataset)
- **Libraries Used**: PyTorch 2.9 (cuda-13.0)
- **Training Hardware**:
    - **GPU**: Nvidia RTX 2070 Super (8 GB VRAM),
    - **CPU**: AMD Ryzen 5 3600,
    - **RAM**: 32GB DDR4 @3.2Ghz
- **Date**: 12-11-2025

## This is an inference notebook. It doesn't train. Training experiments were carried out on a separate training notebook which is available on GitHub at [github.com/thethinkmachine/FacesMultiNet](https://github.com/thethinkmachine/FacesMultiNet)

### Experimentation Details
All pre and post training experiments were carried out on the following backbone architectures, each experiments' performance on the Kaggle leaderboard is summarized with an emoji.
- Scratch CNN üòí (pre-training)
- ResNet50 (ImageNet1k v2 checkpoints) (8k training datapoints) üôÇ
- ResNet50 (ImageNet1K v2 checkpoints) (27.7k training datapoints) üòä
- EfficientNetV1 B4 (ImageNet1K v1 checkpoints) üòê
- EfficientNetV2-s (ImageNet1K v1 checkpoints) (best candidate) ü§©

As such, this notebook has been preconfigured to perform inference & generate `submission.csv` using the best candidate model.

## Imports

In [None]:
import os
import trackio
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.ops.focal_loss as focal_loss
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Optional, Tuple
from transformers.file_utils import ModelOutput
from torch.utils.data import DataLoader, Dataset
from transformers import Trainer, TrainingArguments, PreTrainedModel, PretrainedConfig
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets, load_from_disk
from PIL import Image
from pathlib import Path
from sklearn.metrics import mean_squared_error, mean_absolute_error, f1_score, accuracy_score

## **Data**

### Load

In [None]:
if not os.path.exists("./faces_mtl_transformed"):
    dataset = load_dataset("thethinkmachine/faces-mtl", split="train", streaming=False)
    LOAD_FRESH = True
else:
    faces_mtl = DatasetDict.load_from_disk("./faces_mtl_transformed")
    LOAD_FRESH = False

### Split `(0.8:0.1:0.1)`

In [None]:
if LOAD_FRESH:
    train_val = dataset.train_test_split(test_size=0.2, shuffle=False)
    val_test = train_val['test'].train_test_split(test_size=0.5, shuffle=False)
    
    faces_mtl = DatasetDict({
    'train': train_val['train'],
    'validation': val_test['train'],
    'test': val_test['test']
    })

### Define transforms

In [None]:
if LOAD_FRESH:
    imagenet_mean = [0.485, 0.456, 0.406] # imagenet mean
    imagenet_std = [0.229, 0.224, 0.225] # imagenet std

    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ])

### Apply transforms

In [None]:
if LOAD_FRESH:
    def preprocess_train(batch):
        batch["pixel_values"] = [train_transform(image.convert("RGB")) for image in batch["image"]]
        return batch

    def preprocess_val_test(batch):
        batch["pixel_values"] = [val_test_transform(image.convert("RGB")) for image in batch["image"]]
        return batch


    faces_mtl['train'] = faces_mtl['train'].map(preprocess_train, batched=True)
    faces_mtl['validation'] = faces_mtl['validation'].map(preprocess_val_test, batched=True)
    faces_mtl['test'] = faces_mtl['test'].map(preprocess_val_test, batched=True)

### Set training format

In [None]:
if LOAD_FRESH:
    faces_mtl.set_format(type='torch', columns=['pixel_values', 'age', 'gender'])

### Save to disk for later reuse

In [None]:
if LOAD_FRESH:
    faces_mtl.save_to_disk("./faces_mtl_transformed")

# **Model**

### Define dataclass for model outputs

In [None]:
@dataclass
class MTLOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[Tuple[torch.Tensor, ...]] = None
    hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
    attentions: Optional[Tuple[torch.Tensor, ...]] = None

### Define model configuration

In [None]:
class FacesMultiNetConfig(PretrainedConfig):
    model_type = "facesmultitasknet"
    def __init__(self, 
                 num_age_labels=1, 
                 num_gender_labels=2, 
                 age_loss_weight=1.0, 
                 gender_loss_weight=1.0,
                 backbone_type=None, # must define from (resnet50, efficientnet_b4, efficientnetv2_s & custom_cnn)
                 **kwargs):
        super().__init__(**kwargs)
        self.num_age_labels = num_age_labels
        self.num_gender_labels = num_gender_labels
        self.age_loss_weight = age_loss_weight
        self.gender_loss_weight = gender_loss_weight
        self.backbone_type = backbone_type

### Define model architecture

`PreTrainedModel` subclasses torch.nn.Module, and throws in some really nice convenience features.

In [None]:
class FacesMultiNet(PreTrainedModel):
    config_class = FacesMultiNetConfig

    def __init__(self, config: FacesMultiNetConfig):
        super().__init__(config)
        self.config = config
        
        # CNN backbone to use, part of experimentation
        if config.backbone_type == "efficientnet_b4":
            self.backbone = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.IMAGENET1K_V1)
            self.backbone.classifier = nn.Identity()
            backbone_out_features = 1792
        elif config.backbone_type == "resnet50":
            self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
            self.backbone.fc = nn.Identity()
            backbone_out_features = 2048
        elif config.backbone_type == "efficientnetv2_s":
            self.backbone = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
            self.backbone.classifier = nn.Identity()
            backbone_out_features = 1280
        elif config.backbone_type == "custom_cnn":
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 32, kernel_size=3, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(32, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten()
            )
            backbone_out_features = 256
        else:
            raise ValueError(f"Unsupported backbone type: {config.backbone_type}")

        # Age prediction head
        self.age_head = nn.Sequential(
            nn.Linear(backbone_out_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, config.num_age_labels)
        )

        # Gender prediction head
        self.gender_head = nn.Sequential(
            nn.Linear(backbone_out_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, config.num_gender_labels)
        )

        # Loss fn for respective heads
        self.age_loss_fn = nn.MSELoss()
        self.gender_loss_fn = nn.CrossEntropyLoss()

    def forward(self, pixel_values, age=None, gender=None, return_dict=True, **kwargs):
        features = self.backbone(pixel_values)
        age_logits = self.age_head(features)
        gender_logits = self.gender_head(features)

        loss = None
        if age is not None and gender is not None:
            age_float = age.float().view(-1, 1)
            age_loss = self.age_loss_fn(age_logits, age_float)
            gender_loss = self.gender_loss_fn(gender_logits, gender)
            loss = (self.config.age_loss_weight * age_loss +
                    self.config.gender_loss_weight * gender_loss)

        if not return_dict:
            output = (age_logits, gender_logits)
            return (loss,) + output if loss is not None else output

        return MTLOutput(
            loss=loss,
            logits=(age_logits, gender_logits)
        )

# **Metrics**

### Define metrics to be computed from model outputs

In [None]:
def compute_metrics_multitask(eval_pred):
    age_preds_raw, gender_logits = eval_pred.predictions
    age_labels, gender_labels = eval_pred.label_ids
    gender_preds = np.argmax(gender_logits, axis=-1)
    gender_acc = accuracy_score(gender_labels, gender_preds)
    gender_f1 = f1_score(gender_labels, gender_preds, average='macro')
    age_preds = age_preds_raw.flatten()
    mae = mean_absolute_error(age_labels, age_preds)
    mse = mean_squared_error(age_labels, age_preds)
    rmse = np.sqrt(mse)

    return {
        'eval_gender_accuracy': gender_acc,
        'eval_gender_f1': gender_f1,
        'eval_age_mae': mae,
        'eval_age_rmse': rmse,
    }

# **Training**

### Instantiate model object from defined config

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_config = FacesMultiNetConfig(backbone_type="efficientnetv2_s")
model = FacesMultiNet(model_config).to(device)

### Define training args

In [None]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=5,
    auto_find_batch_size=True,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    optim="adamw_torch",
    weight_decay=1.0e-4,
    learning_rate=1.0e-4,
    lr_scheduler_type='cosine',
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=150,
    save_strategy="steps",
    save_steps=150,
    load_best_model_at_end=True,
    metric_for_best_model="eval_age_rmse",
    greater_is_better=False,
    remove_unused_columns=False,
    report_to=['trackio'],
    fp16=torch.cuda.is_available(),
    hub_model_id="thethinkmachine/ScratchCNN-FacesMTL-EXP1",
    hub_private_repo=True,
    hub_strategy="checkpoint",
    run_name="FacesMTL-Experiment-ScratchCNN"
)

### Subclass `Trainer` to handle multitask prediction steps

In [None]:
class MultiTaskTrainer(Trainer):
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        has_labels = all(inputs.get(k) is not None for k in ["age", "gender"])

        if has_labels:
            labels = (inputs.get("age"), inputs.get("gender"))
        else:
            labels = None

        with torch.no_grad():
            outputs = model(**inputs)
            loss = outputs.loss if has_labels else None
            age_logits, gender_logits = outputs.logits

        if prediction_loss_only:
            return (loss, None, None)

        logits = (age_logits.detach(), gender_logits.detach())
        
        if labels is not None:
            labels = tuple(lab.detach() if isinstance(lab, torch.Tensor) else lab for lab in labels)
        return (loss, logits, labels)

### Instantiate `MultiTaskTrainer`

In [None]:
trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=faces_mtl['train'],
    eval_dataset=faces_mtl['validation'],
    compute_metrics=compute_metrics_multitask
)

# Train

In [None]:
trainer.train()

# Evaluate

In [None]:
test_results = trainer.evaluate(faces_mtl['test'])
print("Test Results:")
for key, value in test_results.items():
    print(f"{key}: {value:.4f}")

# Push to Hub

In [None]:
trainer.create_model_card(
    license="apache-2.0",
    finetuned_from="None",
    tags=["multitask-learning", "cnn", "computer-vision"],
    tasks=["image-classification", "regression"],
    dataset="thethinkmachine/faces-mtl",
    language="en",
    dataset_tags=["faces-mtl"],
    dataset_args="thethinkmachine/faces-mtl"
)

In [None]:
trainer.push_to_hub()