In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/code/auto-diagnosis//')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs
th.backends.cudnn.benchmark = True

In [None]:
from autodiag.dataset import DiagnosisSet

In [None]:
import resampy
sampling_freq = 100
sec_to_cut = 60
duration_recording_mins = 0.05
max_recording_mins = None
n_recordings = 440
max_abs_val = 800
divisor = 20

cuda = True
preproc_functions = []
preproc_functions.append(
    lambda data, fs: (data[:, int(sec_to_cut * fs):-int(
        sec_to_cut * fs)], fs))
preproc_functions.append(
    lambda data, fs: (data[:, :int(np.round(256.25 * (fs/100)))], fs))
if max_abs_val is not None:
    preproc_functions.append(lambda data, fs:
                             (np.clip(data, -max_abs_val, max_abs_val), fs))

preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,
                                                            sampling_freq,
                                                            axis=1,
                                                            filter='kaiser_fast'),
                                           sampling_freq))

if divisor is not None:
    preproc_functions.append(lambda data, fs: (data / divisor, fs))

data_folders = ['/data/schirrmr//gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/train/normal/',
                '/data/schirrmr//gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/train/abnormal/']
sensor_types = ["EEG"]
dataset = DiagnosisSet(n_recordings=n_recordings,
                       max_recording_mins=max_recording_mins,
                       preproc_functions=preproc_functions,
                       train_or_eval='train',
                       sensor_types=sensor_types)

In [None]:
X,y = dataset.load()

In [None]:
X = np.array(X)

In [None]:
X = (X - X.mean(axis=(0,2), keepdims=True)) /  X.std(axis=(0,2), keepdims=True)

In [None]:
X_a = X[y==0]
X_b = X[y==1]
y_a = y[y==0]
y_b = y[y==1]

In [None]:
train_inputs = [np_to_var(X_a[:,:,:,None], dtype=np.float32),
               np_to_var(X_b[:,:,:,None], dtype=np.float32),]
test_inputs = [t[-40:] for t in train_inputs]
train_inputs = [t[:-40] for t in train_inputs]

In [None]:
cuda = True
if cuda:
    train_inputs = [i.cuda() for i in train_inputs]
    test_inputs = [i.cuda() for i in test_inputs]

In [None]:
from reversible2.graph import Node
from reversible2.branching import CatChans, ChunkChans, Select
from reversible2.constantmemory import sequential_to_constant_memory
from reversible2.constantmemory import graph_to_constant_memory
def invert(feature_model, out):
    return feature_model.invert(out)

from copy import deepcopy
from reversible2.graph import Node
from reversible2.distribution import TwoClassDist
from reversible2.wrap_invertible import WrapInvertible
from reversible2.blocks import dense_add_no_switch, conv_add_3x3_no_switch
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter

