In [None]:
%config Completer.use_jedi=False

%load_ext autoreload
%autoreload 2

import os
os.sys.path.insert(0, '/home/schirrmr/code/invertible-public/')

# Invertible Networks for EEG Decoding Tutorial

## Discriminative and Generative Methods

Discriminative classification models learn $p(y|x)$, to predict the class given an input:

Data | $p(y|x)$
- | - 
![](./data.png) | ![](./relclass.png)


Explicit generative class-conditional models learn $p(y,x)$ to predict the joint density of a given class and a given input:



$p(y,0)$ | $p(y,1)$
- | -
![](./class0.png) | ![](./class1.png)

We will see what advantages invertible networks as generative models offer.
First, we show a simplified EEG example to illustrate how invertible networks work. Alpha power from 2 Electrodes to distinguish right hand movement from resting state.

## Load High-Gamma Dataset

We will load the high-gamma dataset ([paper](http://onlinelibrary.wiley.com/doi/10.1002/hbm.23730/full), [data](https://gin.g-node.org/robintibor/high-gamma-dataset)) through the EEG deep learning library [Braindecode](braindecode.org/).
We will use a simple 2d EEG example with right hand movements and resting state as the two signal classes. We will extract the alpha power from two the C3 and C4 sensors, which is very informative for resting state vs right hand.


In [None]:
import braindecode
from braindecode.datasets.moabb import MOABBDataset
from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess


subject_id = 4

# using the moabb dataset to load our data
dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id])

In [None]:
from copy import deepcopy
from braindecode.datautil.preprocess import exponential_moving_demean
# making a copy just to be able to rerun preprocessing without
# waiting later
preproced_set = deepcopy(dataset)
low_cut_hz = 7.  # low cut frequency for filtering
high_cut_hz = 14.  # high cut frequency for filtering

# Define preprocessing steps
preprocessors = [
    # convert from volt to microvolt, directly modifying the numpy array
    NumpyPreproc(fn=lambda x: x * 1e6),
    # keep only EEG sensors
    NumpyPreproc(fn=exponential_moving_demean, init_block_size=1000, factor_new=1e-3),
    MNEPreproc(fn='resample', sfreq=50),
    MNEPreproc(fn='pick_channels', ch_names=['C3', 'C4',], ordered=True),
    # bandpass filter
    MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
    MNEPreproc(fn='resample', sfreq=32),
]

# Preprocess the data
preprocess(preproced_set, preprocessors)

In [None]:
from braindecode.datautil.windowers import create_windows_from_events
# Next, extract the 4-second trials from the dataset.
# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
class_names = ['Right Hand', 'Rest'] # for later plotting
class_mapping = {'right_hand': 0, 'rest': 1}

windows_dataset = create_windows_from_events(
    preproced_set,
    trial_start_offset_samples=0,
    trial_stop_offset_samples=0,
    preload=True,
    mapping=class_mapping,
)

Look inside the dataset to see what it can be splitted by and split into training, validation and evaluation set.

In [None]:
windows_dataset.description

In [None]:
from torch.utils.data import Subset
import numpy as np
splitted = windows_dataset.split('run')
train_set = splitted['train']
n_split = int(np.round(0.8 * len(train_set)))
valid_set = Subset(train_set, range(n_split,len(train_set)))
train_set = Subset(train_set, range(0, n_split))

Now we extract mean squared values to have an estimate of the power.

In [None]:
from skorch.utils import to_numpy
from skorch.utils import to_tensor

train_X = np.stack([to_numpy(X) for X,y,i in train_set], axis=0)
train_y = np.stack([y for X,y,i in train_set], axis=0)

mean_squared_X = np.mean(np.square(train_X), axis=2) 

## Plot distribution

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn
seaborn.set_palette('colorblind')
seaborn.set_style('darkgrid')

%matplotlib inline
%config InlineBackend.figure_format = 'png'
#matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14

We see the data is well-discriminable.

In [None]:
plt.figure(figsize=(4,4))

for i_class in range(2):
    mask = train_y==i_class
    plt.scatter(mean_squared_X[mask][:,0], mean_squared_X[mask][:,1],
               label=class_names[i_class], s=22, alpha=.7)
plt.legend()
plt.xlabel("C3 Mean Squared Alpha Signal [mV]")
plt.ylabel("C4 Mean Squared Alpha Signal [mV]")
plt.title("Real Data")

## Create Invertible Network

We will now create a small invertible network to train on this data and learn the underlying distribution. Invertible networks work by transforming the input data invertibly/bijectively to a predefined distribition, e.g., a gaussian distribution.

