In [1]:
import os, re, math, random, time, copy
from pathlib import Path
from typing import Tuple, List, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image

# tqdm: avoid notebook widget; fall back to plain
try:
    from tqdm import tqdm
except Exception:
    def tqdm(x, **k): return x

In [4]:
# ------------------
# Config
# ------------------
DATA_DIR = r"D:\utk_gender_balanced_6000"
SAVE_DIR = "./checkpoints_fas_utk"
os.makedirs(SAVE_DIR, exist_ok=True)

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 0  # Windows-safe
MAX_EPOCHS = 30
EARLY_STOP_PATIENCE = 5
BASE_LR = 1e-3
WEIGHT_DECAY = 1e-4
LABEL_SMOOTH = 0.0  # standard CE smoothing; handled outside FAS

# FAS hyperparams
FAS_C = 0.5            # blend between individual li and group beta
FAS_EMA_ALPHA = 0.9    # per-sample EMA of loss
FAS_CLIP_MIN = 0.2     # clamp weights
FAS_CLIP_MAX = 5.0
FAS_BETA_LR_SCALE = 0.1  # slower LR for beta

# EarlyStopping composite score: balanced_acc - GAP_W * |acc_f - acc_m|
GAP_W = 0.25

VAL_SPLIT = 0.2
TEST_SPLIT = 0.0  # set >0 if you want a separate test split

In [6]:

# ------------------
# Utils
# ------------------
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(SEED)

def parse_gender_from_name(name: str) -> Optional[int]:
    # UTKFace filename: age_gender_race_date.jpg  (gender: 0 male, 1 female)
    # robust parse
    base = os.path.basename(name)
    m = re.match(r"(\d+)_(\d)_(\d)_", base)
    if m:
        return int(m.group(2))
    return None

def discover_images(data_dir: str) -> List[Tuple[str,int]]:
    """Return list of (path, gender_label). Supports:
       1) subfolders like .../male, .../female or .../0, .../1
       2) UTKFace filename convention
    """
    paths = []
    data_dir = Path(data_dir)

    # Case 1: subfolders
    subs = [p for p in data_dir.iterdir() if p.is_dir()]
    sub_ok = False
    if subs:
        mapping = {}
        for sub in subs:
            key = sub.name.lower()
            if key in ("male", "m", "0"): mapping[sub] = 0; sub_ok = True
            if key in ("female", "f", "1"): mapping[sub] = 1; sub_ok = True
        if sub_ok:
            for sub, lab in mapping.items():
                for ext in ("*.jpg","*.jpeg","*.png","*.bmp","*.webp"):
                    for f in sub.rglob(ext):
                        paths.append((str(f), lab))
    # Case 2: filename parse
    if not sub_ok:
        for ext in ("*.jpg","*.jpeg","*.png","*.bmp","*.webp"):
            for f in data_dir.rglob(ext):
                g = parse_gender_from_name(f.name)
                if g is not None:
                    paths.append((str(f), g))

    if not paths:
        raise RuntimeError(f"No images found or labels not parsed in: {data_dir}")
    return paths

In [8]:
# ------------------
# Dataset
# ------------------
class UTKGenderDataset(Dataset):
    def __init__(self, items: List[Tuple[str,int]], transform=None):
        self.items = items
        self.transform = transform

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        path, y = self.items[idx]
        img = Image.open(path).convert("RGB")
        if self.transform: img = self.transform(img)
        # group = gender here (0 male, 1 female)
        return img, torch.tensor(y, dtype=torch.long), torch.tensor(y, dtype=torch.long), torch.tensor(idx, dtype=torch.long)

# ------------------
# Transforms
# ------------------
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1,0.1,0.1,0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])
val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

In [10]:
# ------------------
# Split data
# ------------------
all_items = discover_images(DATA_DIR)
random.shuffle(all_items)
n = len(all_items)
n_test = int(n * TEST_SPLIT)
n_val = int(n * VAL_SPLIT)

