### Download & extract PACS dataset (optional)

In [1]:
# Install deps (safe to re-run)
!pip -q install datasets pillow tqdm

import os, io
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

root = "/content/data/PACS"
os.makedirs(root, exist_ok=True)

ds = load_dataset("flwrlabs/pacs", split="train")

if "category" in ds.features:
    label_field = "category"
elif "class" in ds.features:
    label_field = "class"
elif "label" in ds.features:
    label_field = "label"
else:
    raise RuntimeError(f"Unexpected schema: {ds.features}")

def class_name(row):
    v = row[label_field]
    if isinstance(v, int):
        return ds.features[label_field].names[v]
    return str(v)

# Normalize domains → art_painting/cartoon/photo/sketch
def norm_domain(v: str):
    s = str(v).strip().lower().replace(" ", "_").replace("-", "_")
    if s in {"art_painting", "cartoon", "photo", "sketch"}:
        return s
    # map common variants just in case
    if s in {"artpainting", "art_paintings"}:
        return "art_painting"
    return s  # fallback (we'll skip unknowns below)

# Write images to /content/data/PACS/<domain>/<class>/<i>.jpg
for i, row in tqdm(enumerate(ds), total=len(ds)):
    dom = norm_domain(row["domain"])
    if dom not in {"art_painting", "cartoon", "photo", "sketch"}:
        continue  # skip anything weird
    cls = class_name(row)
    out_dir = os.path.join(root, dom, cls)
    os.makedirs(out_dir, exist_ok=True)

    img = row["image"]
    if not isinstance(img, Image.Image):
        # Some datasets provide bytes; convert to PIL.Image
        img = Image.open(io.BytesIO(img["bytes"])).convert("RGB")
    img.save(os.path.join(out_dir, f"{i}.jpg"), quality=95)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

100%|██████████| 9991/9991 [00:14<00:00, 681.31it/s]


### Imports, constants, and seed

In [2]:
import torch, random, os, numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np
import pandas as pd
import os

DATA_ROOT = "/content/data/PACS"
SOURCES   = ["art_painting", "cartoon", "photo"]
TARGET    = "sketch"
IMG_SIZE  = 224
BATCH_SIZE = 64
NUM_WORKERS = 2

# seed set for reproducibility
def set_seed(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1337)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [3]:
_to_rgb = transforms.Lambda(lambda im: im.convert("RGB"))

