# Outline
This is a tutorial showing how to create, train and test a CNN classifier for astronomical image data using PyTorch:

- `Dataset`: Galaxy MNIST, galaxies of different morphologies observed in 3 optical bands
- `CNN models`: custom & ResNet
- `Task`: classify input images into 4 possible classes

The tutorial will cover these steps:

1) Set up the environment
2) Download the dataset and create data loaders and transformers/augmenters
3) Create a configurable CNN classifier
4) Train the classifier
5) Evaluate the classifier on test data

# Configuring the environment

## Module installation
Let's first install the python modules required for this tutorial.

In [None]:
###########################
##   IMGPROC MODULES
###########################
%pip install -q pillow opencv-python

###########################
##   ML MODULES
###########################
%pip install -q torch torchvision torchmetrics torchsummary scikit-learn tqdm
%pip install -q wandb -qqq
%pip install -q grad-cam

###########################
##   OTHER MODULES
###########################
%pip install -q shortuuid
%pip install -q gdown # gDrive
%pip install -q matplotlib

## Import modules

In [None]:
###########################
##   STANDARD MODULES
###########################
import os
from pathlib import Path
import shutil
import gdown
import tarfile
import numpy as np
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from itertools import islice
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import shortuuid

###########################
##   IMGPROC/TORCH MODULES
###########################
# - Image proc
import PIL
from PIL import Image
import cv2

# - Torch modules
import torch
from torch.utils.data import Dataset, Subset, random_split
import torchvision
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor
import torchmetrics
from torchsummary import summary

# - GradCAM
from pytorch_grad_cam import (
  GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
  AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
  LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM
)
from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
from pytorch_grad_cam.utils.image import show_cam_on_image

## Project folders
We create a working directory where to run the tutorial.

In [None]:
topdir= os.getcwd()
rundir= os.path.join(topdir, "run-gmnist_classifier")
path = Path(rundir)
path.mkdir(parents=True, exist_ok=True)

