## Installs and imports

In [None]:
!pip install -qqq wandb pytorch-lightning==1.9.3 torchmetrics

In [None]:

import numpy as np
import random

import torch
from torch.nn import functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchsummary import summary

from torchvision import transforms
from torchvision.datasets import ImageFolder

In [None]:
import pytorch_lightning as pl
import torchmetrics
pl.seed_everything(42)
torch.manual_seed(42)
np.random.seed(42)

import wandb

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer

import os
from typing import Any, Dict, cast

import matplotlib.pyplot as plt

from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassFBetaScore,
    AUROC,
    BinaryROC,
    Recall,
    Specificity,
)

from torchmetrics.classification.precision_recall_curve import BinaryPrecisionRecallCurve
from torchmetrics.utilities.data import dim_zero_cat

wandb.login()

INFO:lightning_fabric.utilities.seed:Global seed set to 42
[34m[1mwandb[0m: Currently logged in as: [33mmyntiuk[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
! pip install -qqq timm torchgeo

In [None]:
from torchgeo.datasets import unbind_samples
from torchgeo.models import get_weight
from torchgeo.trainers import utils
import timm

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
! cp /content/drive/MyDrive/minefree-class-128.zip /content/
! unzip -q /content/minefree-class-128.zip

In [None]:
! mv /content/minefree-class-128/train/bombed /content/minefree-class-128/train/1bombed
! mv /content/minefree-class-128/val/bombed /content/minefree-class-128/val/1bombed
! mv /content/minefree-class-128/train/not-bombed /content/minefree-class-128/train/0not-bombed
! mv /content/minefree-class-128/val/not-bombed /content/minefree-class-128/val/0not-bombed

## Define the sweep

In [None]:
sweep_config = {
    "method": "random", # grid, random
    "metric": {
      "name": "accuracy",
      "goal": "maximize"   
    },
    "parameters": {
        "epochs": {
            "values": [100]
        },
        "batch_size": {
            "values": [32, 64]
        },
        "dropout": {
            "values": [0., 0.2]
        },
        "weight_decay": {
            "values": [0, 0.00005, 0.0005]
        },
        "learning_rate": {
            "values": [1e-4, 1e-5]
        },
        "lr_scheduler": {
            "values": ["on_plateau"]
        },
        "optimizer": {
            "values": [
                "adamw",
                "sgd"
            ]
        },
        "model_name": {
            "values": [
                "resnet50",
                "vit_small_patch16_224"
            ]
        },
        "weights": {
            "values": [
                "imagenet",
                "sentinel2"
            ]
        },
        "num_layers_to_finetune": {
            "values": [None]
        },
        "learning_rate_schedule_patience": {
            "values": [5, 3]
        },
        "early_stop_patience": {
            "values": [10],
        },
        "normalize": {
            "values": [False, True]
        },
        "base_size": {
            "values": [64, 128]
        },
        "dro": {
            "values": ["up", None]
        }
        
    }
}

### DataModule

In [None]:
def compute_data_stats(data_path, transform, seed=0):
    unnormalized_image_data = ImageFolder(
        root=data_path, transform=transforms.Compose(transform)
    )
    # Normalize data using full data stats. It's a bit of a leakage.
    initial_loader = DataLoader(
        unnormalized_image_data,
        batch_size=len(unnormalized_image_data),
        shuffle=False
    )

    images, labels = next(iter(initial_loader))

    # shape of images = [b,c,w,h]
    mean, std = images.mean([0, 2, 3]), images.std([0, 2, 3])
    return mean, std

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, val_size=0.1, normalize=False, dro=False, base_size=64, num_workers=2):
      super().__init__()
      self.base_size = base_size
      self.dro = dro
      self.val_size = val_size
      self.data_dir = data_dir
      self.batch_size = batch_size
      self.num_workers = num_workers

      self.train_transforms_list = [
        transforms.Resize((self.base_size, self.base_size)),
        transforms.RandomRotation(45),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
      ]
      self.test_transforms_list = [
          transforms.Resize((self.base_size, self.base_size)),
          transforms.ToTensor(),
      ]
      if normalize:
        train_mean, train_std = compute_data_stats(
            os.path.join(self.data_dir, "train"),
            self.train_transforms_list,
        )
        self.train_transforms_list.append(transforms.Normalize(train_mean, train_std))
        self.test_transforms_list.append(transforms.Normalize(train_mean, train_std))
      self.train_transforms = transforms.Compose(self.train_transforms_list)
      self.test_transforms = transforms.Compose(self.test_transforms_list)

    def setup(self, stage=None):
      if stage in ["fit", "validate"] or stage is None:
        data = ImageFolder(os.path.join(self.data_dir, "train"),)
        n_val = int(np.floor(self.val_size * len(data)))
        self.train, self.validate = random_split(data, [len(data) - n_val, n_val])
        self.train.dataset.transform = self.train_transforms
        self.validate.dataset.transform = self.test_transforms

      if stage == "test" or stage is None:
        self.test = ImageFolder(os.path.join(self.data_dir, "val"),
                                self.test_transforms,)

    def train_dataloader(self):
      if self.dro == "up":  # Upsampling the minority class.
        labels = []
        for i in range(len(self.train)):
            item = self.train[i]
            labels.append(item[1])
        labels = np.array(labels)
        sample_weights = np.ones_like(labels)
        pos_label = 1
        neg_label = 0
        pos_prop = (labels == pos_label).mean()
        neg_prop = (labels == neg_label).mean()
        for i in range(len(labels)):
            if labels[i] == pos_label:
                sample_weights[i] /= pos_prop
            elif labels[i] == neg_label:
                sample_weights[i] /= neg_prop

        sampler = torch.utils.data.WeightedRandomSampler(
            sample_weights,
            num_samples=self.batch_size,
            replacement=True,
        )
        train = DataLoader(self.train,
                           batch_size=self.batch_size,
                           sampler=sampler,
                           num_workers=self.num_workers,
                           )
      else:
        train = DataLoader(self.train,
                           batch_size=self.batch_size,
                           num_workers=self.num_workers,
                           )
      return train

    def val_dataloader(self):
      val = DataLoader(self.validate,
                       batch_size=self.batch_size,
                       num_workers=self.num_workers,
                       )
      return val

    def test_dataloader(self):
      test = DataLoader(self.test,
                        batch_size=self.batch_size,
                        num_workers=self.num_workers,
                        )
      return test

### Metric - TPR @ fixed FPR

In [None]:
class TPR_at_FPR(BinaryPrecisionRecallCurve):
  def __init__(
    self,
    max_fpr,
    thresholds = None,
    ignore_index = None,
    **kwargs: Any,
  ) -> None:
    super().__init__(thresholds, ignore_index, validate_args=False, **kwargs)
    self.max_fpr = max_fpr

  def _compute_TPR_at_FPR(self, max_fpr, pred, target):
    """Return maximal possible TPR and the best threshold for the maximal FPR."""
    roc = BinaryROC(task="binary")
    fprs, tprs, thresholds = roc(pred, target)
    try:
      _, max_tpr, best_threshold = max(
          ((fpr, tpr, tresh) for fpr, tpr, tresh in zip(fprs, tprs, thresholds) if fpr <= max_fpr),
           key=lambda t: t[1]
      )
    except ValueError:
      max_tpr = torch.tensor(0.0, device=fprs.device, dtype=fprs.dtype)
      best_threshold = torch.tensor(0)
    if max_tpr == 0.0:
      best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)

    return max_tpr

  def compute(self):
    return self._compute_TPR_at_FPR(self.max_fpr, dim_zero_cat(self.preds), dim_zero_cat(self.target))

### ClassificationTask

In [None]:
class ClassificationTask(pl.LightningModule):
    """LightningModule for image classification.
    Supports any available `Timm model
    <https://rwightman.github.io/pytorch-image-models/>`_
    as an architecture choice. To see a list of available
    models, you can do:
    .. code-block:: python
        import timm
        print(timm.list_models())
    """

    def config_model(self) -> None:
        """Configures the model based on kwargs parameters passed to the constructor."""
       
        # Create model
        weights = self.hyperparams["weights"]
        imagenet_pretrained = weights == "imagenet"
        self.model = timm.create_model(
            self.hyperparams["model"],
            num_classes=self.hyperparams["num_classes"],
            in_chans=self.hyperparams["in_channels"],
            drop_rate=self.hyperparams["dropout"],
            pretrained=imagenet_pretrained,
        )

        # Load weights
        try:
            if not imagenet_pretrained:
                state_dict = get_weight(weights).get_state_dict(progress=True)
                self.model = utils.load_state_dict(self.model, state_dict)
                print(f"Loaded {weights} successfully.")
        except:
            pass

        num_layers_to_finetune = self.hyperparams["num_layers_to_finetune"]
        if num_layers_to_finetune is not None:
            for parameter in list(self.model.parameters())[:-num_layers_to_finetune]:
                parameter.requires_grad = False

    def config_task(self) -> None:
        """Configures the task based on kwargs parameters passed to the constructor."""
        self.config_model()

        if self.hyperparams["loss"] == "ce":
            self.loss: nn.Module = nn.CrossEntropyLoss()
        else:
            raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.")

    def __init__(self, **kwargs: Any) -> None:
        """Initialize the LightningModule with a model and loss function.
        Keyword Args:
            model: Name of the classification model use
            loss: Name of the loss function, accepts "ce", "jaccard", or "focal"
            weights: Either a weight enum, the string representation of a weight enum,
                True for ImageNet weights, False or None for random weights,
                or the path to a saved model state dict.
            num_classes: Number of prediction classes
            in_channels: Number of input channels to model
            learning_rate: Learning rate for optimizer
            learning_rate_schedule_patience: Patience for learning rate scheduler
        .. versionchanged:: 0.4
           The *classification_model* parameter was renamed to *model*.
        """
        super().__init__()

        # Creates `self.hparams` from kwargs
        self.save_hyperparameters()
        self.hyperparams = cast(Dict[str, Any], self.hparams)

        self.config_task()

        self.train_metrics = MetricCollection(
            {
              "AUROC": AUROC(
                  task="binary",
                  num_classes=self.hyperparams["num_classes"],
              ),
              "TPR": Recall(
                  task="binary",
                  average="macro",
                  num_classes=self.hyperparams["num_classes"],
              ),
              "TNR": Specificity(
                  task="binary",
                  average="macro",
                  num_classes=self.hyperparams["num_classes"],
              ),
              "TPR@FPR=0_2": TPR_at_FPR(
                  max_fpr=0.2,
              ),
              "TPR@FPR=0_1": TPR_at_FPR(
                  max_fpr=0.1,
              ),
              "TPR@FPR=0_05": TPR_at_FPR(
                  max_fpr=0.05,
              ),

            },
            prefix="train_",
        )
        self.val_metrics = self.train_metrics.clone(prefix="val_")
        self.test_metrics = self.train_metrics.clone(prefix="test_")

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward pass of the model.
        Args:
            x: input image
        Returns:
            prediction
        """
        return self.model(*args, **kwargs)

    def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
        """Compute and return the training loss.
        Args:
            batch: the output of your DataLoader
        Returns:
            training loss
        """
        batch = args[0]
        x, y = batch
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)

        loss = self.loss(y_hat, y)

        # by default, the train step logs every `log_every_n_steps` steps where
        # `log_every_n_steps` is a parameter to the `Trainer` object
        self.log("train_loss", loss, on_step=True, on_epoch=False)
        self.train_metrics(y_hat[:, 1], y)

        return cast(Tensor, loss)

    def training_epoch_end(self, outputs: Any) -> None:
        """Logs epoch-level training metrics.
        Args:
            outputs: list of items returned by training_step
        """
        self.log_dict(self.train_metrics.compute())
        self.train_metrics.reset()

    def validation_step(self, *args: Any, **kwargs: Any) -> None:
        """Compute validation loss and log example predictions.
        Args:
            batch: the output of your DataLoader
            batch_idx: the index of this batch
        """
        batch = args[0]
        batch_idx = args[1]
        x, y = batch
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)

        loss = self.loss(y_hat, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.val_metrics(y_hat[:, 1], y)

        if (
            batch_idx < 10
            and hasattr(self.trainer, "datamodule")
            and self.logger
            and hasattr(self.logger, "experiment")
        ):
            try:
                pred = y_hat_hard
                for key in x, y, pred:
                    key = key.cpu()
            except ValueError:
                pass

    def validation_epoch_end(self, outputs: Any) -> None:
        """Logs epoch level validation metrics.
        Args:
            outputs: list of items returned by validation_step
        """
        self.log_dict(self.val_metrics.compute())
        self.val_metrics.reset()

    def test_step(self, *args: Any, **kwargs: Any) -> None:
        """Compute test loss.
        Args:
            batch: the output of your DataLoader
        """
        batch = args[0]
        x, y = batch
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)

        loss = self.loss(y_hat, y)

        # by default, the test and validation steps only log per *epoch*
        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.test_metrics(y_hat[:, 1], y)

    def test_epoch_end(self, outputs: Any) -> None:
        """Logs epoch level test metrics.
        Args:
            outputs: list of items returned by test_step
        """
        self.log_dict(self.test_metrics.compute())
        self.test_metrics.reset()

    def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
        """Compute and return the predictions.
        Args:
            batch: the output of your DataLoader
        Returns:
            predicted softmax probabilities
        """
        batch = args[0]
        x, y = batch
        y_hat: Tensor = self(x).softmax(dim=-1)
        return y_hat

    def configure_optimizers(self) -> Dict[str, Any]:
        """Initialize the optimizer and learning rate scheduler.
        Returns:
            a "lr dict" according to the pytorch lightning documentation --
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        if self.hyperparams["optimizer"].lower() == "adamw":
            optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=self.hyperparams["learning_rate"],
                weight_decay=self.hyperparams["weight_decay"],
            )
        elif self.hyperparams["optimizer"].lower() == "sgd":
            optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=self.hyperparams["learning_rate"],
                weight_decay=self.hyperparams["weight_decay"],
            )
        if self.hyperparams["lr_scheduler"] == "on_plateau":
          scheduler = ReduceLROnPlateau(
                    optimizer,
                    patience=self.hyperparams["learning_rate_schedule_patience"],
                )
        else:
          scheduler = StepLR(
              optimizer,
              step_size=self.hyperparams["learning_rate_schedule_patience"],
          )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

## Train

In [None]:
wandb_namespace = ""
wandb_project_name = ""

In [None]:
# sweep_id = wandb.sweep(sweep_config, entity=wandb_namespace, project=wandb_project_name)
sweep_id = ""

In [None]:
def train():
  config_defaults = {
    "epochs": 100,
    "batch_size": 32,
    "dropout": 0.,
    "weight_decay": 0,
    "learning_rate": 1e-4,
    "optimizer": "adamw",
    "model_name": "resnet50",
    "weights": "imagenet",
    "num_layers_to_finetune": None,
    "learning_rate_schedule_patience": 5,
    "early_stop_patience": 10,
    "normalize": False,
    "base_size": 64, 
    "dro": None, 
    "lr_scheduler": "on_plateau",
  }
  # Initialize a new wandb run
  run = wandb.init(config=config_defaults, entity=wandb_namespace, project=wandb_project_name)
  wandb_logger = WandbLogger(
      entity=wandb_namespace, project=wandb_project_name, run_id=run.id,
      # log_model="all"
  )


  # Config is a variable that holds and saves hyperparameters and inputs
  config = wandb.config

  base_size = config.base_size if config.model_name != "vit_small_patch16_224" else 224
  data = DataModule(
      data_dir="/content/minefree-class-128",
      batch_size=config.batch_size,
      normalize=config.normalize,
      dro=config.dro,
      base_size=base_size
  )
  data.setup()

  if config.model_name == "resnet18":
      if config.weights is None or config.weights == "sentinel2":
          weights = "ResNet18_Weights.SENTINEL2_RGB_SECO"

  elif config.model_name == "resnet50":
      if config.weights is None or config.weights == "sentinel2":
          weights = "ResNet50_Weights.SENTINEL2_RGB_SECO"

  elif config.model_name == "vit_small_patch16_224":
      if config.weights is None or config.weights == "sentinel2":
          weights = "ViTSmall16_Weights.SENTINEL2_ALL_SECO"

  task = ClassificationTask(
      model=config.model_name,
      weights=config.weights,
      loss="ce",
      in_channels=3,
      num_classes=2,
      batch_size=config.batch_size,
      learning_rate=config.learning_rate,
      learning_rate_schedule_patience=config.learning_rate_schedule_patience,
      num_layers_to_finetune=config.num_layers_to_finetune,
      optimizer=config.optimizer,
      weight_decay=config.weight_decay,
      dropout=config.dropout,
      lr_scheduler=config.lr_scheduler
  )

  early_stop_callback = EarlyStopping(
        monitor='val_AUROC',
        patience=config.early_stop_patience,
        verbose=False,
        mode='min'
    )

  task.config_model()

  trainer = pl.Trainer(
    callbacks=[early_stop_callback],
    logger=wandb_logger,
    log_every_n_steps=5,
    gpus=-1,
    max_epochs=config.epochs,
  )

  trainer.fit(task, data)


## Sweep

In [None]:
wandb.agent(sweep_id, train,
            entity=wandb_namespace, project=wandb_project_name)

In [None]:
wandb.finish()