test_items = all_items[:n_test]
val_items = all_items[n_test:n_test+n_val]
train_items = all_items[n_test+n_val:]

train_ds = UTKGenderDataset(train_items, train_tfms)
val_ds = UTKGenderDataset(val_items, val_tfms)
test_ds = UTKGenderDataset(test_items, val_tfms) if n_test>0 else None

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) if test_ds else None

In [12]:
# ------------------
# Model: ResNet18 + Dropout
# ------------------
class ResNet18WithDropout(nn.Module):
    def __init__(self, pretrained=True, p_drop=0.3, n_classes=2):
        super().__init__()
        m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
        in_feat = m.fc.in_features
        # replace head with dropout + linear
        m.fc = nn.Sequential(
            nn.Dropout(p=p_drop),
            nn.Linear(in_feat, n_classes)
        )
        self.net = m
    def forward(self, x): return self.net(x)

model = ResNet18WithDropout(pretrained=True, p_drop=0.3, n_classes=2).to(DEVICE)

In [14]:
# ------------------
# Fair Adaptive Scaling (FAS) Loss
# ------------------
class FairAdaptiveScalingLoss(nn.Module):
    """
    Combines individual scaling l_i (EMA of past loss per-sample) and
    learnable group scaling beta[g], then weights standard CE loss.
    """
    def __init__(self, num_groups:int, num_samples:int, c:float=0.5, ema_alpha:float=0.9,
                 clip_min:float=0.2, clip_max:float=5.0):
        super().__init__()
        assert 0 <= c <= 1
        self.num_groups = num_groups
        self.c = c
        self.ema_alpha = ema_alpha
        self.clip_min = clip_min
        self.clip_max = clip_max

        # per-sample EMA of loss
        self.register_buffer("ema_loss", torch.zeros(num_samples, dtype=torch.float32))

        # learnable group weights beta
        self.beta = nn.Parameter(torch.ones(num_groups, dtype=torch.float32))

        self.ce = nn.CrossEntropyLoss(reduction='none')

    @torch.no_grad()
    def update_ema(self, idxs: torch.Tensor, losses: torch.Tensor):
        # ema[idx] = alpha*ema + (1-alpha)*current
        self.ema_loss[idxs] = self.ema_alpha * self.ema_loss[idxs] + (1 - self.ema_alpha) * losses.detach()

    def forward(self, logits: torch.Tensor, targets: torch.Tensor, groups: torch.Tensor, idxs: torch.Tensor):
        # base per-sample loss
        per_sample = self.ce(logits, targets)  # [B]

        # update EMA (before using)
        with torch.no_grad():
            self.update_ema(idxs, per_sample)

        # individual scaling l_i from EMA of loss (normalize by batch mean to be stable)
        with torch.no_grad():
            ema_vals = self.ema_loss[idxs]  # [B]
            batch_mean = ema_vals.mean().clamp(min=1e-8)
            l_i = (ema_vals / batch_mean)

        # group scaling beta_g (learnable)
        beta_g = self.beta[groups]  # [B]

        # combine
        c_i = self.c * l_i + (1 - self.c) * beta_g
        c_i = c_i.clamp(self.clip_min, self.clip_max)

        # weighted loss
        loss = (c_i * per_sample).mean()
        return loss, c_i.detach(), per_sample.detach()

In [16]:
# ------------------
# Metrics
# ------------------
@torch.no_grad()
def accuracies(y_true: torch.Tensor, y_pred: torch.Tensor, g: torch.Tensor) -> Dict[str, float]:
    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()
    g = g.cpu().numpy()
    overall = (y_true == y_pred).mean()
    acc0 = (y_pred[g==0] == y_true[g==0]).mean() if np.any(g==0) else np.nan
    acc1 = (y_pred[g==1] == y_true[g==1]).mean() if np.any(g==1) else np.nan
    # balanced class accuracy (over labels 0/1)
    b0 = (y_pred[y_true==0] == 0).mean() if np.any(y_true==0) else np.nan
    b1 = (y_pred[y_true==1] == 1).mean() if np.any(y_true==1) else np.nan
    bacc = np.nanmean([b0,b1])
    # worst-group accuracy over gender
    wg = np.nanmin([acc0, acc1])
    gap = np.abs(acc0 - acc1) if (not np.isnan(acc0) and not np.isnan(acc1)) else np.nan
    return dict(overall=overall, male=acc0, female=acc1, balanced=bacc, worst_group=wg, gap=gap)