# Dataset
For this tutorial, we are going to use the [Galaxy MNIST dataset](https://github.com/mwalmsley/galaxy_mnist). 

The dataset currently contains 10,000 images of galaxies in three optical bands (grz), either 64x64x3 (low reso) or 224x224x3 (high reso), taken from the Dark Energy Camera Legacy Survey (DECaLS) Galaxy Zoo project.
DECaLS uses the Dark Energy Camera (DECam) at the 4m Blanco telescope in Chile. Fluxes in the grz bands were converted to RGB colours (see Section 2.3 of reference paper) and PNG images were created for each sample galaxy.   

The dataset is split into two subsets: 

- train: 8000 images
- test: 2000 images

The dataset contains 4 possible classes of galaxy morphologies:

- SMOOTH_ROUND: smooth and round galaxy. Should not have signs of spires.   
- SMOOTH_CIGAR: smooth and cigar-shaped galaxy, looks like being seen edge on. This should not have signs of spires of a spiral galaxy.
- EDGE_ON_DISK: edge-on-disk/spiral galaxy. This disk galaxy should have signs of spires, as seen from an edge-on perspective.
- UNBARRED_SPIRAL: unbarred spiral galaxy. Has signs of a disk and/or spires

Note that categories SMOOTH_CIGAR and EDGE_ON_DISK classes tend to be very similar to each other. To categorize them, ask yourself the following question: Is this galaxy very smooth, maybe with a small bulge? Then it belongs to class SMOOTH_CIGAR. Does it have irregularities/signs of structure? Then it belongs to class EDGE_ON_DISK.

In this tutorial we are going to use the high-reso images (224x224). The original dataset format was slighly modified. The modified dataset is available for download in Google Drive.

More details on the observational data and labelling are available in these references:

- [Galaxy Zoo DECaLS paper](https://ui.adsabs.harvard.edu/abs/2022MNRAS.509.3966W/abstract)    
- [Galaxy Zoo DECaLS data](https://zenodo.org/records/4573248)

## Dataset Download
We download the dataset from GoogleDrive URL and unzip it in the main folder.

In [None]:
# - Set dataset URL & paths
dataset_name= "galaxy_mnist-dataset"
dataset_dir= os.path.join(rundir, dataset_name)
dataset_tar= 'galaxy_mnist-dataset.tar.gz'
dataset_tar_fullpath= os.path.join(rundir, dataset_tar)
dataset_url= 'https://drive.google.com/uc?export=download&id=1OprJ_NQIFyQSRWqjGLFQsAMumHvJ-tMB'

# - Download dataset (if not previously downloaded)
if not os.path.isfile(dataset_tar_fullpath):
  print("Downloading file from url %s ..." % (dataset_url))
  gdown.download(dataset_url, dataset_tar, quiet=False)  
  print("DONE!")

# - Untar dataset
if not os.path.isdir(dataset_dir):
  print("Unzipping dataset file %s ..." % (dataset_tar))
  fp= tarfile.open(dataset_tar)
  fp.extractall('.')
  fp.close()   
  print("DONE!")

# - Moving data to rundir
if not os.path.isfile(dataset_tar_fullpath):
  print("Moving tar file to rundir %s ..." % (rundir))
  shutil.move(dataset_tar, rundir)

if not os.path.isdir(dataset_dir):
  print("Moving datadir to rundir %s ..." % (rundir))    
  shutil.move(dataset_name, rundir)

The dataset provides datalists for the train and test samples in a json format, for 1-channel (channel-averaged data) or 3-channel images. We are going to use the 3-channel data samples:

- `train/3chan/datalist_train.json`
- `test/3chan/datalist_test.json`

Datalists have these format:   
    

```json
{    
  "data": [    
    {    
      "filepaths": [
        "galaxy_mnist-dataset/train/3chan/train_1.png"
      ],
      "sname": "S1",
      "id": 1,
      "label": "smooth_cigar"
    },    
    ...
    ...
  ]
}
```

## Create PyTorch Dataset
We create a custom pytorch dataset for the Galaxy MNIST data using pytorch `Dataset` base class. For this we need to override these base methods:    


```__len__```: returning the size of the dataset.    
```__getitem__```: returning the i-th dataset sample (image and target).

In [None]:
class GMNISTDataset(Dataset):
  """ Galaxy MNIST dataset """

  def __init__(
      self, 
      metadata_file: Optional[Union[str, Path]] = "",
      subset: Optional[Subset] = None,
      transform: Optional[Callable] = None,
      target_transform: Optional[Callable] = None,
      data_path: Optional[Union[str, Path]] = "",
  ):
    # - Read metadata
    self.data_path= data_path
    self.subset= subset
    if self.subset is None:
        print("Reading dataset metadata from file %s ..." % (metadata_file))
        self.__read_metadata(metadata_file)
        
    # - Set pars
    self.transform = transform
    self.target_transform = target_transform
    self.pil2tensor = T.Compose([T.PILToTensor()]) # no normalization
    
    self.target2label= {
      0: "smooth_round",  
      1: "smooth_cigar",
      2: "edge_on_disk",
      3: "unbarred_spiral"
    }

  def __read_metadata(self, filename):
    """ Read json metadata """
    
    f= open(filename, "r")
    self.datalist= json.load(f)["data"]
  
  def __len__(self):
    """ Return size of dataset """ 
    if self.subset:
      return len(self.subset) 
    else:
      return len(self.datalist)
   
  def __load_item(self, idx):
    """ Load dataset item """
    
    # - Read image path & class id
    img_path= self.datalist[idx]['filepaths'][0]
    if self.data_path!="" and os.path.isdir(self.data_path):
      img_path= os.path.join(self.data_path, img_path)
    
    target= self.datalist[idx]['id'] # class id
    
    # - Read PIL image as RGB
    img = Image.open(img_path).convert("RGB")
    
    return img, target
    
  def __load_subset_item(self, idx):
    """ Load dataset subset item """
    
    # - Get item from subset
    #   NB: img is a tensor
    return self.subset[idx]

    
  def __getitem__(self, idx):
    """ Return dataset item """
    
    # - Load image/label
    if self.subset is None:
      img, target= self.__load_item(idx)
    else:
      img, target= self.subset[idx]

    # - Convert PIL to tensor?
    if isinstance(img, PIL.Image.Image):
      img= self.pil2tensor(img)
    
    # - Transform img/tensor?
    if self.transform is not None:
      img = self.transform(img)

    # - Transform target?
    if self.target_transform is not None:
      target = self.target_transform(target)
       
    return img, target

## Create data custom transforms
We define here a series of custom image transformations that we will apply to the data as augmentations. To do that just create a `nn.Module` and override the `forward` method, like the examples below:

### Random flip
A transform that flip either image horizontally/vertically or leave image unchanged.

In [None]:
class RandomFlip(torch.nn.Module):
  """ Flip image """

  def __init__(self):
    super().__init__()

  def forward(self, img):
    op= random.choice([1,2,3])
    if op==1:
      return TF.hflip(img)
    elif op==2:
      return TF.vflip(img)
    else:
      return img

### Random rotate
A transform that randomly rotate image by 90 degrees step.

In [None]:
class RandomRotate90(torch.nn.Module):
  """Rotate by one of the given angles: 90, 270, """

  def __init__(self):
    super().__init__()

  def forward(self, img):
    op= random.choice([1,2,3,4])
    if op==1:
      return TF.rotate(img, 90)
    elif op==2:
      return TF.rotate(img, 180)
    elif op==3:
      return TF.rotate(img, 270)
    elif op==4:
      return img

### Sanitization
A transform that set NaNs/inf pixels to 0.

In [None]:
class Sanitization(torch.nn.Module):
  """ Set NaN/inf pixels to 0 """

  def __init__(self):
    super().__init__()
   
  def forward(self, img):
    # - Create mask of non-nans pixels    
    cond= torch.isfinite(img)
    
    # - Set nans to 0
    img[~cond]= 0
    
    return img

### Absolute Channel Maximum Scaling
This transform finds, for each image, the absolute maximum, and then it scales all channels by this value, taking into account any possible band flux ratio information as sensitive classification variable.

In [None]:
class AbsChanMaxScaling(torch.nn.Module):
  """ Scale tensor by absolute channel maximum """

  def __init__(self):    
    super().__init__()
   
  def forward(self, img):
    
    # - Compute absolute image max across channels
    ndim= img.ndim
    if ndim==4: # [BATCH,CHAN,Ny,Nx]
      img_absmax= torch.amax(img, dim=(1,2,3), keepdim=True)
    elif ndim==3: # [CHAN,Ny,Nx]
      img_absmax= torch.amax(img, dim=(0,1,2), keepdim=True)
    else:
      logger.warn("Unexpected ndim (%d), returning same image ..." % (ndim))
      return img
    
    # - Scale image by absmax
    img_scaled= img/img_absmax
    
    return img_scaled

### My custom transform

In [None]:
## DEFINE YOUR OWN TRANSFORM HERE!

### Define composite transforms
Let's define two composite transforms: one for training data, having standard plus additional augmenter transforms, and the other for validation/test data, having standard transforms.

Standard transforms are:

- Sanitization
- Image resize
- Intra-channel normalization
- Image sample normalization (optional)

Augmenter transforms are:

- Random flipping
- Random rotation 90 deg
- Random crop and resize

In [None]:
# - Define dataset transforms
img_resize= 224

transform_train= T.Compose(
  [  
    Sanitization(),  
    T.Resize(img_resize, interpolation=T.InterpolationMode.BICUBIC),
    RandomFlip(),
    RandomRotate90(),  
    T.RandomResizedCrop(img_resize, scale=(0.5, 1.0), ratio=(1., 1.), interpolation=T.InterpolationMode.BICUBIC),
    AbsChanMaxScaling(), 
    #T.ToTensor(),# convert to tensor in range [0,1]
  ]
)

transform_imagenet_train= T.Compose(
  [  
    Sanitization(),  
    T.Resize(img_resize, interpolation=T.InterpolationMode.BICUBIC),
    RandomFlip(),
    RandomRotate90(), 
    T.RandomResizedCrop(img_resize, scale=(0.5, 1.0), ratio=(1., 1.), interpolation=T.InterpolationMode.BICUBIC),
    AbsChanMaxScaling(),  
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  
  ]
)

transform= T.Compose(
  [
    Sanitization(),  
    T.Resize(img_resize, interpolation=T.InterpolationMode.BICUBIC),
    AbsChanMaxScaling()
  ]
)

transform_imagenet= T.Compose(
  [
    Sanitization(),  
    T.Resize(img_resize, interpolation=T.InterpolationMode.BICUBIC),
    AbsChanMaxScaling(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))   
  ]
)

## Create datasets
Load GMNIST train/test dataset using the GMNISTDataset class created above. Then, split the train dataset into two subsets, one to be used as training set (70% of the original train sample) and the other as validation set (the remaining 30%).

In [None]:
# - Set train/test datalists
filename_train_3chan= os.path.join(dataset_dir, "train/3chan/datalist_train.json")
filename_test_3chan= os.path.join(dataset_dir, "test/3chan/datalist_test.json")

# - Read traincv dataset
print("Read train-cv dataset from file %s ..." % (filename_train_3chan))
dataset_traincv= GMNISTDataset(
  metadata_file=filename_train_3chan,
  data_path=rundir
)

# - Read test dataset
print("Read test dataset from file %s ..." % (filename_test_3chan))
dataset_test= GMNISTDataset(
  metadata_file=filename_test_3chan,
  transform=transform,
  data_path=rundir
)
dataset_imagenet_test= GMNISTDataset(
  metadata_file=filename_test_3chan,
  transform=transform_imagenet,
  data_path=rundir
)

# - Split train-cv dataset into train & validation samples
print("Splitting train-cv dataset in 70% train/30% val subsets...")
generator= torch.Generator().manual_seed(42)
subset_train, subset_val= random_split(dataset_traincv, [0.7, 0.3], generator=generator)

# - Create train & val datasets from subsets
print("Creating train & val datasets from subsets ...")
dataset_train= GMNISTDataset(subset=subset_train, transform=transform_train, data_path=rundir)
dataset_imagenet_train= GMNISTDataset(subset=subset_train, transform=transform_imagenet_train, data_path=rundir)
dataset_val= GMNISTDataset(subset=subset_val, transform=transform, data_path=rundir)
dataset_imagenet_val= GMNISTDataset(subset=subset_val, transform=transform_imagenet, data_path=rundir)

print("#%d entries in train set ..." % (len(dataset_train)))
print("#%d entries in validation set ..." % (len(dataset_val)))
print("#%d entries in test set ..." % (len(dataset_test)))

### Draw sample images
Let's draw some sample images from the train set.

In [None]:
# - Plot images
fig = plt.figure(figsize=(15, 15))
for i, (tensor_image, target) in islice(enumerate(dataset_train), 16):
  label= dataset_train.target2label[target]  
  ax = fig.add_subplot(4, 4, i+1)
  ax.set_xticks([]); ax.set_yticks([])
  im= ax.imshow(tensor_image.permute(1, 2, 0))
  ax.set_title(f'{label}', size=15)
  
plt.show()

## Create dataloaders
We are going to create a dataloader for train, validation and test data.

In [None]:
###############################
##    CREATE DATA LOADERS
###############################
# - Create data loaders
torch.manual_seed(1)
batch_size= 64
dataloader_train= torch.utils.data.DataLoader(
  dataset_train, 
  batch_size=batch_size,
  shuffle=True, 
  num_workers=1
)
dataloader_imagenet_train= torch.utils.data.DataLoader(
  dataset_imagenet_train, 
  batch_size=batch_size,
  shuffle=True, 
  num_workers=1
)

dataloader_val= torch.utils.data.DataLoader(
  dataset_val, 
  batch_size=batch_size,
  shuffle=False, 
  num_workers=1
)
dataloader_imagenet_val= torch.utils.data.DataLoader(
  dataset_imagenet_val, 
  batch_size=batch_size,
  shuffle=False, 
  num_workers=1
)

dataloader_test= torch.utils.data.DataLoader(
  dataset_test, 
  batch_size=8,
  shuffle=False, 
  num_workers=1
)
dataloader_imagenet_test= torch.utils.data.DataLoader(
  dataset_imagenet_test, 
  batch_size=8,
  shuffle=False, 
  num_workers=1
)

# - Test min/max
imgs, targets = next(iter(dataloader_test))
print("type(imgs)")
print(type(imgs))
print("imgs.shape")
print(imgs.shape)

data_min= torch.amin(imgs, dim=(2,3))
data_max= torch.amax(imgs, dim=(2,3))
data_absmax= torch.amax(imgs, dim=(1,2,3))
print("min: ", data_min)
print("max: ", data_max)
print("absmax: ", data_absmax)

# CNN classifier
We will create two CNN classifiers to perform image classification with the loaded dataset:

- ResNet architecture
- Custom architecture

We will show a complete example for the first architecture. For the custom architecture, we will provide some implementations and hints to allow the user to fully complete the exercise.

## ResNet classifier
Let's define a classifier class to perform GMNIST image classification using a pre-trained model, based on the ResNet architecture. For this, we are going to use the predefined models in `torchvision.models` and the torch `Sequential` class.

In [None]:
class ResNetClassifier():
  """ Build a ResNet classifier """  

  def __init__(
    self,
    nn_arch: Optional[str] = "resnet18",
    pretrained_weights: Optional[str] = None,
    num_classes: Optional[int] = 4,  
    n_dense_layers: Optional[int] = 1,
    dense_layer_sizes: Optional[Union[int, list]] = [64],
    add_dropout: Optional[bool] = True,
    dropout_prob: Optional[float] = 0.5
  ):
    """ Initialize class """
    
    self.model= None
    self.device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.nn_arch= nn_arch
    self.pretrained_weights= pretrained_weights
    self.num_classes= num_classes
    self.n_dense_layers= n_dense_layers
    self.add_dropout= add_dropout
    self.dropout_prob= dropout_prob
    
    # - Set dense layer size per layer
    if isinstance(dense_layer_sizes, list):
      if len(dense_layer_sizes)!=self.n_dense_layers:
        raise Exception("dense_layer_sizes list must have length equal to n_dense_layers!")  
      else:
        self.dense_layer_sizes= dense_layer_sizes
    else:
      self.dense_layer_sizes= [dense_layer_sizes]*self.n_dense_layers
    
    # - Build network
    if self.__build_model()<0:
      print("ERROR: Failed to build model!")
      raise Exception("Failed to build model!")
    
    # - Move model to device
    print("Moving model to device %s ..." % (self.device))
    self.model.to(self.device)
    
  def __build_model(self, pretrained_weights=None):
    """ Create network from pre-defined architecture (e.g. resnet) """
    
    # - Load predefined arch
    #   NB: Supported weights for resnet18/34: 'IMAGENET1K_V1'
    #       Supported weights for resnet50/101: {'IMAGENET1K_V1','IMAGENET1K_V2'}
    if self.nn_arch=="resnet18":
      self.model = torchvision.models.resnet18(weights=self.pretrained_weights) 
    elif self.nn_arch=="resnet34":
      self.model = torchvision.models.resnet34(weights=self.pretrained_weights)  
    elif self.nn_arch=="resnet50":
      self.model = torchvision.models.resnet50(weights=self.pretrained_weights)
    elif self.nn_arch=="resnet101":
      self.model = torchvision.models.resnet101(weights=self.pretrained_weights)
    else:
      print("ERROR: Unsupported nn arch (%s) specified, see torch supported arch below and add it yourself!")
      print(torchvision.models.list_models(module=torchvision.models)) 
      return -1

    # - Define classification head
    class_head= torch.nn.Sequential()
    
    for i in range(self.n_dense_layers):
      # - Add dense layer
      layer_name= "fc" + str(i+1)
      class_head.add_module(layer_name, torch.nn.LazyLinear(self.dense_layer_sizes[i]))
    
      # - Add activation
      layer_name= "relu_fc" + str(i+1)  
      class_head.add_module(layer_name, torch.nn.ReLU())
    
      # - Add dropout?
      if self.add_dropout:
        layer_name= "dropout" + str(i+1)  
        class_head.add_module(layer_name, torch.nn.Dropout(p=self.dropout_prob))
    
    # - Add dropout if no dense layer specified?
    if self.n_dense_layers<=0 and self.add_dropout:
      class_head.add_module("dropout", torch.nn.Dropout(p=self.dropout_prob))  
    
    # - Add output layer
    class_head.add_module("output", torch.nn.LazyLinear(self.num_classes))
    
    # - Override head
    self.model.fc = class_head
    
    return 0    

Let's create an instance of the ResNet classifier using a ResNet18 architecture with ImageNet pre-trained weights. The classification head consists of 1 dense hidden layer with 64 neurons. Dropout with 0.5 probability is added in dense layers. Change these settings as you wish.

In [None]:
# - Create model
nn_arch= "resnet18"
pretrained_weights="DEFAULT"
n_dense_layers= 1
dense_layer_sizes= [64]
add_dropout= True
dropout_prob= 0.5

classifier_resnet= ResNetClassifier(
  nn_arch=nn_arch,
  pretrained_weights=pretrained_weights,
  num_classes=4,
  n_dense_layers=n_dense_layers,  
  dense_layer_sizes=dense_layer_sizes,
  add_dropout= add_dropout,
  dropout_prob= dropout_prob 
)

# - Print model architecture
input_shape= (imgs.shape[1], imgs.shape[2], imgs.shape[3])
summary(classifier_resnet.model, input_shape)

### Train model
We define below some methods to run model training.

In [None]:
class AverageMeter:
  def __init__(self):
    self.reset()

  def reset(self):
    self.sum = 0
    self.count = 0

  def update(self, value, n=1):
    self.sum += value * n
    self.count += n

  @property
  def avg(self):
    return self.sum / self.count if self.count > 0 else 0

def run_train(
  classifier,
  train_dl,
  val_dl= None,
  num_epochs: Optional[int] = 1, 
  loss_fn= None,
  optimizer= None,  
  lr: Optional[float] = 1e-4,
  outfile_model="model.pth",
):
  """ Train network """

  # - Get model from classifier
  model= classifier.model  

  # - Set loss
  if loss_fn is None:  
    print("Setting default CE loss ...")
    loss_fn= torch.nn.CrossEntropyLoss()
  
  # - Set optimizer
  if optimizer is None:
    print("Setting defaulf Adam optimizer with lr=%f ..." % (lr))
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
  # - Set output model/weights filenames
  outfile_model_basenoext= os.path.splitext(os.path.basename(outfile_model))[0]
  outfile_model_dir= os.path.dirname(os.path.abspath(outfile_model))
  outfile_model_best= os.path.join(outfile_model_dir, outfile_model_basenoext + '_best.pth')
  outfile_weights= os.path.join(outfile_model_dir, outfile_model_basenoext + '_weights.pth')
  outfile_weights_best= os.path.join(outfile_model_dir, outfile_model_basenoext + '_weights_best.pth')
    
  # - Init metrics
  train_accuracy_metric = torchmetrics.Accuracy(
    task="multiclass", 
    num_classes=classifier.num_classes
  ).to(classifier.device)
    
  train_f1score_metric = torchmetrics.F1Score(
    task="multiclass", 
    num_classes=classifier.num_classes, 
    average="macro"
  ).to(classifier.device)

  val_accuracy = None
  val_f1score = None
  if val_dl is not None:
    val_accuracy_metric= torchmetrics.Accuracy(
      task="multiclass", 
      num_classes=classifier.num_classes
    ).to(classifier.device)
    
    val_f1score_metric= torchmetrics.F1Score(
      task="multiclass", 
      num_classes=classifier.num_classes, 
      average="macro"
    ).to(classifier.device)

  loss_hist_train = [0] * num_epochs
  accuracy_hist_train = [0] * num_epochs
  f1score_hist_train = [0] * num_epochs
  loss_hist_val = [0] * num_epochs
  accuracy_hist_val = [0] * num_epochs
  f1score_hist_val = [0] * num_epochs 
        
  # - Training loop
  best_val_acc = 0.0
    
  for epoch in range(num_epochs):
    # - Run train batch loop
    train_loss, train_acc, train_f1score = train_epoch(
      classifier,  
      train_dl, 
      loss_fn, 
      optimizer,
      epoch, 
      train_accuracy_metric,
      train_f1score_metric  
    )
    loss_hist_train[epoch]= train_loss
    accuracy_hist_train[epoch]= train_acc
    f1score_hist_train[epoch]= train_f1score
        
    # - Run validation batch loop?
    val_loss= 0.
    val_acc= 0.
    val_f1score= 0.
        
    if val_dl is not None:
      val_loss, val_acc, val_f1score = validate_epoch(
        classifier,  
        val_dl, 
        loss_fn,
        epoch, 
        val_accuracy_metric,
        val_f1score_metric  
      )
      loss_hist_val[epoch]= val_loss
      accuracy_hist_val[epoch]= val_acc
      f1score_hist_val[epoch]= val_f1score

    # - Print metrics  
    if val_dl is not None:
      print("Epoch [%d/%d]: loss=%.4f (val=%.4f), acc=%.4f (val=%.4f), f1=%.4f (val=%.4f)" % (epoch, num_epochs, train_loss, val_loss, train_acc, val_acc, train_f1score, val_f1score))
    else:
      print("Epoch [%d/%d]: loss=%.4f, acc=%.4f, f1=%.4f" % (epoch, num_epochs, train_loss, train_acc, train_f1score))
        
    # - Save best model  
    if val_dl is not None and val_acc > best_val_acc:  
      best_val_acc = val_acc
      print("Saving best model at epoch %d (acc_val=%.4f) ..." % (epoch+1, best_val_acc))  
      torch.save(model.state_dict(), outfile_weights_best)  
      torch.save(model, outfile_model_best)
          
  # - Save final model
  print("Saving final model ...")  
  torch.save(model.state_dict(), outfile_weights) 
  torch.save(model, outfile_model)
    
  # - Set metric history
  metric_hist= {
    "loss_train": loss_hist_train,
    "acc_train": accuracy_hist_train,
    "f1score_train": f1score_hist_train,
    "loss_val": loss_hist_val,
    "acc_val": accuracy_hist_val,
    "f1score_val": f1score_hist_val
  }
    
  print("END TRAIN RUN")
    
  return metric_hist


def train_epoch(
  classifier,
  dataloader, 
  criterion, 
  optimizer, 
  epoch, 
  accuracy_metric,
  f1score_metric
):
  """ Train one epoch """
    
  # - Retrieve model
  model= classifier.model

  # - Init metrics
  model.train()
  loss_meter = AverageMeter()
  accuracy_metric.reset()
  f1score_metric.reset()
  progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Training]", leave=False)

  # - Run batch loop
  for X_batch, y_batch in progress_bar:
    X_batch, y_batch = X_batch.to(classifier.device), y_batch.to(classifier.device)

    # - Compute prediction and loss   
    outputs = model(X_batch)
    loss = criterion(outputs, y_batch)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # - Update loss and accuracy
    loss_meter.update(loss.item(), X_batch.size(0))
    preds = outputs.argmax(dim=1)
    accuracy_metric.update(preds, y_batch)
    f1score_metric.update(preds, y_batch)

    # - Update progress bar
    progress_bar.set_postfix(
      loss=loss_meter.avg, 
      accuracy=accuracy_metric.compute().item(),
      f1score=f1score_metric.compute().item()
    )

  avg_loss = loss_meter.avg
  avg_accuracy = accuracy_metric.compute().item()
  avg_f1score = f1score_metric.compute().item()

  return avg_loss, avg_accuracy, avg_f1score

def validate_epoch(
  classifier,
  dataloader,
  criterion,
  epoch,
  accuracy_metric,
  f1score_metric  
):
  """ Run validation loop """      
    
  # - Retrieve model
  model= classifier.model
    
  # - Init metrics
  model.eval()
  loss_meter = AverageMeter()
  accuracy_metric.reset()
  f1score_metric.reset()
  progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Validation]", leave=False)
    
  with torch.no_grad():
    for X_batch, y_batch in progress_bar:
      X_batch, y_batch = X_batch.to(classifier.device), y_batch.to(classifier.device)

      # - Compute prediction and loss   
      outputs = model(X_batch)
      loss = criterion(outputs, y_batch)

      # - Update loss and accuracy
      loss_meter.update(loss.item(), X_batch.size(0))
      preds = outputs.argmax(dim=1)
      accuracy_metric.update(preds, y_batch)
      f1score_metric.update(preds, y_batch)

      # - Update progress bar
      progress_bar.set_postfix(
        loss=loss_meter.avg
      )
        
  avg_loss = loss_meter.avg
  avg_accuracy = accuracy_metric.compute().item()
  avg_f1score = f1score_metric.compute().item()
    
  return avg_loss, avg_accuracy, avg_f1score

