## Standard Deviation

Apply heuristic to many images to see if scores are outside of standard deviation and ensure the heuristic works for any given image. This notebook a follow-up to [numerical-instability.ipynb](./numerical-instability.ipynb).

## Debugging

Detect `NaN` in Tensor computations (e.g., as a result of 0/0).

Enable anomaly detection for autograd engine.
Any backward computation that generate “nan” value will raise an error.

- Discussion: https://discuss.pytorch.org/t/finding-source-of-nan-in-forward-pass/51153/3
- Docs: https://pytorch.org/docs/stable/autograd.html#torch.autograd.set_detect_anomaly

In [1]:
import torch
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x10b962730>

## LRP setup

In [2]:
%load_ext autoreload
%autoreload 2

from typing import Callable
import torchvision
import numpy
import torch
from lrp import norm, image

from lrp.rules import LrpZBoxRule, LrpGammaRule, LrpEpsilonRule, LrpZeroRule

from typing import List, Dict, Union, Tuple
from lrp.filter import LayerFilter
from lrp.zennit.types import AvgPool, Linear
import lrp.rules as rules

from lrp.core import LRP

import lrp.plot
from matplotlib import pyplot as plt

# Normalization
norm_fn: Callable[[torch.Tensor], torch.Tensor] = norm.ImageNetNorm()

# Input data
from torchvision import transforms

import numpy
from lrp import image

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [3]:
def lrp_workflow(img_path: str):
  img: numpy.array = image.load_normalized_img(img_path)

  transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean, std)
  ])

  inv_norm = transforms.Normalize(
    mean= [-m/s for m, s in zip(mean, std)],
    std= [1/s for s in std]
  )

  X = transform(img)
  # Simulate batch by adding a new dimension
  X = torch.unsqueeze(X, 0)

  # Model
  model = torchvision.models.vgg16(pretrained=True)
  model.eval()

  # Low and high parameters for zB-rule
  batch_size: int = 1
  shape: Tuple[int] = (batch_size, 3, 224, 224)

  low: torch.Tensor = norm_fn(torch.zeros(*shape))
  high: torch.Tensor = norm_fn(torch.ones(*shape))

  # Init layer filter
  vgg16_target_types: Tuple[type] = (Linear, AvgPool)
  filter_by_layer_index_type = LayerFilter(model)
  filter_by_layer_index_type.set_target_types(vgg16_target_types)

  name_map: List[Tuple[List[str], rules.LrpRule, Dict[str, Union[torch.Tensor, float]]]]

  name_map = [
      (filter_by_layer_index_type(lambda n: n == 0), LrpZBoxRule, {'low': low, 'high': high}),
      (filter_by_layer_index_type(lambda n: 1 <= n <= 16), LrpGammaRule, {'gamma': 0}),
      (filter_by_layer_index_type(lambda n: 17 <= n <= 30), LrpEpsilonRule, {'epsilon': 0.25}),
      (filter_by_layer_index_type(lambda n: 31 <= n), LrpZeroRule, {}),
  ]

  lrp_example = LRP(model)
  lrp_example.convert_layers(name_map)

  R: torch.Tensor = lrp_example.relevance(X)

  # fig, ax = plt.subplots()
  # img: numpy.array = image.load_normalized_img(img_path)
  # lrp.plot.heatmap(R[0].sum(dim=0).detach().numpy(), width=2, height=2, show_plot=False, fig=ax)

  # transform = transforms.Compose([
  #     transforms.ToTensor(),
  #     transforms.Resize((224, 224)),
  #     transforms.ConvertImageDtype(torch.float),
  # ])

  # ax.imshow(transform(img).numpy().transpose(1,2,0), alpha=0.2)

  return R

## Calculate standard deviation of all castle images in ILSVRC 2012's validation set

In [4]:
import os

# Get all the castle images in the data folder
castle_filenames = os.listdir('../data/castle')

for i, fname in enumerate(castle_filenames, start=1):
  # Image is in RGB mode with range [0,1]
  img_path = f'../data/castle/{fname}'
  # Extract image name without extension
  img_name = os.path.basename(img_path).split('.')[0]

  R = lrp_workflow(img_path)
  torch.save(R, f'./artifacts/heuristic-standard-deviation/relevance_scores_{img_name}.pt')
  pooled_relevance_scores = R[0].sum(dim=0)
  torch.save(pooled_relevance_scores, f'./artifacts/heuristic-standard-deviation/pooled_relevance_scores_{img_name}.csv')

  with open(f'./artifacts/heuristic-standard-deviation/std_relevance_scores_{img_name}.txt', "w") as f:
    f.write(f'Standard deviation of relevance scores of {fname}: {torch.std(R)}')
  
  with open(f'./artifacts/heuristic-standard-deviation/std_pooled_relevance_scores_{img_name}.txt', "w") as f:
    f.write(f'Standard deviation of pooled relevance scores of {fname}: {torch.std(R[0].sum(dim=0))}')

  print(f'{i} -> {img_name} std: {torch.std(R)} std_pooled: {torch.std(R[0].sum(dim=0))}')


1 -> ILSVRC2012_val_00041354 std: 0.000979261938482523 std_pooled: 0.0027281129732728004
2 -> ILSVRC2012_val_00003916 std: 0.001780993421562016 std_pooled: 0.004910173825919628
3 -> ILSVRC2012_val_00002487 std: 0.0011073173955082893 std_pooled: 0.0030731644947081804
4 -> ILSVRC2012_val_00019377 std: 0.0014270085375756025 std_pooled: 0.003981767687946558
5 -> ILSVRC2012_val_00039597 std: 0.001965473871678114 std_pooled: 0.0053999838419258595
6 -> ILSVRC2012_val_00005338 std: 0.0010363674955442548 std_pooled: 0.002890840405598283
7 -> ILSVRC2012_val_00029345 std: 0.001532531576231122 std_pooled: 0.004246695898473263
8 -> ILSVRC2012_val_00044012 std: 0.001191089628264308 std_pooled: 0.0033124773763120174
9 -> ILSVRC2012_val_00041139 std: 0.0013129771687090397 std_pooled: 0.0036860639229416847
10 -> ILSVRC2012_val_00004071 std: 0.0010650560725480318 std_pooled: 0.0029643001034855843
11 -> ILSVRC2012_val_00037341 std: 0.0016428256640210748 std_pooled: 0.004605838563293219
12 -> ILSVRC2012_v