def composite_score(balanced_acc: float, gap: float, gap_w: float=GAP_W) -> float:
    if np.isnan(gap): gap = 0.0
    return float(balanced_acc - gap_w * gap)

In [18]:
# ------------------
# Train setup
# ------------------
train_len = len(train_ds)
num_groups = 2  # gender
fas_loss = FairAdaptiveScalingLoss(num_groups=num_groups, num_samples=train_len,
                                   c=FAS_C, ema_alpha=FAS_EMA_ALPHA,
                                   clip_min=FAS_CLIP_MIN, clip_max=FAS_CLIP_MAX).to(DEVICE)

# CE for reporting (unweighted)
ce_report = nn.CrossEntropyLoss(reduction='mean')

# Optimizer with smaller LR for beta
params = [
    {"params": [p for n,p in model.named_parameters() if p.requires_grad], "lr": BASE_LR, "weight_decay": WEIGHT_DECAY},
    {"params": fas_loss.beta, "lr": BASE_LR * FAS_BETA_LR_SCALE, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(params)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

best_state = None
best_score = -1e9
epochs_no_improve = 0



In [20]:
# ------------------
# Training / Validation loops
# ------------------
def run_epoch(loader, model, criterion_fas, train=True):
    if train: model.train()
    else: model.eval()
    y_true, y_pred, g_all = [], [], []
    total_loss, total_ce = 0.0, 0.0
    n_total = 0

    it = tqdm(loader, desc=("Train" if train else "Val"))
    for imgs, ys, gs, idxs in it:
        imgs, ys, gs, idxs = imgs.to(DEVICE), ys.to(DEVICE), gs.to(DEVICE), idxs.to(DEVICE)

        with torch.set_grad_enabled(train):
            logits = model(imgs)
            if LABEL_SMOOTH > 0 and train:
                # Optional CE smoothing just for reporting (FAS already uses CE internally)
                # We still keep criterion inside FAS as standard CE.
                targets = F.one_hot(ys, num_classes=2).float()
                ls_logits = logits.log_softmax(dim=1)
                smooth = (1 - LABEL_SMOOTH) * targets + LABEL_SMOOTH / 2
                ce_loss = -(smooth * ls_logits).sum(dim=1).mean()
            else:
                ce_loss = ce_report(logits, ys)

            if train:
                loss, weights, per_sample = criterion_fas(logits, ys, gs, idxs)
            else:
                # during eval, don't update EMA -> do a forward without update
                # simple hack: call criterion but stop EMA update via no_grad outer (already not train)
                loss, weights, per_sample = criterion_fas(logits, ys, gs, idxs)

            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

        total_loss += float(loss.item()) * imgs.size(0)
        total_ce   += float(ce_loss.item()) * imgs.size(0)
        n_total    += imgs.size(0)

        preds = logits.argmax(dim=1)
        y_true.append(ys.detach())
        y_pred.append(preds.detach())
        g_all.append(gs.detach())

        it.set_postfix(loss=f"{total_loss/n_total:.4f}")

    y_true = torch.cat(y_true)
    y_pred = torch.cat(y_pred)
    g_all  = torch.cat(g_all)

    mets = accuracies(y_true, y_pred, g_all)
    mets.update({
        "loss": total_loss / max(n_total,1),
        "ce": total_ce / max(n_total,1),
    })
    return mets

history = []

for epoch in range(1, MAX_EPOCHS+1):
    print(f"\nEpoch {epoch}/{MAX_EPOCHS}")

    train_m = run_epoch(train_loader, model, fas_loss, train=True)
    val_m   = run_epoch(val_loader,   model, fas_loss, train=False)

    # Composite early-stopping score
    val_score = composite_score(val_m["balanced"], val_m["gap"], GAP_W)
    scheduler.step(val_score)

    row = {"epoch": epoch, "lr": optimizer.param_groups[0]["lr"], 
           **{f"train_{k}":v for k,v in train_m.items()},
           **{f"val_{k}":v for k,v in val_m.items()},
           "val_score": val_score}
    history.append(row)

    # Log concise line
    print(f"train_acc={train_m['overall']:.3f} | val_acc={val_m['overall']:.3f} | "
          f"val_bal={val_m['balanced']:.3f} | val_gap={val_m['gap']:.3f} | "
          f"val_worst={val_m['worst_group']:.3f} | score={val_score:.4f}")

    # Early stopping on best composite score
    if val_score > best_score:
        best_score = val_score
        best_state = {"model": copy.deepcopy(model.state_dict()),
                      "fas": copy.deepcopy(fas_loss.state_dict()),
                      "epoch": epoch,
                      "score": best_score}
        torch.save(best_state, os.path.join(SAVE_DIR, "best.pt"))
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= EARLY_STOP_PATIENCE:
            print(f"Early stopping at epoch {epoch}. Best score: {best_score:.4f}")
            break

# Load best weights
if best_state is not None:
    model.load_state_dict(best_state["model"])
    fas_loss.load_state_dict(best_state["fas"])
    print(f"Loaded best checkpoint from epoch {best_state['epoch']} (score={best_state['score']:.4f})")

# Final validation summary
print("\n==== Final Validation (Best) ====")
val_best = run_epoch(val_loader, model, fas_loss, train=False)
for k,v in val_best.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")

# Optional: Test set
if test_loader is not None:
    print("\n==== Test ====")
    test_m = run_epoch(test_loader, model, fas_loss, train=False)
    for k,v in test_m.items():
        if isinstance(v, float):
            print(f"{k}: {v:.4f}")


Epoch 1/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:31<00:00,  9.22s/it, loss=0.7089]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:55<00:00,  2.94s/it, loss=0.4397]