Let's set some train parameters:

- learning rate
- number of training epochs
- batch size (was defined in data loaders)
- loss function

You can change them as you prefer.

In [None]:
num_epochs= 10
lr= 1e-4
loss_fn= torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier_resnet.model.parameters(), lr=lr)

We now train the classifier using methods and parameters defined above.

In [None]:
# - Run train
torch.manual_seed(1)
outfile_model= os.path.join(rundir, "resnet_model.pth")

metric_hist_resnet= run_train(
  classifier=classifier_resnet,
  train_dl=dataloader_imagenet_train,
  val_dl=dataloader_imagenet_val,  
  num_epochs=num_epochs, 
  loss_fn=loss_fn,
  optimizer=optimizer,  
  lr=lr,
  outfile_model=outfile_model
)

Let's plot some metrics after the training run completed.

In [None]:
def draw_metric_hist(metric_hist):
  
  # - Draw losses
  x_arr = np.arange(len(metric_hist["loss_train"])) + 1
  fig = plt.figure(figsize=(12, 4))
  ax = fig.add_subplot(1, 2, 1)
  ax.plot(x_arr, metric_hist["loss_train"], '-o', label='Train loss')
  ax.plot(x_arr, metric_hist["loss_val"], '--<', label='Validation loss')
    
  ax.legend(fontsize=8)
  ax = fig.add_subplot(1, 2, 2)
  
  # - Draw accuracy/f1score
  ax.set_ylim(0,1)  
  ax.plot(x_arr, metric_hist["acc_train"], '-o', label='Train acc.')
  ax.plot(x_arr, metric_hist["acc_val"], '--<', label='Validation acc.')
  ax.plot(x_arr, metric_hist["f1score_train"], '-*', label='Train F1-score')
  ax.plot(x_arr, metric_hist["f1score_val"], '-->', label='Validation F1-score')  
  ax.legend(fontsize=8)
  ax.set_xlabel('Epoch', size=15)
  ax.set_ylabel('Accuracy/F1-score', size=15)

  plt.show()  

