In [5]:
from typing import Tuple

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torch import nn
from torchvision.models import resnet50, ResNet50_Weights, ResNet
from torchvision.transforms import PILToTensor


def build_resnet50_pixel(pretrained: str = True) -> ResNet:
    """ResNet50 with custom classifier for testing normal DIRE"""
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Sequential(nn.Linear(2048, 128), nn.ReLU(inplace=True), nn.Linear(128, 2), nn.Softmax(dim=1))
    model.eval()
    return model


def preprocess_resnet50_pixel(img):
    weights = ResNet50_Weights.DEFAULT
    img = PILToTensor()(img)
    batch = weights.transforms()(img)
    return batch

MODEL_DICT = {
    #"resnet50_latent": build_resnet50_latent,
    "resnet50_pixel": build_resnet50_pixel,
    #"mlp": build_mlp,
    #"cnn": build_cnn,
}

def get_dataloaders(root: str, batch_size: int, shuffle: bool = True) -> Tuple[DataLoader, DataLoader, DataLoader]:
    dataset = ImageFolder(root, transform=preprocess_resnet50_pixel)
    train_dataset, val_dataset, test_dataset = random_split(dataset, lengths=[0.8, 0.1, 0.1])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

    return train_loader, val_loader, test_loader

In [83]:
import argparse
from typing import Tuple

import numpy as np
from torchmetrics.functional.classification import accuracy
from torchmetrics.functional import average_precision

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.transforms.functional import hflip
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger

class Classifier(pl.LightningModule):
    def __init__(self, model: str, optimizer: str, learning_rate: float) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.classifier = MODEL_DICT[model]()
        self.loss = nn.CrossEntropyLoss()

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
        dire, label = batch
        if np.random.rand() < 0.5:  # 50% chance for horizontal flip
            dire = hflip(dire)
        pred = self.classifier(dire)
        loss = self.loss(pred, label)
        acc = accuracy(pred.argmax(axis=1), label, task="binary")
        ap = average_precision(pred[:, 1], label, task="binary")
        metrics = {"val_loss": loss, "val_acc": acc, "val_ap": ap}
        self.log_dict(metrics)

        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        dire, label = batch
        pred = self.classifier(dire)
        loss = self.loss(pred, label)
        acc = accuracy(pred.argmax(axis=1), label, task="binary")
        ap = average_precision(pred[:, 1], label, task="binary")
        metrics = {"val_loss": loss, "val_acc": acc, "val_ap": ap}
        self.log_dict(metrics)

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        dire, label = batch
        pred = self.classifier(dire)
        loss = self.loss(pred, label)
        acc = accuracy(pred.argmax(axis=1), label, task="binary")
        ap = average_precision(pred[:, 1], label, task="binary")
        metrics = {"val_loss": loss, "val_acc": acc, "val_ap": ap}
        self.log_dict(metrics)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        optimizer = Adam if self.hparams.optimizer == "Adam" else SGD
        if self.hparams.model == "resnet50_pixel":
            optimizer = optimizer(self.classifier.fc.parameters(), lr=self.hparams.learning_rate)
        else:
            optimizer = optimizer(self.classifier.parameters(), lr=self.hparams.learning_rate)

        lr_scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=2)

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler, "monitor": "val_acc"}


def main(args: argparse.Namespace) -> None:
    seed_everything(33914, workers=True)

    # Setup Weights & Biases
    wandb_logger = WandbLogger(project="Training", entity="latent-dire", config=vars(args))

    # Load the data
    train_loader, val_loader, test_loader = get_dataloaders(args.data_dir, args.batch_size, shuffle=True)

    # Setup callbacks
    early_stop = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.0, patience=5, verbose=True)
    checkpoint = ModelCheckpoint(save_top_k=2, monitor="val_acc", mode="max", dirpath="models/")
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    bar = TQDMProgressBar()

    clf = Classifier(args.model, args.optimizer, args.learning_rate)
    trainer = Trainer(
        fast_dev_run=args.dev_run,  # uncomment to debug
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices="auto",  # use all available GPUs
        min_epochs=1,
        max_epochs=args.max_epochs,
        callbacks=[bar, early_stop, checkpoint, lr_monitor],
        # deterministic=True,  # slower, but reproducable: https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
        precision="16-mixed",
        default_root_dir="models/",
        logger=wandb_logger,
    )
    trainer.fit(clf, train_loader, val_loader)
    trainer.test(clf, test_loader)


In [84]:
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dev_run", action="store_true", help="Whether to run a test run.")
parser.add_argument("--model", type=str, default="resnet50_pixel")
parser.add_argument("--latent", type=bool, default=False, help="Whether to use Latent DIRE")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--max_epochs", type=int, default=100)
parser.add_argument("--use_early_stopping", type=int, default=1, help="Whether to use early stopping.")
parser.add_argument("--optimizer", type=str, default="Adam", choices=["Adam", "SGD"], help="Optimizer to use")
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--data_dir", type=str, default="../data/data_dev")
args = parser.parse_known_args()[0]

In [85]:
main(args)

[rank: 0] Global seed set to 33914
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | classifier | ResNet           | 23.8 M
1 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
262 K     Trainable params
23.5 M    Non-trainable params
23.8 M    Total params
95.082    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_acc improved. New best score: 0.967


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_acc did not improve in the last 5 records. Best score: 0.967. Signaling Trainer to stop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]