-
Notifications
You must be signed in to change notification settings - Fork 90
/
lrp_cnn_demo.py
142 lines (111 loc) · 5.42 KB
/
lrp_cnn_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
'''
@author: Sebastian Lapuschkin
@maintainer: Sebastian Lapuschkin
@contact: sebastian.lapuschkin@hhi.fraunhofer.de, wojciech.samek@hhi.fraunhofer.de
@date: 25.10.2016
@version: 1.2+
@copyright: Copyright (c) 2015-2017, Sebastian Lapuschkin, Alexander Binder, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license : BSD-2-Clause
The purpose of this module is to demonstrate the process of obtaining pixel-wise explanations for given data points at hand of the MNIST hand written digit data set
with CNN models, using the LeNet-5 architecture.
The module first loads a pre-trained neural network model and the MNIST test set with labels and transforms the data such that each pixel value is within the range of [-1 1].
The data is then randomly permuted and for the first 10 samples due to the permuted order, a prediction is computed by the network, which is then as a next step explained
by attributing relevance values to each of the input pixels.
finally, the resulting heatmap is rendered as an image and (over)written out to disk and displayed.
'''
import matplotlib.pyplot as plt
import numpy
import time
import numpy as np
import importlib.util as imp
if imp.find_spec("cupy"): #use cupy for GPU support if available
import cupy
import cupy as np
import model_io
import data_io
import render
#load a neural network, as well as the MNIST test data and some labels
nn = model_io.read('../models/MNIST/LeNet-5.nn') # 99.23% prediction accuracy
nn.drop_softmax_output_layer() #drop softnax output layer for analyses
X = data_io.read('../data/MNIST/test_images.npy')
Y = data_io.read('../data/MNIST/test_labels.npy')
# transfer pixel values from [0 255] to [-1 1] to satisfy the expected input / training paradigm of the model
X = X / 127.5 - 1.
#reshape the vector representations in X to match the requirements of the CNN input
X = np.reshape(X,[X.shape[0],28,28,1])
X = np.pad(X,((0,0),(2,2),(2,2),(0,0)), 'constant', constant_values = (-1.,))
# transform numeric class labels to vector indicator for uniformity. assume presence of all classes within the label set
I = Y[:,0].astype(int)
Y = np.zeros([X.shape[0],np.unique(Y).size])
Y[np.arange(Y.shape[0]),I] = 1
acc = np.mean(np.argmax(nn.forward(X), axis=1) == np.argmax(Y, axis=1))
if not np == numpy: # np=cupy
acc = np.asnumpy(acc)
print('model test accuracy is: {:0.4f}'.format(acc))
#permute data order for demonstration. or not. your choice.
I = np.arange(X.shape[0])
#I = np.random.permutation(I)
#predict and perform LRP for the 10 first samples
for i in I[:10]:
x = X[i:i+1,...]
#forward pass and prediction
ypred = nn.forward(x)
print('True Class: ', np.argmax(Y[i]))
print('Predicted Class:', np.argmax(ypred),'\n')
#prepare initial relevance to reflect the model's dominant prediction (ie depopulate non-dominant output neurons)
mask = np.zeros_like(ypred)
mask[:,np.argmax(ypred)] = 1
Rinit = ypred*mask
#compute first layer relevance according to prediction
#R = nn.lrp(Rinit) #as Eq(56) from DOI: 10.1371/journal.pone.0130140
R = nn.lrp(Rinit,'epsilon',1.) #as Eq(58) from DOI: 10.1371/journal.pone.0130140
#R = nn.lrp(Rinit,'alphabeta',2) #as Eq(60) from DOI: 10.1371/journal.pone.0130140
#R = nn.lrp(ypred*Y[na,i],'epsilon',1.) #compute first layer relevance according to the true class label
'''
#compute first layer relvance for an arbitrarily selected class
for yselect in range(10):
yselect = (np.arange(Y.shape[1])[na,:] == yselect)*1.
R = nn.lrp(ypred*yselect,'epsilon',0.1)
'''
'''
# you may also specify different decompositions for each layer, e.g. as below:
# first, set all layers (by calling set_lrp_parameters on the container module
# of class Sequential) to perform alpha-beta decomposition with alpha = 1.
# this causes the resulting relevance map to display excitation potential for the prediction
#
nn.set_lrp_parameters('alpha',1.)
#
# set the first layer (a convolutional layer) decomposition variant to 'w^2'. This may be especially
# usefill if input values are ranged [0 V], with 0 being a frequent occurrence, but one still wishes to know about
# the relevance feedback propagated to the pixels below the filter
# the result with display relevance in important areas despite zero input activation energy.
#
nn.modules[0].set_lrp_parameters('ww') # also try 'flat'
# compute the relevance map
R = nn.lrp(Rinit)
'''
#sum over the third (color channel) axis. not necessary here, but for color images it would be.
R = R.sum(axis=3)
#same for input. create brightness image in [0,1].
xs = ((x+1.)/2.).sum(axis=3)
if not np == numpy: # np=cupy
xs = np.asnumpy(xs)
R = np.asnumpy(R)
#render input and heatmap as rgb images
digit = render.digit_to_rgb(xs, scaling = 3)
hm = render.hm_to_rgb(R, X = xs, scaling = 3, sigma = 2)
digit_hm = render.save_image([digit,hm],'../heatmap.png')
data_io.write(R,'../heatmap.npy')
#display the image as written to file
plt.imshow(digit_hm, interpolation = 'none')
plt.axis('off')
plt.show()
#note that modules.Sequential allows for batch processing inputs
if True:
N = 256
t_start = time.time()
x = X[:N,...]
y = nn.forward(x)
R = nn.lrp(y)
data_io.write(R,'../Rbatch.npy')
print('Computation of {} heatmaps using {} in {:.3f}s'.format(N, np.__name__, time.time() - t_start))