# - Print & plot metrics
print("== metrics ==")
print(metric_hist_resnet)
draw_metric_hist(metric_hist_resnet)

### Evaluate model
We evaluate the trained classifier on test data, computing:

- classification metrics (accuracy, F1-score, confusion matrix)
- plotting feature maps
- plotting activation heatmaps

In [None]:
def run_test(
  classifier,
  dataloader, 
  modelfile="", 
  weightfile=""
):
  """ Compute model performances on test data """

  # - Load model from file?
  if modelfile=="":
    model= classifier.model  
  else:
    print("Loading model from file %s ..." % (modelfile))    
    model= torch.load(modelfile, weights_only=False)
      
  # - Check for model/dataloader
  if model is None:
    print("ERROR: No model present, cannot run prediction on test set!")
    return None
    
  # - Load model weights
  if weightfile!="":
    print("Loading model weights from file %s ..." % (weightfile))
    model.load_state_dict(torch.load(weightfile, weights_only=True))
      
  model.to(classifier.device).eval()
    
  # - Init metrics    
  accuracy_metric= torchmetrics.Accuracy(
    task="multiclass", 
    num_classes=classifier.num_classes
  ).to(classifier.device)
  accuracy_metric.reset()

  f1score_metric= torchmetrics.F1Score(
    task="multiclass", 
    num_classes=classifier.num_classes, 
    average="macro"
  ).to(classifier.device)
  f1score_metric.reset()

  confusion_matrix_metric= torchmetrics.ConfusionMatrix(
    task="multiclass", 
    num_classes=classifier.num_classes, 
    normalize="true"
  ).to(classifier.device)
  confusion_matrix_metric.reset()
    
  progress_bar = tqdm(dataloader, desc="[Test]", leave=False)
    
  with torch.no_grad():
    for X_batch, y_batch in progress_bar:
      X_batch, y_batch = X_batch.to(classifier.device), y_batch.to(classifier.device)

      # - Compute prediction and loss   
      outputs = model(X_batch)
        
      # - Update loss and accuracy
      preds = outputs.argmax(dim=1)
      accuracy_metric.update(preds, y_batch)
      f1score_metric.update(preds, y_batch)
      confusion_matrix_metric.update(preds, y_batch)
        
  avg_accuracy = accuracy_metric.compute().item()
  avg_f1score = f1score_metric.compute().item()
  confusion_matrix= confusion_matrix_metric.compute().numpy()
    
  metrics= {
    "acc": avg_accuracy,
    "avg_f1score": avg_f1score,
    "cm": confusion_matrix,
    "cm_metric": confusion_matrix_metric 
  }
    
  return metrics  

