In [None]:
from pathlib import Path
import sys
sys.path.append(str(Path().home()/"code"/"XAI"/"models"))
sys.path.append(str(Path().home()/"code"/"XAI"/"models"/"relavance"))
sys.path.append(str(Path().home()/"code"/"XAI"/"models"/"mnist"))

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec

from models.deconv.deconvnet import deconvMNIST
from models.relavance.lrp import lrpMNIST
from models.reshape import Reshape
from models.relavance.layers import relConv2d, relLinear, relMaxPool2d, relReLU
from models.mnist.MnistModels import MNISTmodel
from models.mnist.MnistTrain import build_dataset

## LRP Method

The Goal of LRP Method is calculate every relavance score for each neuron of each layer.

$\begin{aligned} 
r_i^{(L)} &= \begin{cases}S_i(x) & \text{if unit }i \text{ is the target unit of interest}\\ 0 & \text{otherwise}\end{cases}  \\ r_i^{(l)} &= \sum_j \dfrac{z_{ji}}{\sum_{i’}(z_{ji’}+b_j)+ \epsilon \cdot sign(\sum_{i’}(z_{ji’}+b_j))}r_j^{(l+1)}\\ &\text{where } z_{ji} = w_{ji}^{(l+1, l)}x_{i}^{(l)}
\end{aligned} $

* the weight fraction $\dfrac{z_{ji}}{\sum_{i’}(z_{ji’}+b_j)+ \epsilon \cdot sign(\sum_{i’}(z_{ji’}+b_j))}$ means how much share of a input neuron $x_i$ is to calculate the value of the output neuron $x_j$
* numerator means score for input neuron $i$ to output neuron $j$
* denominator means the value of output neuron $j$

### calculate with the matrix

* input feature $(1, \cdots, i, \cdots N)$
* output feature $(1, \cdots, j, \cdots M)$
* actually, weight $(N, M)$ is transposed to $(M, N)$ in pytorch

$\begin{aligned} X^{(l+1)} &= \begin{bmatrix}  x_1 & \cdots  &x_M \end{bmatrix}^T\\
X^{(l)} &= \begin{bmatrix}  x_1 & \cdots  & x_N \end{bmatrix}^T\\
W^{(l+1, l)} &= \begin{bmatrix} 
w_{11} & \cdots & w_{1i} & \cdots & w_{1N} \\ 
\vdots & \ddots & \ddots & \ddots & \vdots \\ 
w_{j1} & \ddots & w_{ji} & \ddots & w_{jN} \\
\vdots & \ddots & \ddots & \ddots & \vdots \\
w_{M1} & \cdots & w_{Mi} & \cdots & w_{MN}
\end{bmatrix}
\end{aligned}$

to see element-wise calculation ...

$\begin{aligned} 
R^{(l+1)} &= \begin{bmatrix} r_1 \\ \vdots \\ r_i \\ \vdots \\ r_N\end{bmatrix}^{(l+1)} 
= \begin{bmatrix} \sum_j^M a_{1j}r_{1j} \\ \vdots \\ \sum_j^M a_{ij}r_{ij} \\ \vdots \\ \sum_j^M a_{Nj}r_{Nj} \end{bmatrix}^{(l)} = \begin{bmatrix} a_{11}r_{11} + \cdots + a_{1M}r_{1M} \\ \vdots \\ a_{i1}r_{i1} + \cdots + a_{iM}r_{iM} \\ \vdots \\ a_{N1}r_{N1} + \cdots + a_{NM}r_{NM} \end{bmatrix}^{(l)} 
\\
Z^{(l, l+1)} &= \begin{bmatrix} 
z_{11} & \cdots & z_{1j} & \cdots & z_{1M} \\ 
\vdots & \ddots & \ddots & \ddots & \vdots \\ 
z_{i1} & \ddots & z_{ij} & \ddots & z_{ij} \\
\vdots & \ddots & \ddots & \ddots & \vdots \\
z_{N1} & \cdots & z_{Nj} & \cdots & z_{NM}
\end{bmatrix} = \begin{bmatrix} 
w_{11}x_1^{(l)} & \cdots & w_{1j}x_j^{(l)} & \cdots & w_{1M}x_M^{(l)} \\ 
\vdots & \ddots & \ddots & \ddots & \vdots \\ 
w_{i1}x_1^{(l)} & \ddots & w_{ij}x_j^{(l)} & \ddots & w_{ij}x_M^{(l)} \\
\vdots & \ddots & \ddots & \ddots & \vdots \\
w_{N1}x_1^{(l)} & \cdots & w_{Nj}x_j^{(l)} & \cdots & w_{NM}x_M^{(l)}
\end{bmatrix}
\end{aligned}$