def make_loaders(data_root, img_size=224, batch_size=64, num_workers=2, sources=None, target=None):
    tfm_train = transforms.Compose([
        _to_rgb,
        transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
    tfm_eval = transforms.Compose([
        _to_rgb,
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])

    def load_domain(name, tfm):
        p = Path(data_root)/name
        assert p.exists(), f"Missing domain folder: {p}"
        return datasets.ImageFolder(str(p), transform=tfm)

    # Build datasets
    src_train, per_domain_eval = [], {}
    class_to_idx = None

    for d in sources:
        ds_tr = load_domain(d, tfm_train)
        ds_ev = load_domain(d, tfm_eval)
        if class_to_idx is None:
            class_to_idx = ds_tr.class_to_idx
        else:
            assert ds_tr.class_to_idx == class_to_idx, "Class mapping differs across domains."
        assert ds_ev.class_to_idx == class_to_idx
        src_train.append(ds_tr)
        per_domain_eval[d] = ds_ev

    target_eval = load_domain(target, tfm_eval)
    assert target_eval.class_to_idx == class_to_idx
    per_domain_eval[target] = target_eval

    # Loaders
    train_ds = ConcatDataset(src_train)
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True,
        persistent_workers=(num_workers > 0), drop_last=True
    )
    eval_loaders = {
        d: DataLoader(ds, batch_size=batch_size, shuffle=False,
                      num_workers=num_workers, pin_memory=True,
                      persistent_workers=(num_workers > 0))
        for d, ds in per_domain_eval.items()
    }

    num_classes = len(target_eval.classes)
    return train_loader, eval_loaders, num_classes, target_eval.classes

train_loader, eval_loaders, num_classes, classes = make_loaders(
    DATA_ROOT, IMG_SIZE, BATCH_SIZE, NUM_WORKERS, SOURCES, TARGET
)
print(f"Train size: {len(train_loader.dataset)} | Num classes: {num_classes}")
print("Domains loaded:", list(eval_loaders.keys()))
print("Classes:", classes)

Train size: 6062 | Num classes: 7
Domains loaded: ['art_painting', 'cartoon', 'photo', 'sketch']
Classes: ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']


In [4]:
def build_resnet50(num_classes: int):
    # pretrained ResNet-50, replace final layer
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

@torch.no_grad()
def evaluate(model, loaders, device):
    model.eval()
    acc = {}
    for dname, loader in loaders.items():
        correct, total = 0, 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total   += y.numel()
        acc[dname] = correct / max(total, 1)
    return acc


def train_erm(
    sources,
    target,
    train_loader,
    eval_loaders,
    num_classes,
    epochs=20,
    lr=3e-4,
    wd=0.05,
    out_dir="outputs_erm",
    seed=1337,
    use_amp=True,
):
    np.random.seed(seed); torch.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_resnet50(num_classes).to(device)

    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and device.type=="cuda"))

    best_target = 0.0
    logs = []
    os.makedirs(out_dir, exist_ok=True)

    for ep in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        seen = 0

        pbar = tqdm(train_loader, desc=f"Epoch {ep}/{epochs}", leave=False)
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=(use_amp and device.type=="cuda")):
                logits = model(x)
                loss = loss_fn(logits, y)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            running_loss += loss.item() * y.size(0)
            seen += y.size(0)

        sched.step()

        acc = evaluate(model, eval_loaders, device)
        src_accs = [acc[d] for d in sources]
        avg_src = float(np.mean(src_accs))
        worst_src = float(min(src_accs))

        # Log
        row = {
            "epoch": ep,
            "train_loss": running_loss / max(seen, 1),
            "target_acc": acc[target],
            "avg_source_acc": avg_src,
            "worst_source_acc": worst_src,
        }
        for d in sources + [target]:
            row[f"acc_{d}"] = acc[d]
        logs.append(row)

        print(f"[Ep {ep:02d}] loss={row['train_loss']:.4f} | "
              f"tgt({target})={acc[target]:.3f} | src_avg={avg_src:.3f} | worst_src={worst_src:.3f}")

        # Save best-by-target
        if acc[target] > best_target:
            best_target = acc[target]
            torch.save(model.state_dict(), os.path.join(out_dir, "best_model.pt"))

    df = pd.DataFrame(logs)
    df.to_csv(os.path.join(out_dir, "training_log.csv"), index=False)
    print(f"\nBest target ({target}) accuracy: {best_target:.3f}")
    return df


In [5]:
# Uses the loaders you already created:
# train_loader, eval_loaders, num_classes, classes = make_loaders(...)

df_logs = train_erm(
    sources=SOURCES,
    target=TARGET,
    train_loader=train_loader,
    eval_loaders=eval_loaders,
    num_classes=num_classes,
    epochs=20,
    lr=3e-4,
    wd=0.05,
    out_dir="outputs_erm",
    seed=1337,
    use_amp=True,   # set False if you hit AMP issues
)

# Peek last few rows
df_logs.tail()


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 156MB/s]
  scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and device.type=="cuda"))
  with torch.cuda.amp.autocast(enabled=(use_amp and device.type=="cuda")):


[Ep 01] loss=0.4052 | tgt(sketch)=0.711 | src_avg=0.967 | worst_src=0.958




[Ep 02] loss=0.1278 | tgt(sketch)=0.582 | src_avg=0.976 | worst_src=0.965




[Ep 03] loss=0.0820 | tgt(sketch)=0.663 | src_avg=0.985 | worst_src=0.980




[Ep 04] loss=0.0750 | tgt(sketch)=0.601 | src_avg=0.990 | worst_src=0.985




[Ep 05] loss=0.0539 | tgt(sketch)=0.549 | src_avg=0.984 | worst_src=0.967




[Ep 06] loss=0.0448 | tgt(sketch)=0.701 | src_avg=0.996 | worst_src=0.992




[Ep 07] loss=0.0387 | tgt(sketch)=0.734 | src_avg=0.995 | worst_src=0.987




[Ep 08] loss=0.0231 | tgt(sketch)=0.644 | src_avg=0.994 | worst_src=0.991




[Ep 09] loss=0.0165 | tgt(sketch)=0.697 | src_avg=0.998 | worst_src=0.998




[Ep 10] loss=0.0224 | tgt(sketch)=0.694 | src_avg=0.997 | worst_src=0.996




[Ep 11] loss=0.0146 | tgt(sketch)=0.724 | src_avg=1.000 | worst_src=0.999