Run evaluation and compute metrics.

In [None]:
weightfile= os.path.join(rundir, "resnet_model_weights.pth")

metrics_resnet_test= run_test(
  classifier_resnet,  
  dataloader=dataloader_imagenet_test,
  weightfile=weightfile 
)

#### Visualizing metrics

In [None]:
# - Print metrics
print("== metrics (TEST) ==")
print(metrics_resnet_test)

# - Draw confusion matrix
fig_, ax_ = metrics_resnet_test["cm_metric"].plot()

#### Visualizing feature maps
The activation maps, called feature maps, capture the result of applying the convolutional filters to input, such as the input image or another feature map.

The idea of visualizing a feature map for a specific input image would be to understand what features of the input are detected or preserved in the feature maps. The expectation would be that the feature maps close to the input detect small or fine-grained detail, whereas feature maps close to the output of the model capture more general features.

We define below a method to extract feature maps from a model.

In [None]:
def extract_feature_maps(
  classifier,
  image,  
  modelfile="", 
  weightfile="",
  return_avg_maps=False
):
  """ Extract a list of feature map for a model and input image """
    
  # - Load image on device
  #   NB: Transforms are expected to be already applied
  if image is None:
    print("ERROR: Input image is None!")
    return -1
  image = image.to(classifier.device)
    
  # - Load model from file?
  if modelfile=="":
    model= classifier.model  
  else:
    print("Loading model from file %s ..." % (modelfile))    
    model= torch.load(modelfile, weights_only=False)
      
  # - Check for model/dataloader
  if model is None:
    print("ERROR: No model present, cannot run prediction on test set!")
    return None
    
  # - Load model weights
  if weightfile!="":
    print("Loading model weights from file %s ..." % (weightfile))
    model.load_state_dict(torch.load(weightfile, weights_only=True))
    
  model.to(classifier.device).eval()
    
  # - Extract conv layers
  print("Extracting all model conv layers ...")
  conv_layers= []
  conv_layer_names= []
  for name, layer in model.named_modules():
    if type(layer) == torch.nn.Conv2d and "downsample" not in name:
      conv_layers.append(layer)
      conv_layer_names.append(name)
     
  # - Define activations hooks
  activations = {}
  def get_activation(name):
    def hook(model, input, output):
      activations[name] = output.detach()
    return hook
    
  for name, layer in zip(conv_layer_names, conv_layers):
    layer.register_forward_hook(get_activation(name))

  # - Forward pass with hooks
  output = model(image)
    
  # - Get activations
  feature_maps= []
  layer_names= []
  print("--> fm.shape")    
  for layer_name in conv_layer_names:
    layer_output= activations[layer_name].squeeze(0)
    print(layer_output.shape)
        
    if return_avg_maps:
      gray_scale = torch.sum(layer_output,0)
      gray_scale = gray_scale / layer_output.shape[0]
      feature_maps.append(gray_scale.data.cpu().numpy())
    else:
      feature_maps.append(layer_output.data.cpu().numpy())
    
    layer_names.append(layer_name)
     
  return feature_maps, layer_names