One prominent example for an invertible transformation is a coupling block:

1. Split input $x$ by dimensions into disjoint parts $x_1$, $x_2$
2. Compute $x_1' = f(x_2) * x_1 + g(x_2)$ with $f, g$ arbitrary functions (e.g. neural networks)
3. Concatenate to $x' = (x_1', x_2)$

Inversion  is possible via $x_1 = \frac{x_1' - g(x_2)}{f(x_2)}$

Find further information in these blog articles:
* https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html
* https://blog.evjang.com/2018/01/nf1.html

In [None]:
from invertible.affine import AffineCoefs, AffineModifier
from invertible.coupling import  CouplingLayer
from invertible.split_merge import ChunkChansIn2, EverySecondChan
from invertible.sequential import InvertibleSequential
from invertible.permute import InvPermute
from invertible.actnorm import ActNorm
from torch import nn
from skorch.utils import to_tensor
import torch as th
from invertible.init import init_all_modules
from invertible.distribution import NClassIndependentDist
from braindecode.util import set_random_seeds


def get_X_y_th(dataset):
    X_np = np.stack([to_numpy(X) for X,y,i in dataset], axis=0)
    y_np = np.stack([y for X,y,i in dataset], axis=0)
    mean_squared_X = np.mean(np.square(X_np), axis=2) / 100
    X_th = to_tensor(mean_squared_X, 'cpu')
    y_th = th.nn.functional.one_hot(
        to_tensor(y_np, 'cpu'), num_classes=2)
    return X_th, y_th

def flow_block():
    return InvertibleSequential(ActNorm(2, 'exp', ),
        InvPermute(2,fixed=False,use_lu=True),
        CouplingLayer(
            ChunkChansIn2(swap_dims=False), 
            AffineCoefs(nn.Sequential(
                nn.Linear(1,512),
                nn.ELU(),
                nn.Linear(512,2),
             ), EverySecondChan()),
            AffineModifier('sigmoid', add_first=True, eps=0)))

set_random_seeds(20200617, False)

net = InvertibleSequential(
    flow_block(),
    flow_block(),
    flow_block(),
    flow_block(),
    NClassIndependentDist(2, 2),
).cuda()


init_all_modules(net, None)
optim = th.optim.Adam(net.parameters(), lr=1e-3, weight_decay=5e-5)


train_X_th, train_y_th = get_X_y_th(train_set)

## Train Invertible Network

Now we train the network to maximize the likelihood of the data under our prior.

In [None]:
n_epochs = 2000
for i_epoch in range(n_epochs+1):
    noise = th.randn_like(train_X_th) * 1e-3
    noised = train_X_th + noise
    z, lp = net(noised.cuda(), fixed=dict(y=train_y_th.cuda()))
    nll = -th.mean(lp)
    nll.backward()
    optim.step()
    optim.zero_grad()
    if i_epoch % (n_epochs // 10) == 0:
        print(f"Negative Log Likelihood: {nll.item():.2f}")

## Plot Results

After training, we can generate samples from the learned distribution.

In [None]:
set_random_seeds(2, False)
plt.figure(figsize=(4,4))

for i_class in range(2):
    inv_x, _ = net.invert(None, fixed=dict(n_samples=200, y=i_class))
    inv_x = to_numpy(inv_x)
    plt.scatter(inv_x[:,0], inv_x[:,1], s=22, alpha=.7, label=class_names[i_class])
plt.legend()
plt.xlabel("C3 Mean Squared Alpha Signal [µV]")
plt.ylabel("C4 Mean Squared Alpha Signal [µV]")
plt.title("Generated Data")

We can also plot the class-conditional distributions $p(x|y)$ and the predictions $p(y|x)$ for any point:

In [None]:
x_start = 0
x_stop = 1
y_start = 0
y_stop = 1.25
points = th.stack(th.meshgrid(th.linspace(x_start,x_stop,30),
                              th.linspace(y_start,y_stop,30)), dim=-1)
lprobs = net(points.view(-1, points.shape[-1]).cuda())[1]

probs_grid_0 = th.exp(lprobs[:,0].view(points.shape[:2]))
probs_grid_1 = th.exp(lprobs[:,1].view(points.shape[:2]))
rel_probs_grid = th.softmax(lprobs, dim=1)[:,1].view(points.shape[:2])

In [None]:
from matplotlib.colors import LinearSegmentedColormap
def create_cmap(c1,c2, n_bins):
    alphas = np.linspace(0,1,100)
    colors = np.array(c1)[None] * (1-alphas[:,None]) + np.array(c2)[None] * (alphas[:,None])
    cmap = LinearSegmentedColormap.from_list('', colors)
    return cmap
cmap_bl = create_cmap([1,1,1], seaborn.color_palette()[0], 100)
cmap_or = create_cmap([1,1,1], seaborn.color_palette()[1], 100)
cmap_blor = create_cmap(seaborn.color_palette()[0], seaborn.color_palette()[1], 100)

In [None]:
fig, axes = plt.subplots(1,3, figsize=(14,4), sharex=True, sharey=True)
im = axes[0].imshow(to_numpy(probs_grid_0).T, origin='lower left',
                   extent=((x_start, x_stop, y_start, y_stop)), cmap=cmap_bl,
                  aspect='auto', interpolation='bilinear', vmin=0, vmax=10)
im2 = axes[1].imshow(to_numpy(probs_grid_1).T, origin='lower left',
                   extent=((x_start, x_stop, y_start, y_stop)), cmap=cmap_or,
                  aspect='auto', interpolation='bilinear', vmin=0, vmax=10)
im2 = axes[2].imshow(to_numpy(rel_probs_grid).T, origin='lower left',
                   extent=((x_start, x_stop, y_start, y_stop)), cmap=cmap_blor,
                  aspect='auto', interpolation='bilinear',vmin=0,vmax=1)
axes[0].set_title("Right Hand Learned Distribution")
axes[1].set_title("Rest Learned Distribution")
axes[2].set_title("Conditional Class Prob")
axes[1].set_xlabel("C3 Mean Squared Alpha Signal [µV]")
axes[0].set_ylabel("C4 Mean Squared Alpha Signal [µV]")

We also show the transformation the invertible network has learned, you can see it transforms the input data to two gaussian distributions.

In [None]:
from matplotlib.patches import ConnectionPatch
fig, axes = plt.subplots(1,2, figsize=(8,4))
with th.no_grad():
    z, _ = net(train_X_th.cuda())
    for i_class in range(2):
        mask = train_y == i_class
        
        axes[0].scatter(train_X_th[mask][:,0], train_X_th[mask][:,1],
                   label=class_names[i_class], s=22, alpha=.7,
                       color=seaborn.color_palette()[i_class])
        axes[1].scatter(*to_numpy(z[mask]).T, color=seaborn.color_palette()[i_class],
                       s=22, alpha=.7)
        for xb, xa in zip(to_numpy(train_X_th[mask].squeeze()), to_numpy(z[mask])):
            con = ConnectionPatch(xyA=xa, xyB=xb, coordsA="data", coordsB="data",
                                  axesA=axes[1], axesB=axes[0], color="grey", lw=0.2)
            axes[1].add_artist(con)
axes[0].set_title("Input Data")
axes[1].set_title("Output Data")
fig.suptitle("Learned Invertible Transformation", y=1.08, fontsize=20)

## Evaluate accuracy

Finally we can also evaluate the accuracy of this model.

In [None]:
valid_X_th, valid_y_th = get_X_y_th(valid_set)

for setname, set_X, set_y in (('train', train_X_th, train_y_th),
                                ('valid', valid_X_th, valid_y_th)):

    with th.no_grad():
        z,lp = net(set_X.cuda())
        acc = np.mean(to_numpy(lp.argmax(dim=1) == set_y.cuda().argmax(dim=1)))
    print(f"{setname.capitalize()} Acc: {acc:.1%}")

## Raw data

We will now go to raw data, still simplified: 32 Hz, 2 second windows, single-sensor data, bandpassed between 7 and 14 Hz (alpha band).

In [None]:
from invertible.affine import AdditiveCoefs
from invertible.view_as import Flatten2d
from invertible.datautil import PreprocessedLoader
from invertible.pure_model import NoLogDet
from invertible.noise import GaussianNoise
from invertible.subsample_split import SubsampleSplitter

def conv_flow_block(n_chans):
    return InvertibleSequential(
        ActNorm(n_chans, 'exp', ),
        InvPermute(n_chans,fixed=False,use_lu=True),
        CouplingLayer(
            ChunkChansIn2(swap_dims=False), 
            AdditiveCoefs(nn.Sequential(
                nn.Conv1d(n_chans//2,128,7, padding=3),
                nn.ELU(),
                nn.Conv1d(128,n_chans//2,7, padding=3),
             )),
            AffineModifier('sigmoid', add_first=True, eps=0)))

def dense_flow_block(n_chans):
    return InvertibleSequential(ActNorm(n_chans, 'exp', ),
        InvPermute(n_chans,fixed=False,use_lu=True),
        CouplingLayer(
            ChunkChansIn2(swap_dims=False), 
            AdditiveCoefs(nn.Sequential(
                nn.Linear(n_chans//2,512),
                nn.ELU(),
                nn.Linear(512,n_chans//2),
             )),
            AffineModifier('sigmoid', add_first=True, eps=0)))

In [None]:
from copy import deepcopy
from braindecode.datautil.preprocess import exponential_moving_demean
# making a copy just to be able to rerun preprocessing without
# waiting later
preproced_set = deepcopy(dataset)
low_cut_hz = 7.  # low cut frequency for filtering
high_cut_hz = 14.  # high cut frequency for filtering

# Define preprocessing steps
preprocessors = [
    # convert from volt to microvolt, directly modifying the numpy array
    NumpyPreproc(fn=lambda x: x * 1e6),
    NumpyPreproc(fn=exponential_moving_demean, init_block_size=1000, factor_new=1e-3),
    MNEPreproc(fn='resample', sfreq=50),
    MNEPreproc(fn='pick_channels', ch_names=['C3',], ordered=True),
    # bandpass filter
    MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
    MNEPreproc(fn='resample', sfreq=32),
]

# Preprocess the data
preprocess(preproced_set, preprocessors)

In [None]:
windows_dataset = create_windows_from_events(
    preproced_set,
    trial_start_offset_samples=0,
    trial_stop_offset_samples=0,
    preload=True,
    window_size_samples=None,
    window_stride_samples=None,
    mapping={'right_hand': 0, 'rest': 1},
)

from torch.utils.data import Subset
splitted = windows_dataset.split('run')
train_set = splitted['train']
n_split = int(np.round(len(train_set) * 0.75))
valid_set = Subset(train_set, range(n_split,len(train_set)))
train_set = Subset(train_set, range(0,n_split))

## Convolutional Invertible Network

For the "raw" time-series data, we will use a convolutional invertible network inspired by the popular [Glow](https://openai.com/blog/glow/) invertible network architecture. We have a special block in the end to make it easier for the network to separate amplitude and phase into different latent dimensions.

In [None]:
from invertible.amp_phase import AmplitudePhase
set_random_seeds(20200718, True)
n_chans = 1
net = InvertibleSequential(
    SubsampleSplitter((2,),chunk_chans_first=False),
    conv_flow_block(n_chans*2),
    conv_flow_block(n_chans*2),
    conv_flow_block(n_chans*2),
    conv_flow_block(n_chans*2),
    SubsampleSplitter((2,),chunk_chans_first=False),
    SubsampleSplitter((2,),chunk_chans_first=False),
    conv_flow_block(n_chans*8),
    conv_flow_block(n_chans*8),
    conv_flow_block(n_chans*8),
    conv_flow_block(n_chans*8),
    SubsampleSplitter((2,),chunk_chans_first=False),
    SubsampleSplitter((2,),chunk_chans_first=False),
    Flatten2d(),
    dense_flow_block(n_chans*64),
    dense_flow_block(n_chans*64),
    dense_flow_block(n_chans*64),
    dense_flow_block(n_chans*64),
    AmplitudePhase(),
    NClassIndependentDist(2, n_chans*64, optimize_mean=True, optimize_std=False),

).cuda()

train_loader = th.utils.data.DataLoader(
    train_set,
    batch_size=50,
    shuffle=True,
    num_workers=0,
    drop_last=True)
valid_loader = th.utils.data.DataLoader(
    valid_set,
    batch_size=len(valid_set),
    shuffle=False,
    num_workers=0)
preproced_loader = PreprocessedLoader(train_loader, GaussianNoise(1e-2), False)
init_all_modules(net, th.cat([x[:,:,32:96] for x,y,i in preproced_loader], dim=0).cuda())

optim = th.optim.Adam(net.parameters(), lr=5e-4, weight_decay=1e-4)

In [None]:
n_epochs = 2000
rng = np.random.RandomState(394834)
for i_epoch in range(n_epochs+1):
    if i_epoch > 0:
        for X_th, y, _ in train_loader:
            start_ind = rng.randint(0,64)
            X_th = X_th[:,:,start_ind:start_ind+64]
            y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
            noise = th.randn_like(X_th) * 1e-2#5e-1
            noised = X_th + noise
            z, lp = net(noised.cuda(), fixed=dict(y=None))
            cross_ent = th.nn.functional.cross_entropy(
                lp, y_th.argmax(dim=1),)
            nll = -th.mean(th.sum(lp * y_th, dim=1))
            loss = cross_ent * 10 + nll
            loss.backward()
            optim.step()
            optim.zero_grad()
    if (i_epoch % (n_epochs // 10) == 0) or (i_epoch == n_epochs):
        print(i_epoch)
        for X_th, y, _ in train_loader:
            X_th = X_th[:,:,32:32+64]
            y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
            z, lp = net(X_th.cuda())
        print(f"Train NLL: {-th.mean(th.sum(lp *y_th, dim=1)).item():.1f}")
        print(f"Train Acc: {to_numpy(y.cuda() == lp.argmax(dim=1)).mean():.1%}")
        for X_th, y, _ in valid_loader:
            X_th = X_th[:,:,32:32+64]
            y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
            z, lp = net(X_th.cuda())
        print(f"Valid NLL: {-th.mean(th.sum(lp *y_th, dim=1)).item():.1f}")
        print(f"Valid Acc: {to_numpy(y.cuda() == lp.argmax(dim=1)).mean():.1%}")

In [None]:
#th.save(net.state_dict(), "netstatetutorialeeg.pth")

In [None]:
net.load_state_dict(th.load('netstatetutorialeeg.pth'))
init_all_modules(net, None)
_  = net(X_th.cuda())

### Evaluate accuracy and Negative Log Likelihood

We see improve daccuracies

In [None]:
for name, loader in (('Train', train_loader), ('Valid', valid_loader)):
    all_lps = []
    all_corrects = []
    for X_th, y, _ in loader:
        X_th = X_th[:,:,32:32+64]
        y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
        z, lp = net(X_th.cuda())
        corrects = to_numpy(y.cuda() == lp.argmax(dim=1))
        lps = to_numpy(th.sum(lp * y_th, dim=1))
        all_lps.extend(lps)
        all_corrects.extend(corrects)
    acc = np.mean(all_corrects)
    nll = -np.mean(all_lps)
    print(f"{name} NLL: {nll:.1f}")
    print(f"{name} Acc: {acc:.1%}")

## Visualize most likely inputs per class

We can also visualize the most likely inputs per class, and the input directly in the middle of the two classes.

In [None]:
dist = net.sequential[-1]
overall_mean = th.mean(dist.class_means, dim=0)
plt.figure(figsize=(12,2))

plt.plot(to_numpy(dist.class_means.squeeze()).T)
plt.plot(to_numpy(overall_mean.squeeze()))
plt.legend(['Right', 'Rest', 'Neutral'], bbox_to_anchor=(1,1,0,0))
plt.title("In Output Space")

plt.figure(figsize=(12,2))
inved_mean, _ = net.invert(overall_mean.unsqueeze(0))
inved_means, _ = net.invert(dist.class_means)

plt.plot(to_numpy(inved_means.squeeze()).T)
plt.plot(to_numpy(inved_mean.squeeze()))
plt.legend(['Right', 'Rest', 'Neutral'], bbox_to_anchor=(1,1,0,0))
plt.title("Inverted to Input Space")

## Data without bandpass

Now let's do the same for data without the bandpass and do more analysis on this task.

In [None]:
# making a copy just to be able to rerun preprocessing without
# waiting later
preproced_set = deepcopy(dataset)

# Define preprocessing steps
preprocessors = [
    # convert from volt to microvolt, directly modifying the numpy array
    NumpyPreproc(fn=lambda x: x * 1e6),
    NumpyPreproc(fn=exponential_moving_demean, init_block_size=1000, factor_new=1e-3),
    MNEPreproc(fn='pick_channels', ch_names=['C3',], ordered=True),
    MNEPreproc(fn='resample', sfreq=32),
]

# Preprocess the data
preprocess(preproced_set, preprocessors)

In [None]:
windows_dataset = create_windows_from_events(
    preproced_set,
    trial_start_offset_samples=0,
    trial_stop_offset_samples=0,
    preload=True,
    window_size_samples=None,
    window_stride_samples=None,
    mapping=class_mapping,
)
splitted = windows_dataset.split('run')
train_set = splitted['train']
n_split = int(np.round(len(train_set) * 0.75))
valid_set = Subset(train_set, range(n_split,len(train_set)))
train_set = Subset(train_set, range(0,n_split))

In [None]:
from invertible.subsample_split import SubsampleSplitter
set_random_seeds(20200718, True)
n_chans = 1
net = InvertibleSequential(
    SubsampleSplitter((2,),chunk_chans_first=False),
    conv_flow_block(n_chans*2),
    conv_flow_block(n_chans*2),
    conv_flow_block(n_chans*2),
    conv_flow_block(n_chans*2),
    SubsampleSplitter((2,),chunk_chans_first=False),
    SubsampleSplitter((2,),chunk_chans_first=False),
    conv_flow_block(n_chans*8),
    conv_flow_block(n_chans*8),
    conv_flow_block(n_chans*8),
    conv_flow_block(n_chans*8),
    SubsampleSplitter((2,),chunk_chans_first=False),
    SubsampleSplitter((2,),chunk_chans_first=False),
    Flatten2d(),
    dense_flow_block(n_chans*64),
    dense_flow_block(n_chans*64),
    dense_flow_block(n_chans*64),
    dense_flow_block(n_chans*64),
    AmplitudePhase(),
    NClassIndependentDist(2, n_chans*64, optimize_mean=True, optimize_std=False),

).cuda()

len(train_set)

train_loader = th.utils.data.DataLoader(
    train_set,
    batch_size=50,
    shuffle=True,
    num_workers=0,
    drop_last=True)

valid_loader = th.utils.data.DataLoader(
    valid_set,
    batch_size=len(valid_set),
    shuffle=False,
    num_workers=0)
preproced_loader = PreprocessedLoader(train_loader, GaussianNoise(1e-2), False)
init_all_modules(net, th.cat([x[:,:,32:96] for x,y,i in preproced_loader], dim=0).cuda())

optim = th.optim.Adam(net.parameters(), lr=5e-4, weight_decay=1e-4)

In [None]:
n_epochs = 2000
rng = np.random.RandomState(394834)
for i_epoch in range(n_epochs+1):
    if i_epoch > 0:
        for X_th, y, _ in train_loader:
            start_ind = rng.randint(0,64)
            X_th = X_th[:,:,start_ind:start_ind+64]
            y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
            noise = th.randn_like(X_th) * 1e-3#5e-1
            noised = X_th + noise
            z, lp = net(noised.cuda(), fixed=dict(y=None))
            cross_ent = th.nn.functional.cross_entropy(
                lp, y_th.argmax(dim=1),)
            nll = -th.mean(th.sum(lp * y_th, dim=1))
            loss = cross_ent * 10 + nll
            loss.backward()
            optim.step()
            optim.zero_grad()
    if (i_epoch % (n_epochs // 10) == 0) or (i_epoch == n_epochs):
        print(i_epoch)
        for X_th, y, _ in train_loader:
            X_th = X_th[:,:,32:32+64]
            y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
            z, lp = net(X_th.cuda())
        print(f"Train NLL: {-th.mean(th.sum(lp *y_th, dim=1)).item():.1f}")
        print(f"Train Acc: {to_numpy(y.cuda() == lp.argmax(dim=1)).mean():.1%}")
        for X_th, y, _ in valid_loader:
            X_th = X_th[:,:,32:32+64]
            y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
            z, lp = net(X_th.cuda())
        print(f"Valid NLL: {-th.mean(th.sum(lp *y_th, dim=1)).item():.1f}")
        print(f"Valid Acc: {to_numpy(y.cuda() == lp.argmax(dim=1)).mean():.1%}")

In [None]:
#th.save(net.state_dict(), "netstatetutorialeegnolowpass.pth")

In [None]:
net.load_state_dict(th.load('netstatetutorialeegnolowpass.pth'))
init_all_modules(net, None)
_ = net(th.zeros(1,1,64, device='cuda'))

In [None]:
for name, loader in (('Train', train_loader), ('Valid', valid_loader)):
    all_lps = []
    all_corrects = []
    for X_th, y, _ in loader:
        X_th = X_th[:,:,32:32+64]
        y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()
        z, lp = net(X_th.cuda())
        corrects = to_numpy(y.cuda() == lp.argmax(dim=1))
        lps = to_numpy(th.sum(lp * y_th, dim=1))
        all_lps.extend(lps)
        all_corrects.extend(corrects)
    acc = np.mean(all_corrects)
    nll = -np.mean(all_lps)
    print(f"{name} NLL: {nll:.1f}")
    print(f"{name} Acc: {acc:.1%}")

We see that now the most likely inputs differ not only in alpha frequency but also in slower frequency/the trend over the window (higher start and lower end for the resting state class).

In [None]:
dist = net.sequential[-1]
overall_mean = th.mean(dist.class_means, dim=0)
plt.figure(figsize=(12,2))

plt.plot(to_numpy(dist.class_means.squeeze()).T)
plt.plot(to_numpy(overall_mean.squeeze()))
plt.legend(['Right', 'Rest', 'Neutral'], bbox_to_anchor=(1,1,0,0))
plt.title("In Output Space")

plt.figure(figsize=(12,2))
inved_mean, _ = net.invert(overall_mean.unsqueeze(0))
inved_means, _ = net.invert(dist.class_means)

plt.plot(to_numpy(inved_means.squeeze()).T)
plt.plot(to_numpy(inved_mean.squeeze()))
plt.legend(['Right', 'Rest', 'Neutral'], bbox_to_anchor=(1,1,0,0))
plt.title("Inverted to Input Space")

## Individual example analysis

We can also try to analyze individual examples to understand what the network is using to predict the class.

In [None]:

dist = net.sequential[-1]
overall_mean = th.mean(dist.class_means, dim=0)
X_th, y, _ = next(valid_loader.__iter__())
X_th = X_th[:,:,32:96]
y_th = th.nn.functional.one_hot(y, num_classes=2).cuda()

with th.no_grad():
    z, lp = net(X_th.cuda())
    lp_per_dim = dist.log_probs_per_class(z, sum_dims=False)
    diffs = lp_per_dim[:,1] - lp_per_dim[:,0]
    directed_diffs = (diffs * ((y * 2 - 1).unsqueeze(1).cuda()))
    unlabeled = dist.get_unlabeled_samples(len(z))
    threshold = 0#np.percentile(to_numpy(directed_diffs), 70)
    mask = (directed_diffs > threshold).type_as(z)
    #z = z * (directed_diffs > 0) + unlabeled * (directed_diffs <= 0)
    z_class = z *  mask + overall_mean.unsqueeze(0) * (1-mask)
    inved_class, _ = net.invert(z_class)
    z_class, inved_lp = net(inved_class)
    z_nonclass = z *  (1-mask) + overall_mean.unsqueeze(0) * (mask)
    inved_nonclass, _ = net.invert(z_nonclass)
correct = to_numpy(th.sum(directed_diffs, dim=1) > 0)

For that, we will do the following:

1. Compute the output of the invertible network
2. Compute for each output dimension, which class is more likely for this example
3. Keep only the output dimensions that indicate one class, and reset the other output dimensions to neutral values: the mean between both class gaussians
4. Invert to visualize signal parts indicative of the different classes

In [None]:
fig, axes = plt.subplots(1,2, figsize=(16,2))
axes[0].plot(to_numpy(z[0].squeeze()), marker='o', color='black')
axes[0].plot(to_numpy(dist.class_means[0].squeeze()), marker='o')
axes[0].plot(to_numpy(dist.class_means[1].squeeze()), marker='o')
#axes[1].plot(np.diff(to_numpy(th.softmax(lp_per_dim[0]).squeeze()), axis=0).squeeze(), marker='o',
#        color='black')
axes[1].plot(to_numpy(th.softmax(lp_per_dim[0], dim=0)[1]).squeeze(), marker='o',
        color='black')
axes[0].set_title('Example Output and Class Means')
axes[1].set_title('Relative Probability Correct Class')
axes[1].axhline(y=0.5, color='darkgrey')

fig, axes = plt.subplots(1,2, figsize=(16,2))
axes[0].plot(to_numpy(z_class[0].squeeze()), marker='o', color=seaborn.color_palette()[1],)
axes[0].plot(to_numpy(z_nonclass[0].squeeze()), marker='o', color=seaborn.color_palette()[0])
axes[0].plot(to_numpy(overall_mean.squeeze()), marker='o', 
             color=seaborn.color_palette()[2])
axes[1].plot(to_numpy(X_th[0].squeeze()), color='black', label="Real")
axes[1].plot(to_numpy(inved_nonclass[0].squeeze()),  color=seaborn.color_palette()[0],
            label="Right Hand")
axes[1].plot(to_numpy(inved_class[0].squeeze()),  color=seaborn.color_palette()[1],
            label="Rest")
axes[1].plot(to_numpy(inved_mean.squeeze()), 
             color=seaborn.color_palette()[2], label="Neutral")
axes[0].set_title('Output parts indicative of different classes')
axes[1].set_title('Class-indicative parts inverted to input space')
plt.legend(bbox_to_anchor=(1,-0.2,0,0))



Let's look at this for more examples, with correctness of prediction written as well.

In [None]:
for i_class in range(2):
    fig, axes = plt.subplots(6,4, figsize=(16,8), sharex=True, sharey=True)
    mask = y == i_class 
    i_start = 30 if i_class == 0 else 40
    for x_orig, x_class, x_nonclass, is_correct, row_axes in zip(
            X_th[mask][i_start:], inved_class[mask][i_start:], inved_nonclass[mask][i_start:],
            correct[mask][i_start:], axes):
        row_axes[0].set_ylabel(f'{["wrong", "correct"][int(is_correct)]}', rotation=0)
        row_axes[0].plot(to_numpy(x_orig.squeeze()), color='black')
        row_axes[1 + i_class].plot(to_numpy(x_class.squeeze()), color=seaborn.color_palette()[i_class])
        row_axes[1 + i_class].plot(to_numpy(inved_mean.squeeze()), color=seaborn.color_palette()[2], lw=1)
        row_axes[2 - i_class].plot(to_numpy(x_nonclass.squeeze()), color=seaborn.color_palette()[1-i_class])
        row_axes[2 - i_class].plot(to_numpy(inved_mean.squeeze()), color=seaborn.color_palette()[2], lw=1)
        row_axes[3].plot(to_numpy(x_orig.squeeze()), color='black', lw=1)
        row_axes[3].plot(to_numpy(x_class.squeeze()), color=seaborn.color_palette()[i_class], lw=1)
        row_axes[3].plot(to_numpy(x_nonclass.squeeze()), color=seaborn.color_palette()[1-i_class], lw=1)
        row_axes[3].plot(to_numpy(inved_mean.squeeze()), color=seaborn.color_palette()[2], lw=1)
    axes[0][0].set_title("Original")
    axes[0][1].set_title(f"Indicate Right")
    axes[0][2].set_title(f"Indicate Rest")
    axes[0][3].set_title(f"All")
    plt.ylim(-25,25)

Let's look at mispredicted inputs. For this right hand signal, it seems that the increase in amplitude towards the end causes the misprediction as resting state:

In [None]:
x_try = X_th[y==0][30:31].clone()
plt.figure(figsize=(10,3))
plt.plot(x_try.squeeze(), color='black', label='Real')
plt.plot(inved_class[y==0][30].detach().cpu().squeeze(), label="Right Hand")
plt.plot(inved_nonclass[y==0][30].detach().cpu().squeeze(), label="Rest")
plt.plot(inved_mean.detach().cpu().squeeze(), label='Neutral')
plt.legend(bbox_to_anchor=(1,1,0,0))

We can try to check that by scaling down the later part of the signal and evaluating the prediction change:

In [None]:
plt.figure(figsize=(10,3))
x_try = X_th[y==0][30:31].clone()
with th.no_grad():
    lp = net(x_try.cuda())[1]
x_try_new = x_try.clone()
i_time_step = 32
block = x_try_new.data[:,:,i_time_step:].clone()
x_try_new.data[:,:,i_time_step:] = (block - block.mean()) * 0.7 + block.mean()
lp_orig = net(x_try.cuda())[1].squeeze()
lp_fake  = net(x_try_new.cuda())[1].squeeze()
plt.plot(x_try.squeeze(), color='black', label=f'Real (Correct pred: {th.softmax(lp_orig, dim=0)[0].item():.1%})')
plt.plot(x_try_new.squeeze(), color=seaborn.color_palette()[4], 
         label=f'Manipulated (Correct pred: {th.softmax(lp_fake, dim=0)[0].item():.1%})')

plt.legend()

For another example, the low frequencies also seem to cause the misclassification:

In [None]:
x_try = X_th[y==0][35:36].clone()
plt.figure(figsize=(10,3))
plt.plot(x_try.squeeze(), color='black', label='Real')
plt.plot(inved_class[y==0][35].detach().cpu().squeeze(), label="Right Hand")
plt.plot(inved_nonclass[y==0][35].detach().cpu().squeeze(), label="Rest")
plt.plot(inved_mean.detach().cpu().squeeze(), label='Neutral')
plt.legend(bbox_to_anchor=(1,1,0,0))

We can check by attenuating the amplitude only in the low frequencies:

In [None]:
plt.figure(figsize=(10,3))
x_try = X_th[y==0][35:36].clone()
with th.no_grad():
    lp = net(x_try.cuda())[1]
x_try_new = x_try.clone()
ffted = th.rfft(x_try_new, signal_ndim=1)

ffted.data[:,:,:12] = ffted.data[:,:,:12] * 0.1

x_try_new = th.irfft(ffted, signal_ndim=1, signal_sizes=[64,])
lp_orig = net(x_try.cuda())[1].squeeze()
lp_fake  = net(x_try_new.cuda())[1].squeeze()
plt.plot(x_try.squeeze(), color='black', label=f'Real (Correct pred: {th.softmax(lp_orig, dim=0)[0].item():.1%})')
plt.plot(x_try_new.squeeze(), color=seaborn.color_palette()[4], 
         label=f'Manipulated (Correct pred: {th.softmax(lp_fake, dim=0)[0].item():.1%})')

plt.legend()