train_acc=0.796 | val_acc=0.847 | val_bal=0.846 | val_gap=0.033 | val_worst=0.830 | score=0.8379

Epoch 2/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:35<00:00,  9.27s/it, loss=0.4616]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [01:01<00:00,  3.23s/it, loss=0.4379]


train_acc=0.865 | val_acc=0.861 | val_bal=0.863 | val_gap=0.143 | val_worst=0.792 | score=0.8273

Epoch 3/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:31<00:00,  9.21s/it, loss=0.3964]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:56<00:00,  2.98s/it, loss=0.4889]


train_acc=0.896 | val_acc=0.861 | val_bal=0.863 | val_gap=0.130 | val_worst=0.798 | score=0.8305

Epoch 4/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:30<00:00,  9.21s/it, loss=0.3901]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:58<00:00,  3.06s/it, loss=0.4363]


train_acc=0.900 | val_acc=0.874 | val_bal=0.874 | val_gap=0.004 | val_worst=0.872 | score=0.8733

Epoch 5/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:17<00:00,  9.03s/it, loss=0.3497]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:56<00:00,  2.97s/it, loss=0.4831]


train_acc=0.911 | val_acc=0.859 | val_bal=0.858 | val_gap=0.074 | val_worst=0.821 | score=0.8395

Epoch 6/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:32<00:00,  9.24s/it, loss=0.3383]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:57<00:00,  3.01s/it, loss=0.7810]


train_acc=0.915 | val_acc=0.781 | val_bal=0.786 | val_gap=0.321 | val_worst=0.625 | score=0.7056