### 1st way

to get $r_i^{(l+1)}$ where $z_{ji}^{(l+1)} = w_{ji}^{(l+1, 1)} x_i^{(l)}$ there are 4 step in Linear Layer

$\begin{aligned} 
(1) & Z^{(l, l+1)} = W^{(l, l+1)} \times X^{(l+1)}\\
(2) & S^{(l+1)} = X^{(l+1)} + \epsilon \cdot sign(X^{(l+1)}) \\
(3) & A^{(l, l+1)} = \dfrac{Z^{(l, l+1)}}{S^{(l+1)}} \\
(4) & R^{(l)} = A^{(l, l+1)}R^{(l+1)}  \\
\end{aligned}$

### 2nd way

same calculation but different order, introducing at http://heatmapping.org/tutorial/

$\begin{aligned} 
(1) & S^{(l+1)} = X^{(l+1)} + \epsilon \cdot sign(X^{(l+1)}) \\
(2) & E^{(l+1)} = \dfrac{R^{(l+1)}}{S^{(l+1)}} \\
(3) & C^{(l)} = W^{(l, l+1)} E^{(l+1)} \\
(4) & R^{(l)} = X^{(l)} \times C^{(l)}  \\
\end{aligned}$

In [None]:
a = nn.Linear(3, 2)
b = relLinear(a)
x = torch.rand(5, 3)
output = b(x)
r = torch.zeros(5, 2).scatter(1, torch.LongTensor([[1], [0], [0], [1], [0]]), 1)
r_next = b.relprop(r)
r_next

In convolutional layer to get $r_i^{(l+1)}$ there are 4 step in Conv Layer, but change step 3 computing gradient of conv. which can be replaced as Transposed convolutional layer(=fractionally strided convolutional layer)

$\begin{aligned} 
(1) & S^{(l+1)} = X^{(l+1)} + \epsilon \cdot sign(X^{(l+1)}) \\
(2) & E^{(l+1)} = \dfrac{R^{(l+1)}}{S^{(l+1)}} \\
(3) & C^{(l)} = \triangledown (\sum S^{(l+1)} \times E^{(l+1)}) \\
(4) & R^{(l)} = X^{(l)} \times C^{(l)}  \\
\end{aligned}$

In [None]:
a = nn.Conv2d(1, 32, 3)
b = relConv2d(a)
x = torch.randn(2, 1, 28, 28)
output = b(x)
r = torch.relu(output)
r_next = b.relprop(r)
r_next.size()

Maxpooling layer

In [None]:
a = nn.MaxPool2d(2, return_indices=True)
b = relMaxPool2d(a)
x = torch.randn(2, 32, 26, 26)
output, swtiches = b(x)
r_next = b.relprop(output)
r_next.size()

In [None]:
rs = Reshape()
x = torch.rand(2, 1, 12, 12)
output = rs(x)
output.size(), rs.relprop(output).size()

---

In [1]:
import torch
import torch.nn as nn
from copy import deepcopy
from collections import defaultdict, OrderedDict

class XaiHook(nn.Module):
    def __init__(self, module):
        super(XaiHook, self).__init__()
        """
        Hook Handler Module
        
        supported register `module` hooks
        - Activations
        - Linear
        - Convd
        
        like RNN have to use `register_hook` to `torch.nn.Parameter` directly
        
        * Ref: https://pytorch.org/docs/master/nn.html#torch.nn.Module.register_backward_hook
        [Warnings]
        The current implementation will not have the presented behavior 
        for complex Module that perform many operations. In some failure cases, 
        `grad_input` and `grad_output` will only contain the gradients for a subset
        of the inputs and outputs. For such `Module`, you should use 
        `torch.Tensor.register_hook()` directly on a specific input or 
        output to get the required gradients.
        
        """
        self.module = module
    
    def zero_grad(self):
        self.module.zero_grad()

    def register_hook(self, backward=False, hook_fn=None):
        """
        defalut hook_function is save (module, input, output) to (m, i, o)
        if you want to use hook function, change `hook_function` 
        if `hook_function` returns `None` then the original input or output 
        will be flow into next / previous layer, but you can return a modifed
        output/gradient to change the original output/gradient.
        for a Conv2d layer example
        - forward: a `Tensor` type output
        - backward: (gradient_input, weight, bias)
        
        """
        def default_hook_fn(m, i, o):
            """
            forward
             - m: module class
             - i: forward input from previous layer
             - o: forward output to next layer
            backward
             - m: module class
             - i: gradient input to next layer (backward out)
             - o: gradient output from previous layer (backward in)

            args:
             * i, o: tuple type
            """
            self.m = m
            self.i = i
            self.o = o
            
        if hook_fn is None:
            self.hook_fn = default_hook_fn
        else:
            self.hook_fn = hook_fn
        if not backward:
            self.hook = self.module.register_forward_hook(self.hook_fn)
        else:
            self.hook = self.module.register_backward_hook(self.hook_fn)
            
    def close(self):
        self.hook.remove()


