In [1]:
from keras.models import model_from_json
from keras.utils import plot_model

import bokeh
from bokeh.plotting import figure, output_file, show

from plot_utils import make_html_plot

import lime
from lime import lime_tabular
import matplotlib.pyplot as plt
import numpy as np
import os
from skimage.segmentation import mark_boundaries
import wfdb

import seaborn as sns
sns.set_style('ticks')

Using TensorFlow backend.


In [2]:
MODEL_LOCATION = 'assets/model.json'
WEIGHTS_LOCATION = 'assets/final_weights.h5'

with open(MODEL_LOCATION, 'r') as f:
    model_json = f.read()

model_to_interpret = model_from_json(model_json)
model_to_interpret.load_weights(WEIGHTS_LOCATION)

In [3]:
# test_data_dir = input('Enter test data location: ')
test_data_dir = r'G:\Team Drives\Hacktech2019\mitdb'

In [4]:
rec = wfdb.rdrecord(os.path.join(test_data_dir, '107'))

In [5]:
def make_binary_pred(input_data):
    pred = model_to_interpret.predict(input_data)
    return np.array([1 - pred, pred]).reshape(-1, 2)

In [6]:
explainer = lime_tabular.RecurrentTabularExplainer(rec.p_signal.reshape(-1, 1300, 2),
                                                   feature_names=[str(x) for x in range(320)],
                                                   class_names=['No Arrhythmic Event', 'Arrhythmic Event'])

In [7]:
sample = rec.p_signal[9000:10300].reshape(-1, 1300, 2)

In [8]:
%%time
explanation = explainer.explain_instance(sample, make_binary_pred,
                                         num_features=320,
                                         labels=[0, 1],
                                         num_samples=1000)

                    Prediction probabilties do not sum to 1, and
                    thus does not constitute a probability space.
                    Check that you classifier outputs probabilities
                    (Not log probabilities, or actual class predictions).
                    
  """)


Wall time: 38.4 s




In [9]:
# len(explanation.local_exp[0])

In [10]:
# explanation.as_pyplot_figure()

In [11]:
explanation.local_exp

{0: [(688, -0.006819010628193731),
  (124, -0.006748709912158557),
  (104, 0.0066291513040340625),
  (728, -0.006514836362053261),
  (496, 0.00644792914288511),
  (169, -0.006332290037765078),
  (442, 0.006215383059631582),
  (657, 0.006045851408624754),
  (253, 0.006039254733730956),
  (369, 0.005959588483451983),
  (326, 0.0059076795107823816),
  (453, 0.005662234059886598),
  (541, -0.005591547022935768),
  (493, -0.005550075499942133),
  (420, 0.005518610417034518),
  (505, 0.005493991078110751),
  (898, -0.005407045979703103),
  (460, -0.005339611438089483),
  (218, 0.005323319717383524),
  (556, -0.005270638645710293),
  (128, -0.005239929202750952),
  (213, -0.005208538688268057),
  (767, 0.00515417185328065),
  (149, -0.0051090414485147045),
  (466, -0.005074897525588657),
  (260, -0.0050159709227493585),
  (926, -0.004942454597749759),
  (753, 0.004940723072211716),
  (625, 0.004691740723965483),
  (473, -0.004654626963483885),
  (146, -0.004572088075983863),
  (583, 0.0045692

In [12]:
a = [np.array(explan, dtype=np.dtype('int, float')) for explan in explanation.local_exp.values()]

In [70]:
import bokeh
from bokeh.plotting import figure, output_file
from bokeh.models import ColorBar
from bokeh.palettes import PiYG
from bokeh.resources import CDN
from bokeh.embed import file_html
from bokeh.transform import linear_cmap

import matplotlib.pyplot as plt
import numpy as np
import wfdb

def make_html_plot(signal: np.ndarray, explanations=None, title: str='', channel: int=0):    
    if explanations is not None:
        indices = np.array([elem[0] for elem in explanations[channel]], dtype=np.int)
        importances = np.array([elem[1] for elem in explanations[channel]], dtype=np.float)
        importances /= 0.1 * np.max(importances)

    p = figure()
    p.line(list(range(len(signal))), signal[:, channel])
    # p.circle(list(range(len(signal))), signal[:, 0])
    
    if explanations is not None:
        p.scatter(indices, signal[:, channel][indices], color='red', radius=importances, alpha=0.7)

    return file_html(p, CDN, title)


In [72]:
with open('pp2channel1.html', 'w') as f:
    f.write(make_html_plot(sample.reshape(1300, 2), explanation.local_exp, 'channel 1', 0))
with open('pp2channel2.html', 'w') as f:
    f.write(make_html_plot(sample.reshape(1300, 2), explanation.local_exp, 'channel 2', 1))