In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/

In [None]:
!pip install -q timm pytorch_lightning

In [None]:
import os
import random
import zipfile
import gdown
from PIL import Image
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from timm import create_model
import pytorch_lightning as pl
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import BinaryAccuracy


pl.seed_everything(42)


INFO:lightning_fabric.utilities.seed:Seed set to 42


42

In [None]:
if os.path.exists("logs/"):
    !rm -r logs/

In [None]:
DTYPE = torch.int

def transform_data():
    transform = transforms.Compose(
        [
            transforms.CenterCrop(178),
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    return transform

class CelebADataset(Dataset):
    def __init__(
        self, root_dir: str, spurious_label: str, split: int, transforms=None
    ) -> None:
        super().__init__()

        self.root_dir = root_dir
        self.split = split
        self.spurious_label = spurious_label
        if transforms is None:
            self.transforms = transform_data()
        else:
            self.transforms = transforms

        self.metadata = pd.read_csv(
            os.path.join(root_dir, "list_attr_celeba.csv"),
            delim_whitespace=True,
            header=1,
            index_col=0,
        )
        for c in self.metadata.columns:
            self.metadata[c] = self.metadata[c].apply(lambda x: 0 if x == -1 else 1)

        df_partitions = pd.read_csv(
            os.path.join(root_dir, "list_eval_partition.csv"),
            delim_whitespace=True,
            header=None,
            index_col=0,
        )
        indices = df_partitions[df_partitions.iloc[:, 0] == split].index

        self.metadata = self.metadata[["Blond_Hair", spurious_label]].loc[indices, :]
        self.metadata.rename(
            columns={"Blond_Hair": "label", spurious_label: "spurious_label"},
            inplace=True,
        )
        self.metadata["group"] = self.metadata.apply(
            lambda x: 2 * x["label"] + x["spurious_label"], axis=1
        )

        self.groups = torch.as_tensor(self.metadata["group"].values, dtype=DTYPE)
        self.group_counts = (
            torch.arange(4, dtype=torch.int).unsqueeze(1).eq(self.groups).sum(dim=1)
        )
        self.labels = torch.as_tensor(self.metadata["label"].values, dtype=DTYPE)
        self.label_counts = (
            torch.arange(2, dtype=torch.int).unsqueeze(1).eq(self.labels).sum(dim=1)
        )

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        sample = self.metadata.iloc[index, :]

        filename = os.path.join(self.root_dir, "img_align_celeba/", sample.name)
        image = Image.open(filename).convert("RGB")
        image = self.transforms(image)

        items = {}
        items["filename"] = sample.name
        items["image"] = image
        items["label"] = torch.as_tensor(sample["label"], dtype=DTYPE)
        items["spurious_label"] = torch.as_tensor(sample["spurious_label"], dtype=DTYPE)
        items["group"] = torch.as_tensor(sample["group"], dtype=DTYPE)
        return items


In [None]:
def stratified_sampler(dataset: Dataset):
    weights = len(dataset) / dataset.label_counts
    weights = weights[dataset.labels.long()]

    sampler = WeightedRandomSampler(
        weights,
        len(dataset),
        replacement=True,
    )
    return sampler


In [None]:
class CelebADataModule(pl.LightningDataModule):
    def __init__(self, root_dir: str, spurious_label: str, stratified_sampling, transforms, batch_size, num_workers):
        super().__init__()

        self.root_dir = root_dir
        self.spurious_label = spurious_label
        if transforms is None:
            self.transforms = transform_data
        else:
            self.transforms = transforms
        self.stratified_sampling = stratified_sampling

        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        files = {
            "list_eval_partition.csv": "1kDqtHZHpYMe7rt1zu9pOevzUbApkDNRa",
            "list_attr_celeba.csv": "1s8CyrddcxHdvwro-_M25H7uxsDWL_1Bs",
            "img_align_celeba.zip": "1mGM-w9373aW5UJ27xa5oAsesL06JOe3h",
        }
        os.makedirs(self.root_dir, exist_ok=True)
        for file, file_id in files.items():
            print(f"Downloading {file}...")
            path = os.path.join(self.root_dir, file)
            if not os.path.exists(path):
                gdown.download(id=file_id, output=path, quiet=True)
            else:
                print("This file already exists.")

            if (
                file.endswith("zip") and
                not os.path.exists(os.path.join(self.root_dir, file.replace(".zip", "")))
            ):
                print(f"Unzipping {file}...")
                with zipfile.ZipFile(path, "r") as zip_ref:
                    zip_ref.extractall(self.root_dir)


    def setup(self, stage: str = None):
        self.train_dataset = CelebADataset(self.root_dir, self.spurious_label, 0)
        self.val_dataset = CelebADataset(self.root_dir, self.spurious_label, 1)
        # self.test_dataset = CelebADataset(self.root_dir, self.spurious_label, 2)

    def train_dataloader(self):
        shuffle = False
        sampler = None
        if self.stratified_sampling:
            sampler = stratified_sampler(self.train_dataset)
        else:
            shuffle = True

        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            # prefetch_factor=2 * (self.batch_size // self.num_workers),
            pin_memory=True,
            shuffle=shuffle,
            sampler=sampler,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            # prefetch_factor=2 * (self.batch_size // self.num_workers),
            pin_memory=True,
            shuffle=False,
        )

In [None]:
class WorstGroupAccuracy(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("correct", torch.zeros(4), dist_reduce_fx="sum")
        self.add_state("total", torch.zeros(4), dist_reduce_fx="sum")

    def update(self, y_pred, y_true, g):
        self.total += torch.arange(4, device=self.device).unsqueeze(1).eq(g).sum(dim=1)

        is_correct = y_true == y_pred
        for i in range(4):
            indices = torch.nonzero(g == i).squeeze(1)
            self.correct[i] += is_correct[indices].sum()

    def compute(self):
        x = self.correct.float() / self.total
        x[x.isnan()] = 0.0
        wg_acc = x.min()
        return wg_acc, x

In [None]:
class ResNet(pl.LightningModule):
    def __init__(
        self,
        hparams,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.save_hyperparameters()
        self.model = create_model(**hparams["model"])

        self.optimizer_config = hparams["optimizer"]

        self.loss_fn = nn.BCEWithLogitsLoss()

        self.train_accuracy = BinaryAccuracy()
        self.valid_accuracy = BinaryAccuracy()
        self.train_wga = WorstGroupAccuracy()
        self.valid_wga = WorstGroupAccuracy()

    def forward(self, batch, validation: bool = False):
        batch.pop("filename")
        x = batch["image"]
        g = batch["group"]
        y_true = batch["label"].float()
        if y_true.ndim < 2:
            y_true.unsqueeze_(1)

        logits = self.model(x)
        if logits.size(1) < 2:
            y_pred = (torch.sigmoid(logits) >= 0.5).float()
        else:
            y_pred = torch.argmax(torch.softmax(logits, dim=1), dim=1).unsqueeze(1)

        loss = self.loss_fn(logits, y_true)
        self.log(f"{self.logging_prefix}/loss", loss, prog_bar=self.log_progress_bar, logger=True, sync_dist=True)

        if validation:
            self.valid_accuracy.update(y_pred, y_true)
            self.valid_wga.update(y_pred, y_true, g)
        else:
            self.train_accuracy.update(y_pred, y_true)
            self.train_wga.update(y_pred, y_true, g)

            avg_acc = self.train_accuracy.compute(y_pred, y_true)
            self.log(
                f"{self.logging_prefix}/step:avg_acc",
                avg_acc,
                on_step=self.log_on_step,
                on_epoch=self.log_on_epoch,
                logger=True,
            )

            wga, acc_groups = self.train_wga.compute(y_pred, y_true, g)
            self.log(f"{self.logging_prefix}/step:wga", wga, on_step=self.log_on_step, logger=True)
            for k, v in enumerate(acc_groups):
                self.log(f"{self.logging_prefix}/step:acc_grp{k}", v, on_step=self.log_on_step, logger=True)
        return loss

    def training_step(self, batch, batch_idx):
        self.logging_prefix = "train"
        self.log_on_step = True
        self.log_on_epoch = False
        loss = self(batch)
        self.log_on_step = None
        self.log_on_epoch = None
        self.logging_prefix = None
        return loss

    def on_train_epoch_end(self):
        self.log(
            "val/epoch:avg_acc",
            self.train_accuracy.compute(),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
            logger=True,
        )

        wga, acc_groups = self.train_wga.compute()
        self.log(
            "train/epoch:wga",
            wga,
            prog_bar=True,
            on_step=False,
            on_epoch=True,
            logger=True,
        )
        for k, v in enumerate(acc_groups):
            self.log(
                f"train/epoch:acc_grp{k}",
                v,
                on_step=False,
                on_epoch=True,
                logger=True,
            )

        self.train_accuracy.reset()
        self.train_wga.reset()

    def validation_step(self, batch, batch_idx):
        self.logging_prefix = "val"
        self.log_progress_bar = True
        loss = self(batch, validation=True)
        self.logging_prefix = None
        self.log_progress_bar = False
        return loss

    def on_validation_epoch_end(self):
        self.log(
            "val/epoch:avg_acc",
            self.valid_accuracy.compute(),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
            logger=True,
        )

        wga, acc_groups = self.valid_wga.compute()
        self.log(
            "val/epoch:wga",
            wga,
            prog_bar=True,
            on_step=False,
            on_epoch=True,
            logger=True,
        )
        for k, v in enumerate(acc_groups):
            self.log(
                f"val/epoch:acc_grp{k}",
                v,
                on_step=False,
                on_epoch=True,
                logger=True,
            )

        self.valid_accuracy.reset()
        self.valid_wga.reset()

    @property
    def num_training_steps(self) -> int:
        """Get number of training steps"""
        if self.trainer.max_steps > -1:
            return self.trainer.max_steps

        self.trainer.fit_loop.setup_data()
        num_devices = max(1, self.trainer.num_devices)
        dataset_size = len(self.trainer.train_dataloader)
        num_steps = dataset_size * self.trainer.max_epochs // (self.trainer.accumulate_grad_batches * num_devices)
        return num_steps

    def configure_optimizers(self):
        optimizer_name = self.optimizer_config.pop("optimizer_name", "AdamW")
        optimizer = getattr(torch.optim, optimizer_name)
        optimizer = optimizer(self.model.parameters(), **self.optimizer_config)

        t = [1] * self.num_training_steps
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [lambda x: t.__getitem__(x)])
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]


In [None]:
hparams = {
    "dataset": {
        "root_dir": "celeba/",
        "spurious_label": "Male",
        "stratified_sampling": True,
        "batch_size": 128,
        "num_workers": 2,
        "transforms": None,
    },
    "optimizer": {
        "optimizer_name": "SGD",
        "lr": 1e-5,
        "weight_decay": 1.0,
        "momentum": 0.9,
    },
    "model": {
        "model_name": "resnet18",
        "pretrained": True,
        "num_classes": 1,
    },
    "max_epochs": 10,
    "save_dir": "logs/",
}

# from argparse import (Namespace)
# hparams = Namespace(**hparams)

# import json
# import types
# def load_object(dct):
#     return types.SimpleNamespace(**dct)
# hparams = json.loads(json.dumps(hparams), object_hook=load_object)

experiment_name = "celeba_" + str(random.randint(1e4, 1e5))


In [None]:
datamodule = CelebADataModule(**hparams["dataset"])

model = ResNet(hparams)

ckpt_callback = pl.callbacks.ModelCheckpoint(
    dirpath=os.path.join(hparams["save_dir"], experiment_name),
    filename="checkpoint-{epoch:03d}-{val/epoch:avg_acc:.3f}-{val/epoch:wga:.3f}",
    monitor="val/epoch:wga",
    save_last=True,
    save_top_k=1,
    mode="max",
    auto_insert_metric_name=False,
)

progress_bar = pl.callbacks.progress.TQDMProgressBar(refresh_rate=1)
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
logger = pl.loggers.TensorBoardLogger(save_dir=f"logs/{experiment_name}/")

use_cuda = torch.cuda.is_available()
accelerator = "cuda" if use_cuda else "cpu"

trainer = pl.Trainer(
    accelerator=accelerator,
    devices="auto",
    benchmark=True,
    enable_progress_bar=True,
    log_every_n_steps=1,
    num_sanity_val_steps=1,
    check_val_every_n_epoch=1,
    max_epochs=hparams["max_epochs"],
    callbacks=[ckpt_callback, progress_bar, lr_callback],
    logger=logger,
)
trainer.fit(
    model=model,
    datamodule=datamodule,
    ckpt_path=None,
)