class XaiBase(nn.Module):
    def __init__(self):
        super(XaiBase, self).__init__()
        """
        - need to define XaiHook class to use
        - defalut hook_function is save (module, input, output) to (m, i, o)
          if you want to use hook function, change `hook_function` 
        """
        self._reset_maps()
    
    def _reset_maps(self):
        self.maps = OrderedDict()
        
    def _save_maps(self, layer_name, x):
        self.maps[layer_name] = x    
        
    def _register(self, hooks, backward=False, hook_fn=None):
        """
        - need to define XaiHook class to use
        - defalut hook_function is save (module, input, output) to (m, i, o)
          if you want to use hook function, change `hook_function` 
        """
        if not isinstance(hooks, list):
            hooks = [hooks]
        for hook in hooks:
            hook.register_hook(backward=backward, hook_fn=hook_fn)
    
    def _register_forward(self, hooks, hook_fn=None):
        self._register(hooks, backward=False, hook_fn=hook_fn)
        
    def _register_backward(self, hooks, hook_fn=None):
        self._register(hooks, backward=True, hook_fn=hook_fn)
    
    def _reset_hooks(self, hooks):
        if not isinstance(hooks, list):
            hooks = [hooks]
        for hook in hooks:
            hook.close()

    def _return_indices(self, layers, on=True):
        """
        support for cnn layer which have `nn.MaxPool2d`,
        you can turn on/off pooling indices.
        please define a forward function to use it in your model
        '''
        # in your model
        def forward_switch(self, x):
            switches = OrderedDict()
            self.return_indices(on=True)
            for idx, layer in enumerate(self.convs):
                if isinstance(layer, nn.MaxPool2d):
                    x, indices = layer(x)
                    switches[idx] = indices
                else:
                    x = layer(x)
            self.return_indices(on=False)
            return x, switches
        '''
        """
        if on:
            for layer in layers:
                if isinstance(layer, nn.MaxPool2d):
                    layer.return_indices = True
        else:
            for layer in layers:
                if isinstance(layer, nn.MaxPool2d):
                    layer.return_indices = False  
                    
                    
class XaiModel(XaiBase):
    def __init__(self, model):
        super(XaiModel, self).__init__()
        self.model = deepcopy(model)
        self.model.cpu()
        self.model.eval()
        
    def _one_hot(self, targets, module_name):
        """
        one hot vectorize the target tensor for classification purpose.
        the `module` with respect to `module_name` must have `out_features` attribution.
        args:
        - targets: torch.LongTensor, target classes that have size of mini-batch
        - module_name: str, feature name for Fully-Connected Network or any Task-specific Network
        return:
        - one hot vector of targets
        """
        assert isinstance(targets, torch.LongTensor), "`targets` must be `torch.LongTensor` type"
        assert isinstance(module_name, str), "`module_name` must be `str` type"
        modules = self.model._modules[module_name]
        if isinstance(modules, nn.Sequential):
            last_layer = modules[-1]
        else:
            last_layer = modules
        try:
            last_layer.out_features 
        except AttributeError as e:
            is_linear = isinstance(last_layer, nn.Linear)
            print(f"last layer of module `{module_name}` doesn't have `out_features` attribute")
            print()
            if not is_linear:
                print(f"type of the last layer is `{type(last_layer)}`")
                print("the last layer is not `torch.nn.linear.Linear` class")
                print("create `.out_featrues` attribution in the custom module")
                
        target_size = last_layer.out_features
        B = targets.size(0)
        one_hot = torch.zeros((B, target_size))
        one_hot.scatter_(1, targets.unsqueeze(1), 1.0)
        return one_hot.to(targets.device)
    
    def _find_target_layer_idx(self, module_name, layer_names):
        assert isinstance(layer_names, list) or isinstance(layer_names, tuple), "use list for `layer_names`"
        layer_names = [l.lower() for l in layer_names]
        idxes = defaultdict(list)
        modules = self.model._modules[module_name]
        assert isinstance(modules, nn.Sequential), "use this function for `nn.Sequential` type modules"
        for idx, layer in modules.named_children():
            l_name = type(layer).__name__.lower()
            if l_name in layer_names:
                idxes[l_name].append(int(idx))

        return idxes

