## Standard Deviation

This notebook a follow-up to [numerical-instability.ipynb](./numerical-instability.ipynb) and [num-instability-standard-deviation](./num-instability-standard-deviation.ipynb).

Requirements:
- **Disable heuristic to reproduce numerical instability and calculate the standard deviation in this case.**


## Debugging

Detect `NaN` in Tensor computations (e.g., as a result of 0/0).

Enable anomaly detection for autograd engine.
Any backward computation that generate “nan” value will raise an error.

- Discussion: https://discuss.pytorch.org/t/finding-source-of-nan-in-forward-pass/51153/3
- Docs: https://pytorch.org/docs/stable/autograd.html#torch.autograd.set_detect_anomaly

In [1]:
import torch

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x10fc34c10>

## LRP setup

In [2]:
%load_ext autoreload
%autoreload 2

from typing import Callable, Dict, List, Tuple, Union

import numpy
import torch
import torchvision
from matplotlib import pyplot as plt

import lrp.plot
import lrp.rules as rules
from lrp import image, norm
from lrp.core import LRP
from lrp.filter import LayerFilter
from lrp.rules import LrpEpsilonRule, LrpGammaRule, LrpZBoxRule, LrpZeroRule
from lrp.zennit.types import AvgPool, Linear

# Normalization
norm_fn: Callable[[torch.Tensor], torch.Tensor] = norm.ImageNetNorm()

import numpy
# Input data
from torchvision import transforms

from lrp import image

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [3]:
def lrp_workflow(img_path: str):
  img: numpy.array = image.load_normalized_img(img_path)

  transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean, std)
  ])

  inv_norm = transforms.Normalize(
    mean= [-m/s for m, s in zip(mean, std)],
    std= [1/s for s in std]
  )

  X = transform(img)
  # Simulate batch by adding a new dimension
  X = torch.unsqueeze(X, 0)

  # Model
  model = torchvision.models.vgg16(pretrained=True)
  model.eval()

  # Low and high parameters for zB-rule
  batch_size: int = 1
  shape: Tuple[int] = (batch_size, 3, 224, 224)

  low: torch.Tensor = norm_fn(torch.zeros(*shape))
  high: torch.Tensor = norm_fn(torch.ones(*shape))

  # Init layer filter
  vgg16_target_types: Tuple[type] = (Linear, AvgPool)
  filter_by_layer_index_type = LayerFilter(model)
  filter_by_layer_index_type.set_target_types(vgg16_target_types)

  name_map: List[Tuple[List[str], rules.LrpRule, Dict[str, Union[torch.Tensor, float]]]]

  name_map = [
      (filter_by_layer_index_type(lambda n: n == 0), LrpZBoxRule, {'low': low, 'high': high}),
      (filter_by_layer_index_type(lambda n: 1 <= n <= 16), LrpGammaRule, {'gamma': 0}),
      (filter_by_layer_index_type(lambda n: 17 <= n <= 30), LrpEpsilonRule, {'epsilon': 0.25}),
      (filter_by_layer_index_type(lambda n: 31 <= n), LrpZeroRule, {}),
  ]

  lrp_example = LRP(model)
  lrp_example.convert_layers(name_map)

  R: torch.Tensor = lrp_example.relevance(X)

  # fig, ax = plt.subplots()
  # img: numpy.array = image.load_normalized_img(img_path)
  # lrp.plot.heatmap(R[0].sum(dim=0).detach().numpy(), width=2, height=2, show_plot=False, fig=ax)

  # transform = transforms.Compose([
  #     transforms.ToTensor(),
  #     transforms.Resize((224, 224)),
  #     transforms.ConvertImageDtype(torch.float),
  # ])

  # ax.imshow(transform(img).numpy().transpose(1,2,0), alpha=0.2)

  return R

## Calculate standard deviation of all castle images in ILSVRC 2012's validation set

In [4]:
import os

# Image is in RGB mode with range [0,1]
img_path = f'../data/castle.jpg'
# Extract image name without extension
img_name = os.path.basename(img_path).split('.')[0]

R = lrp_workflow(img_path)
torch.save(R, f'./artifacts/num-instability-standard-deviation/relevance_scores_{img_name}.pt')
pooled_relevance_scores = R[0].sum(dim=0)
torch.save(pooled_relevance_scores, f'./artifacts/num-instability-standard-deviation/pooled_relevance_scores_{img_name}.csv')

with open(f'./artifacts/num-instability-standard-deviation/std_relevance_scores_{img_name}.txt', "w") as f:
  f.write(f'Standard deviation of relevance scores of {img_name}: {torch.std(R)}')

with open(f'./artifacts/num-instability-standard-deviation/std_pooled_relevance_scores_{img_name}.txt', "w") as f:
  f.write(f'Standard deviation of pooled relevance scores of {img_name}: {torch.std(R[0].sum(dim=0))}')

print(f'{img_name} std: {torch.std(R)} std_pooled: {torch.std(R[0].sum(dim=0))}')

castle std: 0.44248613715171814 std_pooled: 1.2392504215240479