[Ep 12] loss=0.0089 | tgt(sketch)=0.707 | src_avg=0.999 | worst_src=0.998




[Ep 13] loss=0.0076 | tgt(sketch)=0.729 | src_avg=1.000 | worst_src=1.000




[Ep 14] loss=0.0064 | tgt(sketch)=0.700 | src_avg=1.000 | worst_src=1.000




[Ep 15] loss=0.0026 | tgt(sketch)=0.689 | src_avg=1.000 | worst_src=1.000




[Ep 16] loss=0.0032 | tgt(sketch)=0.729 | src_avg=1.000 | worst_src=1.000




[Ep 17] loss=0.0044 | tgt(sketch)=0.728 | src_avg=1.000 | worst_src=1.000




[Ep 18] loss=0.0039 | tgt(sketch)=0.736 | src_avg=1.000 | worst_src=1.000




[Ep 19] loss=0.0029 | tgt(sketch)=0.728 | src_avg=1.000 | worst_src=1.000




[Ep 20] loss=0.0041 | tgt(sketch)=0.729 | src_avg=1.000 | worst_src=1.000

Best target (sketch) accuracy: 0.736


Unnamed: 0,epoch,train_loss,target_acc,avg_source_acc,worst_source_acc,acc_art_painting,acc_cartoon,acc_photo,acc_sketch
15,16,0.003183,0.729448,1.0,1.0,1.0,1.0,1.0,0.729448
16,17,0.00443,0.72843,1.0,1.0,1.0,1.0,1.0,0.72843
17,18,0.003913,0.735811,1.0,1.0,1.0,1.0,1.0,0.735811
18,19,0.002883,0.728175,1.0,1.0,1.0,1.0,1.0,0.728175
19,20,0.00406,0.729193,1.0,1.0,1.0,1.0,1.0,0.729193


In [6]:
import pandas as pd

log_path = "outputs_erm/training_log.csv"
df = pd.read_csv(log_path)

best_idx = df["target_acc"].idxmax()
best = df.iloc[best_idx]

summary = {
    "best_epoch": int(best["epoch"]),
    "target_sketch_acc": round(float(best["target_acc"]), 4),
    "avg_source_acc": round(float(best["avg_source_acc"]), 4),
    "worst_source_acc": round(float(best["worst_source_acc"]), 4),
}
for d in SOURCES + [TARGET]:
    summary[f"{d}_acc"] = round(float(best[f"acc_{d}"]), 4)

summary


{'best_epoch': 18,
 'target_sketch_acc': 0.7358,
 'avg_source_acc': 1.0,
 'worst_source_acc': 1.0,
 'art_painting_acc': 1.0,
 'cartoon_acc': 1.0,
 'photo_acc': 1.0,
 'sketch_acc': 0.7358}

In [7]:
from torch.utils.data import DataLoader