## Test for New LRP Layer

### relLinear

In [17]:
class relLinear(XaiHook):
    def __init__(self, module, use_rho=False):
        """
        forward
        > input: (B, in_f)
        > output: (B, out_f)
        """
        super(relLinear, self).__init__(module)
        self.use_rho = use_rho
        self.register_hook(backward=False, hook_fn=self.f_hook)
        self.register_hook(backward=True, hook_fn=self.b_hook)
        
    def __call__(self, x):
        return self.module(x)
    
    def f_hook(self, m, i, o):
        """
        forward hook
        i: (input,)
        o: output
        """
        self.input = i[0].clone().data
        self.output = o.clone().data
    
    def b_hook(self, m, i, o):
        """
        backward hook
        i: (grad_bias, input, grad_weight.T) -> backward output
        o: (output,) -> backward input
        
        ### implementation method 1
        [Step 1]: (B, in_f, 1) * (1, in_f, out_f) = (B, in_f, out_f)
        [Step 2]: (B, 1, out_f), do not multiply `torch.sign(self.output.unsqueeze(1))` 
                  that returns `nan` in tensor
        [Step 3]: divide by s
        [Step 4]: (B, in_f, out_f) x (B, out_f, 1) = (B, in_f)
        ```
        # Step 1
        z = self.input.unsqueeze(-1) * self.rho(self.weight).transpose(0, 1).unsqueeze(0)
        # Step 2
        s = self.output.unsqueeze(1) + eps * torch.sign(self.output.unsqueeze(1))  
        # Step 3
        weight = z / s
        # Step 4
        r_next = torch.bmm(weight, r.unsqueeze(-1)).squeeze()
        ```
        ### implemetation method 2
        # Step 1: (B, out_f), do not multiply `torch.sign(self.output)` that returns `nan` in tensor
        # Step 2: (B, out_f) / (B, out_f) = (B, out_f)
        # Step 3: (B, in_f, out_f) * (B, out_f, 1) = (B, in_f)
        # Step 4: (B, in_f) x (B, in_f) = (B, in_f)
        ```
        # Step 1
        s = self.output + eps
        # Step 2
        e = r / s
        # Step 3
        c = torch.bmm(w.transpose(0, 1).expand(e.size(0), self.in_features, self.out_features), 
                      e.unsqueeze(-1)).squeeze(-1)
        # Step 4
        r_next = self.input * c
        ```
        """
        grad_bias, _, grad_weight = i
        r = o[0]
        eps = 1e-6
        w = self.rho(self.module.weight).data
        # Step 1
        s = self.output + eps
        # Step 2
        e = r / s
        # Step 3
        c = torch.bmm(w.transpose(0, 1).expand(e.size(0), 
                                               self.module.in_features, 
                                               self.module.out_features), 
                      e.unsqueeze(-1)).squeeze(-1)
        # Step 4
        r_next = self.input * c
        assert r_next.size(1) == self.module.in_features, "size of `r_next` is not correct"
        # for debugging
        self.r = r  
        self.r_next = r_next
        return (grad_bias, r_next, grad_weight)
        
    def rho(self, w):
        if self.use_rho:
            return torch.clamp(w, min=0)
        else:
            return w

In [20]:
def hook(m, i, o):
    print(m)
    print(type(i))
    for tensor in i:
        print(tensor)
    print(type(o))
    print(o)
    print()
relu = nn.ReLU()
# relu.register_backward_hook(hook)
linear_p = nn.Linear(5, 4)
# linear_p.weight.data = torch.FloatTensor([[1, 2, 3], [1, 2, 3]])
# linear_p.bias.data = torch.FloatTensor([1, 2])

a = relLinear(linear_p)
# linear_p.register_backward_hook(hook)

linear = nn.Linear(4, 3)
# linear.weight.data = torch.FloatTensor([[2, 3]])
# linear.bias.data = torch.FloatTensor([1])

b = relLinear(linear)
# linear.register_backward_hook(hook)

