# Investigate numerical instability

We recall the hint at the heuristic in the algorithm presented in the paper "_LRP: An Overview_":

> The small additive term 1e-9 in the division simply enforces the behavior 0/0 = 0.


## Setup

1. The function `stabilize` in [lrp/zennit/core.py](../lrp/zennit/core.py) needs to be set to the values provided in the next section.
1. The image to load should be [castle.jpg](../data/castle.jpg).
1. The minimal working example should have the following rule-layer mapping:

```python
name_map = [
    (filter_by_layer_index_type(lambda n: n == 0), LrpZBoxRule, {'low': low, 'high': high}),
    (filter_by_layer_index_type(lambda n: 1 <= n <= 16), LrpGammaRule, {'gamma': 0}),
    (filter_by_layer_index_type(lambda n: 17 <= n <= 30), LrpEpsilonRule, {'epsilon': 0.25}),
    (filter_by_layer_index_type(lambda n: 31 <= n), LrpZeroRule, {}),
]
```

## Root cause of numerical instability

The issue arises depending on the implementation of the [`stabilize`](https://github.com/rodrigobdz/lrp/blob/c163c519599d0dd1320e8ed8cab5daac1978fe20/lrp/zennit/core.py#L16-L25) method and it is input-dependent— it doesn't appear if we use [castle2.jpg](../data/castle2.jpg) instead.

The following implementations have been tested:

```python
epsilon: float = 0.1
dividend: torch.Tensor = torch.Tensor([-epsilon, 5, -5, -10])
# tensor([ -0.1000,   5.0000,  -5.0000, -10.0000])
```

1. **Heuristic:** Add epsilon to the absolute value of the dividend (zennit's) conserving the sign: 

    ```python
    dividend + ((dividend == 0.).to(dividend) + dividend.sign()) * epsilon
    ```

    Example:

    ```python
    dividend + ((dividend == 0.).to(dividend) + dividend.sign()) * epsilon
    # tensor([ -0.2000,   5.1000,  -5.1000, -10.1000])
    ```

2. **Heuristic:** Scale epsilon according to dividend's magnitude using quadratic mean
    
    ```python
    dividend + epsilon * (dividend**2).mean()**.5 + 1e-9
    ```

    Example:

    ```python
    dividend + epsilon * (dividend**2).mean()**.5 + 1e-9
    # tensor([ 0.5124,  5.6124, -4.3876, -9.3876])
    ```

3. **Vanilla:** Add epsilon to dividend without heuristics

    ```python
    dividend + epsilon
    ```

    Example:

    ```python
    dividend + epsilon
    # tensor([ 0.0000,  5.1000, -4.9000, -9.9000])
    ```

## Debugging

Detect `NaN` in Tensor computations (e.g., as a result of 0/0).

Enable anomaly detection for autograd engine.
Any backward computation that generate “nan” value will raise an error.

- Discussion: https://discuss.pytorch.org/t/finding-source-of-nan-in-forward-pass/51153/3
- Docs: https://pytorch.org/docs/stable/autograd.html#torch.autograd.set_detect_anomaly

In [1]:
import torch
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x108a0a250>

## Reproducing LRP's "bug"

LRP setup (scaffold?)

In [2]:
%load_ext autoreload
%autoreload 2

from typing import Callable
import torchvision
import numpy
import torch
from lrp import norm, image

from lrp.rules import LrpZBoxRule, LrpGammaRule, LrpEpsilonRule, LrpZeroRule

from typing import List, Dict, Union, Tuple
from lrp.filter import LayerFilter
from lrp.zennit.types import AvgPool, Linear
import lrp.rules as rules

from lrp.core import LRP

import lrp.plot
from matplotlib import pyplot as plt

# Normalization
norm_fn: Callable[[torch.Tensor], torch.Tensor] = norm.ImageNetNorm()

# Input data
from torchvision import transforms

import numpy
from lrp import image

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Resize((224, 224)),
  transforms.ConvertImageDtype(torch.float),
  transforms.Normalize(mean, std)
])

inv_norm = transforms.Normalize(
  mean= [-m/s for m, s in zip(mean, std)],
  std= [1/s for s in std]
)

# Image is in RGB mode with range [0,1]
img_path = '../data/castle.jpg'
img: numpy.array = image.load_normalized_img(img_path)
X = transform(img)
# Simulate batch by adding a new dimension
X = torch.unsqueeze(X, 0)

# Model
model = torchvision.models.vgg16(pretrained=True)
model.eval()

# Low and high parameters for zB-rule
batch_size: int = 1
shape: Tuple[int] = (batch_size, 3, 224, 224)

low: torch.Tensor = norm_fn(torch.zeros(*shape))
high: torch.Tensor = norm_fn(torch.ones(*shape))

# Init layer filter
vgg16_target_types: Tuple[type] = (Linear, AvgPool)
filter_by_layer_index_type = LayerFilter(model)
filter_by_layer_index_type.set_target_types(vgg16_target_types)

name_map: List[Tuple[List[str], rules.LrpRule, Dict[str, Union[torch.Tensor, float]]]]

LRP rule-layer mapping and computation

In [3]:
name_map = [
    (filter_by_layer_index_type(lambda n: n == 0), LrpZBoxRule, {'low': low, 'high': high}),
    (filter_by_layer_index_type(lambda n: 1 <= n <= 16), LrpGammaRule, {'gamma': 0}),
    (filter_by_layer_index_type(lambda n: 17 <= n <= 30), LrpEpsilonRule, {'epsilon': 0.25}),
    (filter_by_layer_index_type(lambda n: 31 <= n), LrpZeroRule, {}),
]

