In [None]:
from pathlib import Path
import sys
sys.path.append(str(Path().home()/"code"/"XAI"/"models"))
sys.path.append(str(Path().home()/"code"/"XAI"/"models"/"relavance"))
sys.path.append(str(Path().home()/"code"/"XAI"/"models"/"mnist"))

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec

from models.deconv.deconvnet import deconvMNIST
from models.relavance.lrp import lrpMNIST
from models.reshape import Reshape
from models.relavance.layers import relConv2d, relLinear, relMaxPool2d, relReLU
from models.mnist.MnistModels import MNISTmodel
from models.mnist.MnistTrain import build_dataset

## LRP Method

$\begin{aligned} 
r_i^{(L)} &= \begin{cases}S_i(x) & \text{if unit }i \text{ is the target unit of interest}\\ 0 & \text{otherwise}\end{cases}  \\ r_i^{(l)} &= \sum_j \dfrac{z_{ji}}{\sum_{i’}(z_{ji’}+b_j)+ \epsilon \cdot sign(\sum_{i’}(z_{ji’}+b_j))}r_j^{(l+1)}\\ &\text{where } z_{ji} = w_{ji}^{(l+1, l)}x_{i}^{(l)}
\end{aligned} $


input feature $(1, \cdots, i, \cdots N)$, output feature $(1, \cdots, j, \cdots M)$

weight $(N, M)$ is transposed to $(M, N)$ in pytorch

$\begin{aligned} X^{(l+1)} &= \begin{bmatrix}  x_1 & \cdots  &x_M \end{bmatrix}^T\\
X^{(l)} &= \begin{bmatrix}  x_1 & \cdots  & x_N \end{bmatrix}^T\\
W^{(l+1, l)} &= \begin{bmatrix} 
w_{11} & \cdots & w_{1i} & \cdots & w_{1N} \\ 
\vdots & \ddots & \ddots & \ddots & \vdots \\ 
w_{j1} & \ddots & w_{ji} & \ddots & w_{jN} \\
\vdots & \ddots & \ddots & \ddots & \vdots \\
w_{M1} & \cdots & w_{Mi} & \cdots & w_{MN}
\end{bmatrix}
\end{aligned}$

to see element-wise calculation ...

$\begin{aligned} 
R^{(l)} &= \begin{bmatrix} r_1 \\ \vdots \\ r_i \\ \vdots \\ r_N\end{bmatrix}^{(l)} = \begin{bmatrix} \sum_j^M a_{1j}r_{1j} \\ \vdots \\ \sum_j^M a_{ij}r_{ij} \\ \vdots \\ \sum_j^M a_{Nj}r_{Nj} \end{bmatrix}^{(l)} = \begin{bmatrix} a_{11}r_{11} + \cdots + a_{1M}r_{1M} \\ \vdots \\ a_{i1}r_{i1} + \cdots + a_{iM}r_{iM} \\ \vdots \\ a_{N1}r_{N1} + \cdots + a_{NM}r_{NM} \end{bmatrix}^{(l)} 
\\
Z^{(l, l+1)} &= \begin{bmatrix} 
z_{11} & \cdots & z_{1j} & \cdots & z_{1M} \\ 
\vdots & \ddots & \ddots & \ddots & \vdots \\ 
z_{i1} & \ddots & z_{ij} & \ddots & z_{ij} \\
\vdots & \ddots & \ddots & \ddots & \vdots \\
z_{N1} & \cdots & z_{Nj} & \cdots & z_{NM}
\end{bmatrix} = \begin{bmatrix} 
w_{11}x_1^{(l)} & \cdots & w_{1j}x_j^{(l)} & \cdots & w_{1M}x_M^{(l)} \\ 
\vdots & \ddots & \ddots & \ddots & \vdots \\ 
w_{i1}x_1^{(l)} & \ddots & w_{ij}x_j^{(l)} & \ddots & w_{ij}x_M^{(l)} \\
\vdots & \ddots & \ddots & \ddots & \vdots \\
w_{N1}x_1^{(l)} & \cdots & w_{Nj}x_j^{(l)} & \cdots & w_{NM}x_M^{(l)}
\end{bmatrix}
\end{aligned}$

### 1st way

to get $r_i^{(l+1)}$ where $z_{ji}^{(l+1)} = w_{ji}^{(l+1, 1)} x_i^{(l)}$ there are 4 step in Linear Layer

$\begin{aligned} 
(1) & Z^{(l, l+1)} = W^{(l, l+1)} \times X^{(l+1)}\\
(2) & S^{(l+1)} = X^{(l+1)} + \epsilon \cdot sign(X^{(l+1)}) \\
(3) & A^{(l, l+1)} = \dfrac{Z^{(l, l+1)}}{S^{(l+1)}} \\
(4) & R^{(l)} = A^{(l, l+1)}R^{(l+1)}  \\
\end{aligned}$

### 2nd way

same calculation but different order, introducing at http://heatmapping.org/tutorial/

$\begin{aligned} 
(1) & S^{(l+1)} = X^{(l+1)} + \epsilon \cdot sign(X^{(l+1)}) \\
(2) & E^{(l+1)} = \dfrac{R^{(l+1)}}{S^{(l+1)}} \\
(3) & C^{(l)} = W^{(l, l+1)} E^{(l+1)} \\
(4) & R^{(l)} = X^{(l)} \times C^{(l)}  \\
\end{aligned}$

In [None]:
a = nn.Linear(3, 2)
b = relLinear(a)
x = torch.rand(5, 3)
output = b(x)
r = torch.zeros(5, 2).scatter(1, torch.LongTensor([[1], [0], [0], [1], [0]]), 1)
r_next = b.relprop(r)
r_next

In convolutional layer to get $r_i^{(l+1)}$ there are 4 step in Conv Layer, but change step 3 computing gradient of conv. which can be replaced as Transposed convolutional layer(=fractionally strided convolutional layer)

$\begin{aligned} 
(1) & S^{(l+1)} = X^{(l+1)} + \epsilon \cdot sign(X^{(l+1)}) \\
(2) & E^{(l+1)} = \dfrac{R^{(l+1)}}{S^{(l+1)}} \\
(3) & C^{(l)} = \triangledown (\sum S^{(l+1)} \times E^{(l+1)}) \\
(4) & R^{(l)} = X^{(l)} \times C^{(l)}  \\
\end{aligned}$

In [None]:
a = nn.Conv2d(1, 32, 3)
b = relConv2d(a)
x = torch.randn(2, 1, 28, 28)
output = b(x)
r = torch.relu(output)
r_next = b.relprop(r)
r_next.size()

Maxpooling layer

In [None]:
a = nn.MaxPool2d(2, return_indices=True)
b = relMaxPool2d(a)
x = torch.randn(2, 32, 26, 26)
output, swtiches = b(x)
r_next = b.relprop(output)
r_next.size()

In [None]:
rs = Reshape()
x = torch.rand(2, 1, 12, 12)
output = rs(x)
output.size(), rs.relprop(output).size()