In [21]:
x = torch.randn(1, 5).requires_grad_(True)
o1 = a(x)
o2 = relu(o1)
o3 = b(o2)
o3.size()

torch.Size([1, 3])

In [23]:
o3.backward(torch.FloatTensor([[1, 0, 0]]))

In [24]:
# if all equal returns True
(x.grad == a.r_next).sum() == x.grad.view(-1).size(0)

tensor(True)

### relConv2d

In [2]:
class relConv2d(XaiHook):
    """relConv2d"""
    def __init__(self, module, use_rho=False):
        """
        forward
        > input: (B, C_in, H_in, W_in)
        > output: (B, C_out, H_out, W_out)
        backward
        > lrp propagation with respect to previous input
        """
        super(relConv2d, self).__init__(module)
        self.use_rho = use_rho
        self.register_hook(backward=False, hook_fn=self.f_hook)
        self.register_hook(backward=True, hook_fn=self.b_hook)
        
    def __call__(self, x):
        return self.module(x)
    
    def f_hook(self, m, i, o):
        """
        forward hook
        i: (input,)
        o: output
        """
        self.input = i[0].clone().data
        self.output = o.clone().data
    
    def b_hook(self, m, i, o):
        """
        backward hook
        i: (grad_input, grad_weight, gard_bias) -> backward output
        o: (gard_output,) -> backward input
        
        ### implementation method 
        [Step 1]: (B, C_out, H_out, W_out), do not multiply `torch.sign(self.output)` 
                   that returns `nan` in tensor
        [Step 2]: (B, C_out, H_out, W_out) / (B, C_out, H_out, W_out) = (B, C_out, H_out, W_out)
        [Step 3]: (B, C_out, H_out, W_out) --> (B, C_in, H, W)
                  same as `self.gradprop(s*e)` or `(s*e).backward(); c=self.input.grad`
        [Stpe 4]: (B, C_in, H, W) x (B, C_in, H, W) = (B, C_in, H, W)
        
        ```
        # Step 1
        s = self.output + eps 
        # Step 2
        e = r / s
        # Step 3:
        c = self.gradprop(e)
        # Step 4
        r_next = self.input * c
        ```
        """
        _, grad_weight, grad_bias = i
        r = o[0]
        eps = 1e-6
        w = self.rho(self.module.weight)
        # Step 1
        s = self.output + eps 
        # Step 2
        e = r / s
        # Step 3:
        c = self.gradprop(e)
        # Step 4
        r_next = self.input * c

        # for debugging
        self.r = r  
        self.r_next = r_next
        return (r_next, grad_weight, grad_bias)
        
    def rho(self, w):
        if self.use_rho:
            return torch.clamp(w, min=0)
        else:
            return w

    def gradprop(self, x):
        """
        `ConvTransposed2d` can be seen as the gradient of `Conv2d` with respect to its input.
        """
        output_padding = self.cal_output_padding()
        c = torch.nn.functional.conv_transpose2d(x, 
                                                 weight=self.module.weight, 
                                                 stride=self.module.stride, 
                                                 padding=self.module.padding, 
                                                 output_padding=output_padding)
        return c        

    def cal_output_padding(self):
        """
        calculate output_padding size
        - size of height or width: (X_in + 2P - K) / S + 1 = X_out
        - output_padding = X_in - ((X_out - 1) * S + K - 2P)

        * what is output_padding?
        from PyTorch Document:
        https://pytorch.org/docs/stable/nn.html#convtranspose2d

        The padding argument effectively adds `dilation * (kernel_size - 1) - padding` amount of zero padding to 
        both sizes of the input. This is set so that when a `Conv2d` and a `ConvTranspose2d` are initialized with 
        same parameters, they are inverses of each other in regard to the input and output shapes. 
        However, when `stride > 1`, `Conv2d` maps multiple input shapes to the same output shape. 
        `output_padding` is provided to resolve this ambiguity by effectively increasing 
        the calculated output shape on one side. Note that output_padding is only used to find output shape, 
        but does not actually add zero-padding to output.
        """
        H_in, W_in = self.input.size()[2:]
        H_out, W_out = self.output.size()[2:]
        S_h, S_w = self.module.stride
        K_h, K_w = self.module.kernel_size
        P_h, P_w = self.module.padding
        H_output_padding = H_in - ((H_out - 1)*S_h + K_h - 2*P_h)
        W_output_padding = W_in - ((W_out - 1)*S_w + K_w - 2*P_w)
        return (H_output_padding, W_output_padding)