We define below a method to draw feature maps from a model.

In [None]:
def draw_feature_maps(
  classifier,
  image,
  modelfile="", 
  weightfile="",
  images_per_row= 4
):
  """ Extract and plot feature maps for an input image """      

  # - Retrieve feature maps
  feature_maps, layer_names= extract_feature_maps(  
    classifier,
    image,  
    modelfile=modelfile, 
    weightfile=weightfile  
  )
  if feature_maps is None:
    print("ERROR: Failed to compute feature maps!")
    return -1
  if not feature_maps:
    print("ERROR: Empty list of feature maps!")
    return -1
    
  # - Draw feature maps
  for layer_name, feature_map in zip(layer_names, feature_maps):
    n_features = feature_map.shape[0]
    size = feature_map.shape[1]
    n_cols = n_features // images_per_row
    display_grid = np.zeros((size * n_cols, images_per_row * size))
        
    for col in range(n_cols):
      for row in range(images_per_row):
        index= col * images_per_row + row
        channel_image = feature_map[index, :, :]
        channel_image -= channel_image.mean()
        channel_image /= channel_image.std()
        channel_image *= 64
        channel_image += 128
        channel_image = np.clip(channel_image, 0, 255).astype('uint8')
        display_grid[col * size : (col + 1) * size, row * size : (row + 1) * size] = channel_image

    scale = 1. / size
    plt.figure(figsize=(scale * display_grid.shape[1], scale * display_grid.shape[0]))
    plt.title(layer_name)
    plt.grid(False)
    plt.imshow(display_grid, aspect='auto', cmap='viridis')
    
  return 0

Let's plot the feature maps for a sample image.

In [None]:
# - Take a sample image from the test dataset
data_index= 0 # take the first
image, target= dataset_imagenet_test[data_index]
label= dataset_imagenet_test.target2label[target]
image= image.unsqueeze(0)

print("image")
print(type(image))
print(image.shape)

# - Draw feature maps
weightfile= os.path.join(rundir, "resnet_model_weights.pth")

draw_feature_maps(
  classifier_resnet,  
  image,
  weightfile=weightfile,
  images_per_row=16  
)

#### Visualizing heatmaps of class activation
Visualization of class activation map (CAM) is useful to understand which parts of a given image led the model to its final classification decision. It consists of producing heatmaps of class activation over input images. A class activation heatmap is a 2D grid of scores associated with a specific output class, computed for every location in any input image, indicating how important each location is with respect to the class under consideration.

