In [9]:
import mxnet as     mx
from   mxnet import nd
from   mxnet import autograd
from mxnet.gluon.nn.basic_layers import Activation

from mxmodules import Sequential, Linear, Rect, Tanh
import types

import model_io

# Custom Layer Definition

In [2]:
from mxlrp import dense_hybrid_forward_lrp

# Model Definition

In [8]:
ctx = mx.cpu()

## a) Dummy Network

In [5]:
## ######### ##
# GLUON MODEL #
## ######### ##

net_gl = mx.gluon.nn.HybridSequential()

dense_0 = mx.gluon.nn.Dense(units=12, in_units=4, activation='tanh')
dense_1 = mx.gluon.nn.Dense(units= 1, in_units=12)

net_gl.add(dense_0)
net_gl.add(dense_1)

net_gl.collect_params().initialize(ctx=ctx)

# extract variables from gluon model
weight_0 = dense_0.weight.data()
bias_0   = dense_0.bias.data()

weight_1 = dense_1.weight.data()
bias_1   = dense_1.bias.data()

## ############## ##
# STANDALONE MODEL #
## ############## ##

net_sta = Sequential([Linear(4, 12), Tanh(), Linear(12, 1)])

net_sta.modules[0].W = weight_0.T
net_sta.modules[0].B = bias_0
net_sta.modules[2].W = weight_1.T
net_sta.modules[2].B = bias_1

##  Patch gluon gradient to LRP

In [6]:
for layer in net_gl._children:
    if layer.__class__.__name__ == 'Dense':
        layer.hybrid_forward = types.MethodType(dense_hybrid_forward_lrp, layer)

## Comparison with standalone implementation

In [7]:
# forward and backward pass test
X = nd.arange(12).reshape((3, 4))

print('Input:')
print(X)

X.attach_grad()
with autograd.record():
    dense_out = net_gl(X)

dense_out.backward(dense_out)
hm_gluon = X.grad
    
print('\nGLUON impl:')
print(dense_out)
print(hm_gluon)

print('\nStandalone impl:')
linear_out = net_sta.forward(X)
hm_standalone  = net_sta.lrp(linear_out)
print(linear_out)
print(hm_standalone)

Input:

[[  0.   1.   2.   3.]
 [  4.   5.   6.   7.]
 [  8.   9.  10.  11.]]
<NDArray 3x4 @cpu(0)>

GLUON impl:

[[-0.00355378]
 [-0.02960501]
 [-0.048284  ]]
<NDArray 3x1 @cpu(0)>

[[ 0.          0.00364889 -0.00259847 -0.0046042 ]
 [-0.03007531  0.01729889 -0.00631801 -0.01051059]
 [-0.05498287  0.02882986 -0.00622151 -0.01590948]]
<NDArray 3x4 @cpu(0)>

Standalone impl:

[[-0.00355378]
 [-0.02960501]
 [-0.048284  ]]
<NDArray 3x1 @cpu(0)>

[[ 0.          0.00364889 -0.00259847 -0.0046042 ]
 [-0.03007531  0.01729889 -0.00631801 -0.01051059]
 [-0.05498287  0.02882986 -0.00622151 -0.01590948]]
<NDArray 3x4 @cpu(0)>