set_random_seeds(2019011641, cuda)
n_chans = train_inputs[0].shape[1]
n_time = train_inputs[0].shape[2]
base_model = nn.Sequential(
    SubsampleSplitter(stride=[2,1],chunk_chans_first=False),
    conv_add_3x3_no_switch(2*n_chans,32),
    conv_add_3x3_no_switch(2*n_chans,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 4 x 128
    conv_add_3x3_no_switch(4*n_chans,32),
    conv_add_3x3_no_switch(4*n_chans,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 8 x 64
    conv_add_3x3_no_switch(8*n_chans,32),
    conv_add_3x3_no_switch(8*n_chans,32))
base_model.cuda();

branch_1_a =  nn.Sequential(
    SubsampleSplitter(stride=[2,1],chunk_chans_first=False), # 8 x 32
    conv_add_3x3_no_switch(8*n_chans,32),
    conv_add_3x3_no_switch(8*n_chans,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True),# 16 x 16
    conv_add_3x3_no_switch(16*n_chans,32),
    conv_add_3x3_no_switch(16*n_chans,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 32 x 8
    conv_add_3x3_no_switch(32*n_chans,32),
    conv_add_3x3_no_switch(32*n_chans,32),
)
branch_1_b = nn.Sequential(
    *(list(deepcopy(branch_1_a).children()) + [
    ViewAs((-1, 32*n_chans,n_time//64,1), (-1,(n_time // 2)*n_chans)),
    dense_add_no_switch((n_time // 2)*n_chans,32),
    dense_add_no_switch((n_time // 2)*n_chans,32),
    dense_add_no_switch((n_time // 2)*n_chans,32),
    dense_add_no_switch((n_time // 2)*n_chans,32),
]))
branch_1_a.cuda();
branch_1_b.cuda();

branch_2_a = nn.Sequential(
    SubsampleSplitter(stride=[2,1], chunk_chans_first=False),# 32 x 4
    conv_add_3x3_no_switch(32*n_chans,32),
    conv_add_3x3_no_switch(32*n_chans,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True),# 64 x 2
    conv_add_3x3_no_switch(64*n_chans,32),
    conv_add_3x3_no_switch(64*n_chans,32),
    ViewAs((-1, (n_time // 4)*n_chans,1,1), (-1,(n_time // 4)*n_chans)),
    dense_add_no_switch((n_time // 4)*n_chans,64),
    dense_add_no_switch((n_time // 4)*n_chans,64),
    dense_add_no_switch((n_time // 4)*n_chans,64),
    dense_add_no_switch((n_time // 4)*n_chans,64),
)


branch_2_b = deepcopy(branch_2_a).cuda()
branch_2_a.cuda();
branch_2_b.cuda();

final_model = nn.Sequential(
    dense_add_no_switch(n_time*n_chans,256),
    dense_add_no_switch(n_time*n_chans,256),
    dense_add_no_switch(n_time*n_chans,256),
    dense_add_no_switch(n_time*n_chans,256),
    RFFT(),
)
final_model.cuda();
o = Node(None, base_model)
o = Node(o, ChunkChans(2))
o1a = Node(o, Select(0))
o1b = Node(o, Select(1))
o1a = Node(o1a, branch_1_a)
o1b = Node(o1b, branch_1_b)
o2 = Node(o1a, ChunkChans(2))
o2a = Node(o2, Select(0))
o2b = Node(o2, Select(1))
o2a = Node(o2a, branch_2_a)
o2b = Node(o2b, branch_2_b)
o = Node([o1b,o2a,o2b], CatChans())
o = Node(o, final_model)
o = graph_to_constant_memory(o)
feature_model = o
if cuda:
    feature_model.cuda()
feature_model.eval();

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
from reversible2.distribution import TwoClassDist

feature_model.data_init(th.cat((train_inputs[0], train_inputs[1]), dim=0))

# Check that forward + inverse is really identical
t_out = feature_model(train_inputs[0][:2])
inverted = invert(feature_model, t_out)
clear_ctx_dicts(feature_model)
assert th.allclose(train_inputs[0][:2], inverted, rtol=1e-3,atol=1e-4)
device = list(feature_model.parameters())[0].device
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2, np.prod(train_inputs[0].size()[1:]) - 2, i_class_inds=[0,1])
class_dist.cuda()

for i_class in range(2):
    with th.no_grad():
        this_outs = feature_model(train_inputs[i_class])
        mean = th.mean(this_outs, dim=0)
        std = th.std(this_outs, dim=0)
        class_dist.set_mean_std(i_class, mean, std)
        # Just check
        setted_mean, setted_std = class_dist.get_mean_std(i_class)
        assert th.allclose(mean, setted_mean)
        assert th.allclose(std, setted_std)
clear_ctx_dicts(feature_model)

optim_model = th.optim.Adam(feature_model.parameters(), lr=1e-3, betas=(0.9,0.999))
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2, betas=(0.9,0.999))

In [None]:
%%writefile plot.py
import torch as th
import matplotlib.pyplot as plt
import numpy as np
from reversible2.util import var_to_np
from reversible2.plot import display_close
from matplotlib.patches import Ellipse
import seaborn

def plot_outs(feature_model, train_inputs, test_inputs, class_dist):
    with th.no_grad():
        # Compute dist for mean/std of encodings
        data_cls_dists = []
        for i_class in range(len(train_inputs)):
            this_class_outs = feature_model(train_inputs[i_class])[:,class_dist.i_class_inds]
            data_cls_dists.append(
                th.distributions.MultivariateNormal(th.mean(this_class_outs, dim=0),
                covariance_matrix=th.diag(th.std(this_class_outs, dim=0) ** 2)))
        for setname, set_inputs in (("Train", train_inputs), ("Test", test_inputs)):

            outs = [feature_model(ins) for ins in set_inputs]
            c_outs = [o[:,class_dist.i_class_inds] for o in outs]

            c_outs_all = th.cat(c_outs)

            cls_dists = []
            for i_class in range(len(c_outs)):
                mean, std = class_dist.get_mean_std(i_class)
                cls_dists.append(
                    th.distributions.MultivariateNormal(mean[class_dist.i_class_inds],
                                                        covariance_matrix=th.diag(std[class_dist.i_class_inds] ** 2)))

            preds_per_class = [th.stack([cls_dists[i_cls].log_prob(c_out)
                             for i_cls in range(len(cls_dists))],
                            dim=-1) for c_out in c_outs]

            pred_labels_per_class = [np.argmax(var_to_np(preds), axis=1)
                           for preds in preds_per_class]

            labels = np.concatenate([np.ones(len(set_inputs[i_cls])) * i_cls 
             for i_cls in range(len(train_inputs))])

            acc = np.mean(labels == np.concatenate(pred_labels_per_class))

            data_preds_per_class = [th.stack([data_cls_dists[i_cls].log_prob(c_out)
                             for i_cls in range(len(cls_dists))],
                            dim=-1) for c_out in c_outs]
            data_pred_labels_per_class = [np.argmax(var_to_np(data_preds), axis=1)
                                for data_preds in data_preds_per_class]
            data_acc = np.mean(labels == np.concatenate(data_pred_labels_per_class))

            print("{:s} Accuracy: {:.1f}%".format(setname, acc * 100))
            fig = plt.figure(figsize=(5,5))
            ax = plt.gca()
            for i_class in range(len(c_outs)):
                #if i_class == 0:
                #    continue
                o = var_to_np(c_outs[i_class]).squeeze()
                incorrect_pred_mask = pred_labels_per_class[i_class] != i_class
                plt.scatter(o[:,0], o[:,1], s=20, alpha=0.75, label=["Right", "Rest"][i_class])
                assert len(incorrect_pred_mask) == len(o)
                plt.scatter(o[incorrect_pred_mask,0], o[incorrect_pred_mask,1], marker='x', color='black',
                           alpha=1, s=5)
                means, stds = class_dist.get_mean_std(i_class)
                means = var_to_np(means)[class_dist.i_class_inds]
                stds = var_to_np(stds)[class_dist.i_class_inds]
                for sigma in [0.5,1,2,3]:
                    ellipse = Ellipse(means, stds[0]*sigma, stds[1]*sigma)
                    ax.add_artist(ellipse)
                    ellipse.set_edgecolor(seaborn.color_palette()[i_class])
                    ellipse.set_facecolor("None")
            for i_class in range(len(c_outs)):
                o = var_to_np(c_outs[i_class]).squeeze()
                plt.scatter(np.mean(o[:,0]), np.mean(o[:,1]),
                           color=seaborn.color_palette()[i_class+2], s=80, marker="^",
                           label=["Right Mean", "Rest Mean"][i_class])

            plt.title("{:6s} Accuracy:        {:.1f}%\n"
                      "From data mean/std: {:.1f}%".format(setname, acc * 100, data_acc * 100))
            plt.legend(bbox_to_anchor=(1,1,0,0))
            display_close(fig)
    return

In [None]:
import pandas as pd
df = pd.DataFrame()

from reversible2.training import OTTrainer
trainer = OTTrainer(feature_model, class_dist,
                optim_model, optim_dist)

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
from reversible2.timer import Timer
from plot import plot_outs
from reversible2.gradient_penalty import gradient_penalty


i_start_epoch_out = 2001
n_epochs = 1001
gen_frequency = 10
for i_epoch in range(n_epochs):
    epoch_row = {}
    with Timer(name='EpochLoop', verbose=False) as loop_time:
        gen_update = (i_epoch % gen_frequency) == (gen_frequency-1)
        loss_on_outs = i_epoch >= i_start_epoch_out
        result = trainer.train(train_inputs, loss_on_outs=loss_on_outs)
        
    epoch_row.update(result)
    epoch_row['runtime'] = loop_time.elapsed_secs * 1000
    if i_epoch % (n_epochs // 20) != 0:
        df = df.append(epoch_row, ignore_index=True)
        # otherwise add ot loss in
    else:
        for i_class in range(len(train_inputs)):
            with th.no_grad():
                class_ins = train_inputs[i_class]
                samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 4)
                inverted = feature_model.invert(samples)
                clear_ctx_dicts(feature_model)
                ot_loss_in = ot_euclidean_loss_for_samples(class_ins.view(class_ins.shape[0], -1),
                                                           inverted.view(inverted.shape[0], -1)[:(len(class_ins))])
                epoch_row['ot_loss_in_{:d}'.format(i_class)] = ot_loss_in.item()
        df = df.append(epoch_row, ignore_index=True)
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
        display(df.iloc[-3:])
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)
    


In [None]:

ch_names = ['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1',
                    'FP2', 'FZ', 'O1', 'O2',
                    'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']

In [None]:
sensor_map = [['', '', '', 'FP1', '', 'FP2', '', '', '',],
  ['F7','', 'F3', '', 'Fz', '', 'F4', '', 'F8',],
 #['', '', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', '', ''],
 ['A1', 'T3', 'C3', 'C1', 'Cz', 'C2', 'C4', 'T4', 'A2'],
 #['','',  'CP3', 'CP1', 'CPz', 'CP2', 'CP4', '','', ],
 ['','T5',  'P3', 'P1', 'Pz', 'P2', 'P4', 'T6','', ],
 ['','',  '', 'O1', '', 'O2', '', '','', ]]

In [None]:
from reversible2.plot import plot_head_signals_tight
inverted_per_class = []
for i_class in range(2):
    samples = class_dist.get_mean_std(i_class)[0].unsqueeze(0)
    inverted = feature_model.invert(samples)
    inverted_per_class.append(var_to_np(inverted)[0].squeeze())

In [None]:
fig = plot_head_signals_tight(np.stack(inverted_per_class, axis=-1), sensor_names=ch_names,
                              sensor_map=sensor_map,
                         figsize=(20,12));

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
from reversible2.timer import Timer
from plot import plot_outs
from reversible2.gradient_penalty import gradient_penalty


i_start_epoch_out = 801
n_epochs = 2001
gen_frequency = 10
for i_epoch in range(n_epochs):
    epoch_row = {}
    with Timer(name='EpochLoop', verbose=False) as loop_time:
        gen_update = (i_epoch % gen_frequency) == (gen_frequency-1)
        loss_on_outs = i_epoch >= i_start_epoch_out
        result = trainer.train(train_inputs, loss_on_outs=loss_on_outs)
        
    epoch_row.update(result)
    epoch_row['runtime'] = loop_time.elapsed_secs * 1000
    if i_epoch % (n_epochs // 20) != 0:
        df = df.append(epoch_row, ignore_index=True)
        # otherwise add ot loss in
    else:
        for i_class in range(len(train_inputs)):
            with th.no_grad():
                class_ins = train_inputs[i_class]
                samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 4)
                inverted = feature_model.invert(samples)
                clear_ctx_dicts(feature_model)
                ot_loss_in = ot_euclidean_loss_for_samples(class_ins.view(class_ins.shape[0], -1),
                                                           inverted.view(inverted.shape[0], -1)[:(len(class_ins))])
                epoch_row['ot_loss_in_{:d}'.format(i_class)] = ot_loss_in.item()
        df = df.append(epoch_row, ignore_index=True)
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
        display(df.iloc[-3:])
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)
    


In [None]:
1

In [None]:
from reversible2.plot import plot_head_signals_tight
inverted_per_class = []
for i_class in range(2):
    samples = class_dist.get_mean_std(i_class)[0].unsqueeze(0)
    inverted = feature_model.invert(samples)
    inverted_per_class.append(var_to_np(inverted)[0].squeeze())

In [None]:
fig = plot_head_signals_tight(np.stack(inverted_per_class, axis=-1), sensor_names=ch_names,
                              sensor_map=sensor_map,
                         figsize=(20,12));

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
from reversible2.timer import Timer
from plot import plot_outs
from reversible2.gradient_penalty import gradient_penalty


i_start_epoch_out = 401
n_epochs = 1001
gen_frequency = 10
for i_epoch in range(n_epochs):
    epoch_row = {}
    with Timer(name='EpochLoop', verbose=False) as loop_time:
        gen_update = (i_epoch % gen_frequency) == (gen_frequency-1)
        loss_on_outs = i_epoch >= i_start_epoch_out
        result = trainer.train(train_inputs, loss_on_outs=loss_on_outs)
        
    epoch_row.update(result)
    epoch_row['runtime'] = loop_time.elapsed_secs * 1000
    if i_epoch % (n_epochs // 20) != 0:
        df = df.append(epoch_row, ignore_index=True)
        # otherwise add ot loss in
    else:
        for i_class in range(len(train_inputs)):
            with th.no_grad():
                class_ins = train_inputs[i_class]
                samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 4)
                inverted = feature_model.invert(samples)
                clear_ctx_dicts(feature_model)
                ot_loss_in = ot_euclidean_loss_for_samples(class_ins.view(class_ins.shape[0], -1),
                                                           inverted.view(inverted.shape[0], -1)[:(len(class_ins))])
                epoch_row['ot_loss_in_{:d}'.format(i_class)] = ot_loss_in.item()
        df = df.append(epoch_row, ignore_index=True)
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
        display(df.iloc[-3:])
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)
    