The specific implementation we are going to use is the one described in “GradCAM: Visual Explanations from Deep Networks via Gradient-based Localization”, implemented in this python package. It consists of taking the output feature map of a convolution layer, given an input image, and weighing every channel in that feature map by the gradient of the class with respect to the channel. Intuitively, one way to understand this trick is that you're weighting a spatial map of "how intensely the input image activates different channels" by "how important each channel is with regard to the class", resulting in a spatial map of "how intensely the input image activates the class".

Let's plot the heatmaps relative to the last conv layer with respect to the predicted class for a sample of test images.

In [None]:
def plot_sample_predictions(
  classifier, 
  dataset,
  dataset_gradcam=None,
  modelfile="", 
  weightfile="",
  plot_gradcam=True, 
  gradcam_method="gradcam",
  layer_names=[],   
  aug_smooth=False,
  eigen_smooth=False,
  gradcam_alpha=0.5,
  apply_heatmap_thr=False,
  heatmap_thr=0.7,
  plot_gradcam_only=False
):
  """ Plot gradCAM on some images """
    
  # - Load model from file?
  if modelfile=="":
    model= classifier.model  
  else:
    print("Loading model from file %s ..." % (modelfile))    
    model= torch.load(modelfile, weights_only=False)
      
  # - Check for model/dataloader
  if model is None:
    print("ERROR: No model present, cannot run prediction on test set!")
    return None
    
  # - Load model weights
  if weightfile!="":
    print("Loading model weights from file %s ..." % (weightfile))
    model.load_state_dict(torch.load(weightfile, weights_only=True))
    
  model.to(classifier.device).eval()
    
  # - Set dataset to be used for gradcam
  if dataset_gradcam is None:
    dataset_gradcam= dataset
    
  # - Init gradCAM 
  methods = {
    "gradcam": GradCAM,
    "hirescam": HiResCAM,
    "scorecam": ScoreCAM,
    "gradcam++": GradCAMPlusPlus,
    "ablationcam": AblationCAM,
    "xgradcam": XGradCAM,
    "eigencam": EigenCAM,
    "eigengradcam": EigenGradCAM,
    "layercam": LayerCAM,
    "fullgrad": FullGrad,
    #"fem": FEM,
    "gradcamelementwise": GradCAMElementWise,
    "kpcacam": KPCA_CAM,
    #"shapleycam": ShapleyCAM
  }
  cam_algorithm = methods[gradcam_method]
  
  # - Find RELU layers in model  
  print("Printing all relu layers in model ...")
  #print(model)
  print(find_layer_types_recursive(model, [torch.nn.ReLU]))
  target_layers= []
  for item in layer_names:
    layer= model._modules[item]
    print("layer")
    print(layer)
    target_layers.append(layer)
    
  # - Set target
  #   If targets is None, the highest scoring category (for every member in the batch) will be used
  targets = None
    
  # - Plot images
  fig = plt.figure(figsize=(15, 15))
      
  for i, ((input_img, target), (input_img_gradcam, target_gradcam)) in islice(enumerate(zip(dataset,dataset_gradcam)), 16):    
    with torch.enable_grad():
      # - Compute model prediction
      label= dataset.target2label[target] 
      print("Computing model prediction for image (label=%s, target=%d) ..." % (label, target))
      input_tensor= input_img.unsqueeze(0)
      input_tensor_gradcam= input_img_gradcam.unsqueeze(0)                                                       
          
      pred = model(input_tensor) # logits  
      y_pred = torch.argmax(pred).item()
      soft_outputs = torch.nn.functional.softmax(pred, dim=1) # pass through softmax
      prob_pred, target_pred = soft_outputs.topk(1, dim = 1) # select top probability as prediction
      prob_pred= prob_pred.item()
      target_pred= target_pred.item()  
      label_pred= dataset.target2label[target_pred]
        
      # - Create image for plot
      rgb_img= input_img_gradcam.permute(1, 2, 0).numpy()                                                       
      grayscale_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY)
        
      # - Compute cam
      if plot_gradcam:
        with cam_algorithm(model=model, target_layers=target_layers) as cam:
          grayscale_cam = cam(
            input_tensor=input_tensor_gradcam, 
            targets=targets,
            aug_smooth=aug_smooth,
            eigen_smooth=eigen_smooth
          )
          grayscale_cam = grayscale_cam[0, :]
          cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=gradcam_alpha)
          cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
        
          # - Create heatmap
          colormap= cv2.COLORMAP_JET
          mask= np.copy(grayscale_cam)
          heatmap= cv2.applyColorMap(np.uint8(255 * mask), colormap)
          heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
          heatmap = np.float32(heatmap) / 255
            
          alphas= np.ones(grayscale_img.shape)*gradcam_alpha
          if apply_heatmap_thr:
            alphas[grayscale_cam<heatmap_thr]= 0 # set invisible    
    
      # - Plot image
      ax = fig.add_subplot(4, 4, i+1)
      ax.set_xticks([]); ax.set_yticks([])
      if not plot_gradcam_only:
        ax.imshow(grayscale_img, cmap="gray")
            
      if plot_gradcam:
        if plot_gradcam_only:
          ax.imshow(cam_image)  
        else:
          ax.imshow(heatmap, alpha=alphas)
            
      ax.set_title(f'{label} \n pred: {label_pred}, p={prob_pred:.1f})', size=12)
  
  plt.show()

In [None]:
# - Run gradCAM on some test data
target_layers= ["layer4"]
weightfile= os.path.join(rundir, "resnet_model_weights.pth")

plot_sample_predictions(
  classifier_resnet,    
  dataset=dataset_imagenet_test,
  dataset_gradcam=dataset_test,
  weightfile=weightfile, 
  plot_gradcam=True,
  layer_names=target_layers,
  gradcam_method="gradcam",
  aug_smooth=False,
  eigen_smooth=False,
  gradcam_alpha=0.3,
  apply_heatmap_thr=True,
  heatmap_thr=0.5 
)

## Custom classifier
Let's implement a class that uses the torch `Sequential` class to define a custom network architecture.

