# LRP ResNet Demonstration Notebook
In this notebook we deploy our LRPEngine on the **ResNet image classification model**. As comprehensibly described in one provided paper (see "literature/lrp_resnet"), the implementation of LRP on models with residual connections (as ResNet) is non-trivial.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torchvision.transforms import v2
# import cv2
# import json

# import numpy as np
import matplotlib.pyplot as plt

from src.xai.lrp import LRPEngine

## Model Initialisation

In [None]:
# constant variables in capital letters:
OUTPUT_PATH = "../output"

Decide which picture to use:

In [None]:
img_path = LINEALS

### 1. Load Model

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)

### 2. Model Description
The backbone (ResNet 50) consists of
- 1 Conv - Batch normalisation - ReLU - MaxPool2d layer
- 4 layers with 3, 4, 6, 3 Bottleneck blocks, respectively. Each Bottleneck block consists of
- - 3 Conv - Batch normalisation blocks (1x1, 3x3 and 1x1)
- - 1 ReLU layer
- - and (the first block) another "downsample" Conv - Batch normalisation block
- - **Importantly, there is a skip-connection either through an identity mapping or a linear projection (in downsample blocks), that adds the input again before the next ReLU Layer!**
- 1 average pooling layer

The classifier consists of 
- 1 Flattening layer
- 2 Linear - ReLU blocks
- One Linear output layer

Why so?
- **Drop-out** layers are a regularisation tool to *prevent overfitting*
- **Batch normalisation** layers are a regularisation tool to prevent vanishing gradients by normalizing the layer output based on statistics.
- **Pooling layers** reduce the dimension of the feature space

This is the code snipped for executing a Bottleneck block:
        
        def forward(self, x: Tensor) -> Tensor:
            identity = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            if self.downsample is not None:
                identity = self.downsample(x)
    
            out += identity
            out = self.relu(out)
    
            return out
            
Such **Bottleneck blocks introduce additional complexity because of the involved skip connections**.

![ResNet Architecture](../images/resnet50_architecture.png)

### 3. Demonstrate Model

In [None]:
medit_classifier.plot_prediction(model, img_path, PRED_PATH)

## LRP Implementation
We achieve all subsequent calculation steps through the **LRPEngine class**.

In [None]:
_, img_tensor = medit_classifier.load_img_and_tensor(img_path)

In [None]:
lrp_engine = LRPEngine(model, img_tensor.unsqueeze(0), plot_output_dir=OUTPUT_PATH)

### Validate Model
Here, we again read out the prediction, verifying that our collected activations are accurate.

In [None]:
lrp_engine.layers

In [None]:
categories = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
lrp_engine.print_results(categories, 4)

**Success!** The results are equivalent.

### Relevance Score Calculation
In the following we define and deploy the LRP implementation. Because LRP implementation on residual networks is non-trivial, we pulled inspiration from the https://github.com/keio-smilab24/LRP-for-ResNet/tree/main implementation and implemented it in our LRPEngine class.

An important hyperparameter is the **rel_filter_ratio** that describes the share of largest relevance values to be kept during LRP calculation, i.e. the remainder is set to zero to avoid spurious contributions.

In [None]:
lrp_engine.calculate_relevance_scores(rel_filter_ratio=.75)

### Result Visualisation
Here we can **read out the LRP values as a saliency map** from **different layers**. Furthermore, with the plt_cmap parameter we can define how the heatmap should be constructed.

The default is *"afmhot"*, which yields a heatmap ranging from black to light fire colors. Other options can be found here: https://matplotlib.org/stable/users/explain/colors/colormaps.html, of which some are demonstrated below.

In [None]:
layer_indices = range(20)
color_maps = ['bwr'] 

for layer_ind in layer_indices:
    for cmap in color_maps:
        lrp_engine.plot_relevance_scores(layer_ind, cmap)

In [None]:
lrp_engine.layers[20:]

In [None]:
layer_indices = range(20, len(lrp_engine.layers))
color_maps = ['bwr'] 

for layer_ind in layer_indices:
    for cmap in color_maps:
        lrp_engine.plot_relevance_scores(layer_ind, cmap)