def make_per_domain_train_loaders(data_root, img_size=224, batch_size=32, num_workers=2, sources=None):
    # same train transform as before
    tfm_train = transforms.Compose([
        _to_rgb,
        transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1,0.1,0.1,0.05),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
    loaders = {}
    class_to_idx = None
    for d in sources:
        ds = datasets.ImageFolder(str(Path(data_root)/d), transform=tfm_train)
        if class_to_idx is None: class_to_idx = ds.class_to_idx
        else: assert ds.class_to_idx == class_to_idx
        loaders[d] = DataLoader(ds, batch_size=batch_size, shuffle=True,
                                num_workers=num_workers, pin_memory=True,
                                persistent_workers=(num_workers>0), drop_last=True)
    return loaders


In [8]:
import itertools, math

def irm_penalty(loss, dummy_w):
    # gradient of loss wrt dummy scalar w, squared
    g = torch.autograd.grad(loss, [dummy_w], create_graph=True)[0]
    return torch.sum(g**2)

def train_irm(
    sources,
    target,
    per_domain_train_loaders,  # dict: domain -> loader
    eval_loaders,              # from your existing make_loaders (eval tfm)
    num_classes,
    epochs=20,
    lr=3e-4,
    wd=0.0,             # IRM often with little/no wd
    out_dir="outputs_irm",
    seed=1337,
    use_amp=True,
    lambda_start=100.0,  # starting IRM penalty weight
    warmup_epochs=1,     # keep lambda small then ramp
    lambda_multiplier=10.0
):
    torch.manual_seed(seed); np.random.seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_resnet50(num_classes).to(device)

    opt   = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    ce    = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and device.type=="cuda"))

    # iterators for each domain
    domain_iters = {d: itertools.cycle(loader) for d, loader in per_domain_train_loaders.items()}

    best_target = 0.0
    logs = []
    os.makedirs(out_dir, exist_ok=True)

    lam = lambda_start

    for ep in range(1, epochs+1):
        if ep == warmup_epochs+1:
            lam = lambda_start * lambda_multiplier  # ramp the penalty

        model.train()
        n_steps = min(len(loader) for loader in per_domain_train_loaders.values())
        running_loss = 0.0
        seen = 0

        for _ in tqdm(range(n_steps), desc=f"[IRM] Epoch {ep}/{epochs}", leave=False):
            opt.zero_grad(set_to_none=True)

            # gather a mini-batch from EACH source domain
            losses, penalties = [], []
            for d in sources:
                x, y = next(domain_iters[d])
                x, y = x.to(device), y.to(device)
                dummy_w = torch.tensor(1.0, requires_grad=True, device=device)

                with torch.amp.autocast('cuda', enabled=(use_amp and device.type=="cuda")):
                    logits = model(x) * dummy_w
                    loss_d = ce(logits, y)

                penalty_d = irm_penalty(loss_d, dummy_w)
                losses.append(loss_d)
                penalties.append(penalty_d)

            loss = torch.stack(losses).mean()
            penalty = torch.stack(penalties).mean()
            total = loss + lam * penalty

            scaler.scale(total).backward()
            scaler.step(opt)
            scaler.update()

            running_loss += total.item() * sum(y.size(0) for _, y in [next(iter(per_domain_train_loaders[d])) for d in sources])
            seen += sum(next(iter(per_domain_train_loaders[d]))[1].size(0) for d in sources)

        sched.step()

        # eval
        acc = evaluate(model, eval_loaders, device)
        src_accs = [acc[d] for d in sources]
        avg_src, worst_src = float(np.mean(src_accs)), float(min(src_accs))

        row = {
            "epoch": ep,
            "lambda": lam,
            "target_acc": acc[target],
            "avg_source_acc": avg_src,
            "worst_source_acc": worst_src,
        }
        for d in sources + [target]:
            row[f"acc_{d}"] = acc[d]
        logs.append(row)

        print(f"[IRM Ep {ep:02d}] tgt={acc[target]:.3f} | src_avg={avg_src:.3f} | worst_src={worst_src:.3f} | λ={lam:.1f}")

        if acc[target] > best_target:
            best_target = acc[target]
            torch.save(model.state_dict(), os.path.join(out_dir, "best_model.pt"))

    df = pd.DataFrame(logs)
    df.to_csv(os.path.join(out_dir, "training_log.csv"), index=False)
    print(f"\n✅ IRM best target ({target}) accuracy: {best_target:.3f}")
    return df


In [None]:
# Build per-domain train loaders
per_domain_train = make_per_domain_train_loaders(
    DATA_ROOT, img_size=IMG_SIZE, batch_size=32, num_workers=NUM_WORKERS, sources=SOURCES
)

# Reuse your existing eval_loaders (from make_loaders)
df_irm = train_irm(
    sources=SOURCES,
    target=TARGET,
    per_domain_train_loaders=per_domain_train,
    eval_loaders=eval_loaders,
    num_classes=num_classes,
    epochs=20,
    lr=3e-4,
    wd=0.0,
    out_dir="outputs_irm",
    lambda_start=100.0,
    warmup_epochs=1,
    lambda_multiplier=10.0,
)

# Summarize
best_idx = df_irm["target_acc"].idxmax()
best = df_irm.iloc[best_idx]
{
    "best_epoch": int(best["epoch"]),
    "target_sketch_acc": round(float(best["target_acc"]), 4),
    "avg_source_acc": round(float(best["avg_source_acc"]), 4),
    "worst_source_acc": round(float(best["worst_source_acc"]), 4),
    "lambda_used": float(best["lambda"]),
    **{f"{d}_acc": round(float(best[f"acc_{d}"]), 4) for d in SOURCES+[TARGET]}
}




[IRM Ep 01] tgt=0.201 | src_avg=0.197 | worst_src=0.190 | λ=100.0


[IRM] Epoch 2/20:  50%|█████     | 26/52 [03:14<03:21,  7.76s/it]