lrp_example = LRP(model)
lrp_example.convert_layers(name_map)

R: torch.Tensor = lrp_example.relevance(X)

fig, ax = plt.subplots()
img: numpy.array = image.load_normalized_img(img_path)
lrp.plot.heatmap(R[0].sum(dim=0).detach().numpy(), width=2, height=2, show_plot=False, fig=ax)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.ConvertImageDtype(torch.float),
])

ax.imshow(transform(img).numpy().transpose(1,2,0), alpha=0.2)
# fig.savefig('/Users/rodrigobermudezschettino/Downloads/castle_lrp.png', dpi=150)

  File "/usr/local/Cellar/python@3.9/3.9.10/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/Cellar/python@3.9/3.9.10/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/rodrigobermudezschettino/Documents/personal/unterlagen/bildung/uni/master/masterarbeit/code/lrp/venv/lib/python3.9/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/rodrigobermudezschettino/Documents/personal/unterlagen/bildung/uni/master/masterarbeit/code/lrp/venv/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/Users/rodrigobermudezschettino/Documents/personal/unterlagen/bildung/uni/master/masterarbeit/code/lrp/venv/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 677, in start
    self.io_loop.sta

RuntimeError: Function 'MulBackward0' returned nan values in its 0th output.

## Inspect results

Try to find the correlation between the patches in the heatmaps and the computed relevance scores.

### Display min. and max. values in relevance scores

Relevance scores indicating pixel-wise contributions:
```python
R[0].sum(dim=0)
```

> The relevance scores obtained in the pixel layer can now be summed over the RGB channels to indicate actual pixel-wise contributions.
> 
> Source: lrp-tutorial

In [None]:
# dividend + epsilon
a = R[0].sum(dim=0).detach()
print('Total', torch.numel(a))
print('negative', torch.numel(a[a < 0.]))
print('zero', torch.numel(a[a == 0.]))
print('positive', torch.numel(a[a > 0.]))

display(torch.aminmax(a))

### Map indices of relevance scores to pixel values in input image and its gradient

The input image is stored in `X` and the relevance scores are stored in `R`. We also look into the gradient of the input image `X.grad`.

In [None]:
print(f'R {R.shape}')
print(f'X {X.shape}')
print(f'X.grad {X.grad.shape}')
print('\n')

a = (R[0].detach()==torch.max(R[0].detach())).nonzero()
i, h, k = a[0, :].tolist()

print(f'R max indices: {i} {h} {k}')

print(f'X [0][{i} {h} {k}] {X[0][i, h, k].item()}')
print(f'X.grad [0][{i} {h} {k}] {X.grad[0][i, h, k].item()}')
print(f'R [0][{i} {h} {k}] {R[0].detach()[i, h, k].item()}')

print('\n')

a = (R[0].detach()==torch.min(R[0].detach())).nonzero()
i, h, k = a[0, :].tolist()

print(f'R min indices: {i} {h} {k}')

print(f'X [0][{i} {h} {k}] {X[0][i, h, k].item()}')
print(f'X.grad [0][{i} {h} {k}] {X.grad[0][i, h, k].item()}')
print(f'R [0][{i} {h} {k}] {R[0].detach()[i, h, k].item()}')


print('\n')

# print('X max')
a = (X.grad==torch.max(X.grad)).nonzero()
z, i, h, k = a[0, :].tolist()

print(f'X.grad max indices: {z} {i} {h} {k}')
print(f'X [{z} {i} {h} {k}]: {X[z, i, h, k].item()}')
print(f'X.grad [{z} {i} {h} {k}]: {X.grad[z, i, h, k].item()}')

# print('\nX min')
print('\n')

a = (X.grad==torch.min(X.grad)).nonzero()
z, i, h, k = a[0, :].tolist()

print(f'X.grad min indices: {z} {i} {h} {k}')
print(f'X [{z} {i} {h} {k}]: {X[z, i, h, k].item()}')
print(f'X.grad [{z} {i} {h} {k}]: {X.grad[z, i, h, k].item()}')


### Inspect X.grad minima and maxima

Compute `X.grad`'s minima and maxima and show the values of `X` at these minima and maxima.

In [None]:
print(torch.min(X.grad))
min_x_grad = (X.grad==torch.min(X.grad)).nonzero()
z, i, h, k = min_x_grad[0, :].tolist()

a = X[z, i, h, k].item()
b = X.grad[z, i, h, k].item()
print(f'X min {a} X.grad min {b}')
print(a*b)

r = R[0].detach()
print(r.shape)
display(torch.aminmax(r))

r = R[0].sum(dim=0).detach()
print(r.shape)
display(torch.aminmax(r))

print(f'X.grad min indices: {z} {i} {h} {k}')
print(f'X [{z} {i} {h} {k}]: {X[z, i, h, k].item()}')

### Search for NaN values in relevance scores

- The numerical instability is reproducible with `lrp-tutorial`, where values turn to `NaN`.
  Results saved to file and committed to git ([robust](https://git.tu-berlin.de/rodrigobdz/lrp-tutorial/-/blob/955e0d495ab022ca51174feb978e82073e6672af/tutorial.ipynb), [error reproduction](https://git.tu-berlin.de/rodrigobdz/lrp-tutorial/-/blob/36d6de8db6f440aa8eebed291e93bdfc2c83f74f/tutorial.ipynb)).
- This repo's LRP leverages the forward-hook architecture implementation. Inspect patches in heatmaps (blown-up values in magnitude) to **see if relevance scores contain `NaN` values to verify if the behavior is consistent.**

In [None]:
display(R.isnan().any())
display(X.isnan().any())
display(X.grad.isnan().any())