Epoch 7/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:37<00:00,  9.29s/it, loss=0.3234]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:56<00:00,  3.00s/it, loss=0.4635]


train_acc=0.915 | val_acc=0.868 | val_bal=0.865 | val_gap=0.144 | val_worst=0.793 | score=0.8293

Epoch 8/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:38<00:00,  9.31s/it, loss=0.2319]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:57<00:00,  3.02s/it, loss=0.4963]


train_acc=0.950 | val_acc=0.886 | val_bal=0.884 | val_gap=0.096 | val_worst=0.836 | score=0.8604

Epoch 9/30


Train: 100%|██████████████████████████████████████████████████████████████| 75/75 [11:44<00:00,  9.39s/it, loss=0.1730]
Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:58<00:00,  3.07s/it, loss=0.6549]


train_acc=0.961 | val_acc=0.882 | val_bal=0.884 | val_gap=0.101 | val_worst=0.834 | score=0.8589
Early stopping at epoch 9. Best score: 0.8733
Loaded best checkpoint from epoch 4 (score=0.8733)

==== Final Validation (Best) ====


Val: 100%|████████████████████████████████████████████████████████████████| 19/19 [00:55<00:00,  2.93s/it, loss=0.4721]

overall: 0.8742
male: 0.8724
female: 0.8761
balanced: 0.8742
worst_group: 0.8724
gap: 0.0037
loss: 0.4721
ce: 0.2896





In [28]:
# ==== Robust TEST evaluation (works with/without 'per_group' in run_epoch output) ====
import numpy as np
import torch

assert test_loader is not None, "test_loader is missing. Build it before running this cell."


test_m = run_epoch(test_loader, model, fas_loss, train=False)


overall   = test_m.get("overall")
male_acc  = test_m.get("male")
female_acc= test_m.get("female")
balanced  = test_m.get("balanced")
worst     = test_m.get("worst_group")
gap       = test_m.get("gap")

# If any of male/female/balanced/worst/gap is missing, recompute from raw preds
need_recompute = any(v is None for v in [male_acc, female_acc, balanced, worst, gap])

if need_recompute:
    ys_all, yhat_all, g_all = [], [], []
    model.eval()
    with torch.no_grad():
        for imgs, ys, gs, idxs in test_loader:
            preds = model(imgs.to(DEVICE)).argmax(1).cpu().numpy()
            ys_all.extend(ys.numpy().tolist())
            yhat_all.extend(preds.tolist())
            g_all.extend(gs.numpy().tolist())

    ys_all   = np.array(ys_all)
    yhat_all = np.array(yhat_all)
    g_all    = np.array(g_all)

    if overall is None:
        overall = float((ys_all == yhat_all).mean())

    # class-wise (0 = male, 1 = female)
    m_mask = ys_all == 0
    f_mask = ys_all == 1
    male_acc   = float((yhat_all[m_mask] == ys_all[m_mask]).mean()) if m_mask.any() else float("nan")
    female_acc = float((yhat_all[f_mask] == ys_all[f_mask]).mean()) if f_mask.any() else float("nan")

    if balanced is None:
        balanced = float(np.nanmean([male_acc, female_acc]))
    if worst is None:
        worst = float(np.nanmin([male_acc, female_acc]))
    if gap is None:
        gap = float(abs(male_acc - female_acc))


print("\n==== External TEST ====")
print(f"overall:   {overall:.4f}")
print(f"male (0):  {male_acc:.4f}")
print(f"female(1): {female_acc:.4f}")
print(f"balanced:  {balanced:.4f}")
print(f"worst_grp: {worst:.4f}")
print(f"gap:       {gap:.4f}")

Val: 100%|████████████████████████████████████████████████████████████████| 16/16 [00:42<00:00,  2.67s/it, loss=0.3310]


==== External TEST ====
overall:   0.8799
male (0):  0.8819
female(1): 0.8775
balanced:  0.8797
worst_grp: 0.8775
gap:       0.0045



