In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.high_gamma import load_file, create_inputs
from reversible2.high_gamma import load_train_test
th.backends.cudnn.benchmark = True
from reversible2.models import deep_invertible


In [None]:
sensor_names = ['Fz', 
                'FC3','FC1','FCz','FC2','FC4',
                'C5','C3','C1','Cz','C2','C4','C6',
                'CP3','CP1','CPz','CP2','CP4',
                'P1','Pz','P2',
                'POz']

In [None]:
# create model
# create dist

train_inputs, test_inputs = load_train_test(
    subject_id=4,
    car=True,
    n_sensors=22,
    final_hz=256,
    start_ms=500,
    stop_ms=1500,
    half_before=True,
    only_load_given_sensors=False,
)


In [None]:
# create model
# create dist

test_dist_inputs, test_dist_inputs_2 = load_train_test(
    subject_id=5,
    car=True,
    n_sensors=22,
    final_hz=256,
    start_ms=500,
    stop_ms=1500,
    half_before=True,
    only_load_given_sensors=False,
)

In [None]:
from reversible2.mixture import GaussianMixture, TwoClassMixture

In [None]:
def split_into_train_valid(train_inputs):
    train_val = [th.chunk(t_ins, 2, dim=0) for t_ins in train_inputs]
    # class 0 train/valid
    # class 1 train/valid
    train_inputs = [t for t,v in train_val]
    valid_inputs = [v for t,v in train_val]
    return train_inputs, valid_inputs


In [None]:
tr_ins = [t[:,:].cuda() for t in train_inputs]
te_ins = [t[:,:].cuda() for t in test_inputs]
log_stds =  [th.zeros_like(flatten_2d(t), requires_grad=True) for t in tr_ins]
for l in log_stds: l.data += 0;

optim_log_stds = th.optim.Adam(log_stds, lr=1e-3)
mixtures = [GaussianMixture(flatten_2d(t), l) for t,l in zip(tr_ins, log_stds)]
mixture = TwoClassMixture(mixtures)


In [None]:
n_epochs = 10001
rand_noise_factor = 1e-2
for i_epoch in range(n_epochs):
    optim_log_stds.zero_grad()
    for i_class in range(2):
        tr_inds, val_inds = th.chunk(th.randperm(len(tr_ins[i_class])),2)
        this_ins = flatten_2d(tr_ins[i_class][val_inds])
        this_ins = this_ins + (th.rand_like(this_ins) - 0.5) * rand_noise_factor

        mix = GaussianMixture(flatten_2d(tr_ins[i_class][tr_inds]), log_stds[i_class][tr_inds])
        nll = -th.mean(mix.log_probs(this_ins))
        nll.backward()
        del mix
        del this_ins
    optim_log_stds.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))

        for i_class in range(2):
            for j_class in range(2):
                nll = -th.mean(mixtures[j_class].log_probs(flatten_2d(tr_ins[i_class])))
                print("NLL {:d}->{:d} {:.1E}".format(i_class, j_class, nll.item(),))
                
        for setname, ins in (("Train", tr_ins), ("Test", te_ins)):
            corrects = []
            for i_class in range(2):
                corrects.extend(np.argmax(var_to_np(mixture.log_softmax(flatten_2d(ins[i_class]))), axis=1)  == i_class)
            acc = np.mean(corrects)
            print("{:6s} Accuracy: {:.1f}".format(setname, acc * 100))

In [None]:
from reversible2.mixture import GaussianMixture
mixtures = [GaussianMixture(flatten_2d(t), l) for t,l in zip(tr_ins, log_stds)]

In [None]:
n_samples = 5000


plt.figure(figsize=(8,3))
for i_class in range(2):
    samples = mixtures[i_class].sample(5000).view(-1, *train_inputs[0].shape[1:]).squeeze()
    bps_fake = np.abs(np.fft.rfft(var_to_np(samples)))
    plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(np.mean(bps_fake, axis=0), axis=0),
            color=seaborn.color_palette()[i_class], ls="--",
            label="Fake {:s}".format(["Right", "Rest"][i_class]))
    bps_real = np.abs(np.fft.rfft(var_to_np(train_inputs[i_class].squeeze())))
    plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(np.mean(bps_real, axis=0), axis=0),
            color=seaborn.color_palette()[i_class], ls="-",
            label="Real {:s}".format(["Right", "Rest"][i_class]))
    
plt.legend()
plt.title("Spectrum of real and generated data")
plt.xlabel("Freq [Hz]")
plt.ylabel("Amplitude")
    

In [None]:
from reversible2.plot import plot_head_signals_tight

plt.plot(var_to_np(samples)[0,0])

In [None]:
fig = plot_head_signals_tight(var_to_np(samples)[0], sensor_names=sensor_names, figsize=(16,12))


## Earlier:  Fixed split train/valid

In [None]:
tr_ins, val_ins = split_into_train_valid(train_inputs)

tr_ins = [t[:,:].cuda() for t in tr_ins]
val_ins = [t[:,:].cuda() for t in val_ins]
te_ins = [t[:,:].cuda() for t in test_inputs]
log_stds =  [th.zeros_like(flatten_2d(t), requires_grad=True) for t in tr_ins]
for l in log_stds: l.data += 0;

mixtures = [GaussianMixture(flatten_2d(t), l) for t,l in zip(tr_ins, log_stds)]
mixture = TwoClassMixture(mixtures)
optim_log_stds = th.optim.Adam(log_stds, lr=1e-3)

In [None]:
n_epochs = 10001
rand_noise_factor = 1e-2
for i_epoch in range(n_epochs):
    optim_log_stds.zero_grad()
    for i_class in range(2):
        this_ins = flatten_2d(val_ins[i_class])
        this_ins = this_ins + (th.rand_like(this_ins) - 0.5) * rand_noise_factor
        nll = -th.mean(mixtures[i_class].log_probs(this_ins))
        nll.backward()
    optim_log_stds.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))

        for i_class in range(2):
            for j_class in range(2):
                nll = -th.mean(mixtures[j_class].log_probs(flatten_2d(val_ins[i_class])))
                print("NLL {:d}->{:d} {:.1E}".format(i_class, j_class, nll.item(),))
                
        for setname, ins in (("Train", tr_ins), ("Valid", val_ins), ("Test", te_ins)):
            corrects = []
            for i_class in range(2):
                corrects.extend(np.argmax(var_to_np(mixture.log_softmax(flatten_2d(ins[i_class]))), axis=1)  == i_class)
            acc = np.mean(corrects)
            print("{:6s} Accuracy: {:.1f}".format(setname, acc * 100))

In [None]:
plt.figure(figsize=(12,4))
plt.plot(var_to_np(log_stds[0]).T, lw=0.5, color='black');

plt.figure(figsize=(12,4))
plt.plot(np.exp(var_to_np(log_stds[0]).T), lw=0.5, color='black');

plt.figure(figsize=(12,4))
plt.plot(np.mean(np.exp(var_to_np(log_stds[0])), axis=0), lw=0.5, color='black');