In [3]:
def hook(m, i, o):
    print(m)
    print(type(i))
    for tensor in i:
        print(tensor)
    print(type(o))
    print(o)
    print()
relu = nn.ReLU()
# relu.register_backward_hook(hook)
conv_p = nn.Conv2d(1, 3, 3)

a = relConv2d(conv_p)
# conv_p.register_backward_hook(hook)

conv = nn.Conv2d(3, 5, 3)

b = relConv2d(conv)
# conv.register_backward_hook(hook)

In [4]:
x = torch.randn(1, 1, 11, 11).requires_grad_(True)
o1 = a(x)
o2 = relu(o1)
o3 = b(o2)
o3.size()

torch.Size([1, 5, 7, 7])

In [5]:
o3.sum().backward()

In [6]:
# if all equal returns True
(x.grad == a.r_next).sum() == x.grad.view(-1).size(0)

tensor(True)

### relMaxPool2d

In [12]:
class relMaxPool2d(XaiHook):
    """relMaxPool2d"""
    def __init__(self, module, use_rho=False):
        """
        forward
        > input: (B, C, H_in, W_in)
        > output: (B, C, H_out, W_out)
        backward
        > lrp propagation with respect to previous input
        """
        super(relMaxPool2d, self).__init__(module)
        self.use_rho = use_rho
        self.register_hook(backward=False, hook_fn=self.f_hook)
        self.register_hook(backward=True, hook_fn=self.b_hook)
            
    def __call__(self, x):
        return self.module(x)
    
    def f_hook(self, m, i, o):
        """
        forward hook
        i: (input,)
        o: output
        
        save forward input and output data
        """
        self.input = i[0].clone().data
        self.output = o.clone().data
        
    def b_hook(self, m, i, o):
        """
        backward hook
        i: (grad_input,) -> backward output
        o: (gard_output,) -> backward input
        
        ### implementation method 
        [Step 1]: (B, C, H_out, W_out), do not multiply `torch.sign(self.output)` 
                  that returns `nan` in tensor
        [Step 2]: (B, C, H_out, W_out) / (B, C, H_out, W_out) = (B, C, H_out, W_out)
        [Step 3]: (B, C, H_out, W_out) --> (B, C, H_in, W_in)
                  same as `self.gradprop(s*e)` or `(s*e).backward(); c=self.input.grad`
        [Stpe 4]: (B, C, H_in, W_in) x (B, C, H_in, W_in) = (B, C, H_in, W_in)
        
        ```
        # Step 1
        s = self.output + eps 
        # Step 2
        e = r / s
        # Step 3:
        c = self.gradprop(e)
        # Step 4
        r_next = self.input * c
        ```
        """        
        r = o[0]
        eps = 1e-6
        # Step 1
        s = self.output + eps
        # Step 2
        e = r / s
        # Step 3
        c = self.gradprop(e)
        # Step 4
        r_next = self.input * c
        
        # for debugging
        self.r = r  
        self.r_next = r_next
        
        return (r_next,)
    
    def gradprop(self, x):
        """
        get maxpooled switches first then unpool
        """
        _, switches = torch.nn.functional.max_pool2d(self.input, 
                                                     self.module.kernel_size, 
                                                     self.module.stride, 
                                                     self.module.padding, 
                                                     self.module.dilation, 
                                                     self.module.ceil_mode, 
                                                     return_indices=True)
        c = torch.nn.functional.max_unpool2d(x, switches, 
                                             self.module.kernel_size, 
                                             self.module.stride, 
                                             self.module.padding)
        return c

In [13]:
def hook(m, i, o):
    print(m)
    print(type(i))
    for tensor in i:
        print(tensor)
    print(type(o))
    print(o)
    print()
relu = nn.ReLU()
# relu.register_backward_hook(hook)
maxpool_p = nn.MaxPool2d(2)

a = relMaxPool2d(maxpool_p)
# conv_p.register_backward_hook(hook)

maxpool = nn.MaxPool2d(2)

b = relMaxPool2d(maxpool)
# conv.register_backward_hook(hook)

In [14]:
x = torch.randn(1, 1, 8, 8).requires_grad_(True)
o1 = a(x)
o2 = relu(o1)
o3 = b(o2)
o3.size()

torch.Size([1, 1, 2, 2])

In [15]:
o3.sum().backward()

In [16]:
# if all equal returns True
(x.grad == a.r_next).sum() == x.grad.view(-1).size(0)

tensor(True)