In [None]:
class CustomClassifier():
  """ Build a custom CNN network """  

  def __init__(
    self,
    nn_arch: Optional[str] = "custom",
    num_classes: Optional[int] = 4,
    n_conv_layers: Optional[int] = 3,
    n_filters: Optional[Union[int, list]] = [8,16,32],
    kern_sizes: Optional[Union[int, list]] = [3,5,5],
    strides: Optional[Union[int, list]] = [1,1,1],
    add_maxpool: Optional[bool] = True,
    pool_sizes: Optional[Union[int, list]] = 2,
    add_batchnorm: Optional[bool] = True,
    n_dense_layers: Optional[int] = 1,
    dense_layer_sizes: Optional[Union[int, list]] = [64],
    add_dropout: Optional[bool] = True,
    dropout_prob: Optional[float] = 0.5
  ):
    self.model= None
    self.device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.nn_arch= nn_arch
    self.num_classes= num_classes
    self.n_conv_layers= n_conv_layers
    self.n_dense_layers= n_dense_layers
    self.add_maxpool= add_maxpool
    self.add_batchnorm= add_batchnorm
    self.add_dropout= add_dropout
    self.dropout_prob= dropout_prob
    
    # - Set number of conv filters per layer
    if isinstance(n_filters, list):
      if len(n_filters)!=self.n_conv_layers:
        raise Exception("n_filters list must have length equal to n_conv_layers!")  
      else:
        self.n_filters= n_filters
    else:
      self.n_filters= [n_filters]*self.n_conv_layers 
    
    # - Set conv filter kernel size per layer
    if isinstance(kern_sizes, list):
      if len(kern_sizes)!=self.n_conv_layers:
        raise Exception("kern_sizes list must have length equal to n_conv_layers!")  
      else:
        self.kern_sizes= kern_sizes
    else:
      self.kern_sizes= [kern_sizes]*self.n_conv_layers
    
    # - Set conv filter stride size per layer
    if isinstance(strides, list):
      if len(strides)!=self.n_conv_layers:
        raise Exception("strides list must have length equal to n_conv_layers!")  
      else:
        self.strides= strides
    else:
      self.strides= [strides]*self.n_conv_layers
    
    # - Set conv filter stride size per layer
    if isinstance(pool_sizes, list):
      if len(pool_sizes)!=self.n_conv_layers:
        raise Exception("pool_sizes list must have length equal to n_conv_layers!")  
      else:
        self.pool_sizes= pool_sizes
    else:
      self.pool_sizes= [pool_sizes]*self.n_conv_layers
    
    # - Set dense layer size per layer
    if isinstance(dense_layer_sizes, list):
      if len(dense_layer_sizes)!=self.n_dense_layers:
        raise Exception("pool_sizes list must have length equal to n_conv_layers!")  
      else:
        self.dense_layer_sizes= dense_layer_sizes
    else:
      self.dense_layer_sizes= [dense_layer_sizes]*self.n_dense_layers
    
    # - Build network
    print("Building NN architecture ...")
    if self.__build_model()<0:
      print("ERROR: Failed to build nn!")
      raise Exception("Failed to build nn!")
    
    # - Move model to device
    print("Moving model to device %s ..." % (self.device))
    self.model.to(self.device)

  def __build_model(self):  
    """ Create network """
    
    # - Create model using nn.Sequential class
    self.model = torch.nn.Sequential()
    
    # - Add CNN layers
    for i in range(self.n_conv_layers):
      # - Add convolution layer  
      layer_name= 'conv' + str(i+1)
      self.model.add_module(
        layer_name,
        torch.nn.LazyConv2d(
          out_channels=self.n_filters[i],
          kernel_size=self.kern_sizes[i],
          padding="same",
          stride=self.strides[i]
        )
      )
    
      # - Add activation
      layer_name= 'relu' + str(i+1)
      self.model.add_module(layer_name, torch.nn.ReLU())
      
      # - Add batch normalization?
      if self.add_batchnorm:
        layer_name= "bn" + str(i+1)
        self.model.add_module(layer_name, torch.nn.LazyBatchNorm2d())
    
      # - Add max pool layer?
      if self.add_maxpool:
        layer_name= 'pool' + str(i+1)
        self.model.add_module(layer_name, torch.nn.MaxPool2d(kernel_size=self.pool_sizes[i]))
        
    # - Flatten layer
    self.model.add_module('flatten', torch.nn.Flatten())
    
    # - Add dense layers
    for i in range(self.n_dense_layers):
      # - Add dense layer
      layer_name= "fc" + str(i+1)
      self.model.add_module(layer_name, torch.nn.LazyLinear(self.dense_layer_sizes[i]))
    
      # - Add activation
      layer_name= "relu_fc" + str(i+1)  
      self.model.add_module(layer_name, torch.nn.ReLU())
    
      # - Add dropout?
      if self.add_dropout:
        layer_name= "dropout" + str(i+1)  
        self.model.add_module(layer_name, torch.nn.Dropout(p=self.dropout_prob))
    
    # - Add dropout if no dense layer specified?
    if self.n_dense_layers<=0 and self.add_dropout:
      self.model.add_module("dropout", torch.nn.Dropout(p=self.dropout_prob))  
    
    # - Add output layer
    self.model.add_module("output", torch.nn.LazyLinear(self.num_classes))
    
    return 0

It's time for you to create your custom model.

For example, the following code creates an instance of the custom classifier with 3 conv layers, a classification head with 1 dense hidden layer with 64 neurons:

```
classifier= CustomClassifier(
  nn_arch=nn_arch,
  n_conv_layers= 3,
  n_filters= [8,16,32],
  kern_sizes= [3,5,7],
  strides= [1,1,1],
  add_maxpool= True,
  pool_sizes= 2,
  add_batchnorm= True,
  n_dense_layers= 1,
  dense_layer_sizes= [64],
  add_dropout= True,
  dropout_prob = 0.5
)
summary(classifier.model, input_shape)
```

Change layer configuration (filters, stride, add/remove layer) as you wish. 

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...

### Train model
Train the custom classifier using the class `CustomClassifier` defined above and following the steps done with the ResNet classifier.

You can re-use the methods defined before:

- training method `run_train()`
- data loaders (`dataloader_train`, `dataloader_val`)
- plotting methods (e.g. `draw_metric_hist`)

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...

Let's plot some metrics after the training run completed.

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...

### Evaluate model
Let's load the saved trained model and run inference on test data.
You can re-use the methods defined before:

- evaluation `run_test()`
- plotting methods (`plot_sample_predictions()`, `draw_feature_maps()`)

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...

#### Visualizing metrics

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...

#### Visualizing feature maps

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...

#### Visualizing heatmaps of class activation
Visualize heatmaps using relu layers, e.g. setting `target_layers` in `plot_sample_predictions` to one of the ReLu layer, e.g. `target_layers= ["relu3"]` if you defined .

In [None]:
##### ADD YOUR CODE HERE ######
# ...
# ...