In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs
th.backends.cudnn.benchmark = True

In [None]:
orig_train_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/train/4.mat')
train_cnt = orig_train_cnt.reorder_channels(['C3',])

train_inputs = create_inputs(train_cnt, final_hz=512, half_before=True)

In [None]:
orig_test_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/test/4.mat')
test_cnt = orig_test_cnt.reorder_channels(['C3', ])
test_inputs = create_inputs(test_cnt, final_hz=512, half_before=True)

In [None]:
from reversible2.branching import CatChans, ChunkChans, Select

In [None]:
cuda = True
if cuda:
    train_inputs = [i.cuda() for i in train_inputs]
    test_inputs = [i.cuda() for i in test_inputs]

In [None]:
from reversible2.graph import Node
from reversible2.branching import CatChans, ChunkChans, Select

In [None]:
def invert(feature_model, samples):
    return feature_model.invert(samples)

In [None]:
a = th.linspace(0,8,9)

In [None]:
class PadChans(nn.Module):
    def __init__(self, n_chans):
        super(PadChans, self).__init__()
        self.n_chans = n_chans
    
    def forward(self, x):
        pad = th.zeros((x.size()[0], self.n_chans, x.shape[2], x.shape[3]), device=x.device)
        out = th.cat((x,pad), dim=1)
        return out

    def invert(self, y):
        return y[:,:-self.n_chans]

class RepeatTime(nn.Module):
    def __init__(self, n_times):
        super(RepeatTime, self).__init__()
        self.n_times = n_times

    def forward(self, x):
        x = x.unsqueeze(3).repeat(1,1,1,self.n_times,1).view(
            x.shape[0],x.shape[1],x.shape[2] * 2, x.shape[3])
        return x
    
    def invert(self, y):
        return th.nn.functional.avg_pool2d(y,kernel_size=(2,1), stride=(2,1))
    

In [None]:
from copy import deepcopy
from reversible2.graph import Node
from reversible2.distribution import TwoClassDist

from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter

set_random_seeds(2019011641, cuda)

base_model = nn.Sequential(
    RepeatTime(2 ),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=False),# 2 x 512
    conv_add_block_3x3(2,32),
    conv_add_block_3x3(2,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 4 x 256
    conv_add_block_3x3(4,32),
    conv_add_block_3x3(4,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 8 x 128
    conv_add_block_3x3(8,32),
    conv_add_block_3x3(8,32))
base_model.cuda();

branch_1_a =  nn.Sequential(
    SubsampleSplitter(stride=[2,1],chunk_chans_first=False), # 8 x 64
    conv_add_block_3x3(8,32),
    conv_add_block_3x3(8,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 16 x 32
    conv_add_block_3x3(16,32),
    conv_add_block_3x3(16,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 32 x 16
    conv_add_block_3x3(32,32),
    conv_add_block_3x3(32,32),
)
branch_1_b = nn.Sequential(
    *(list(deepcopy(branch_1_a).children()) + [
    ViewAs((-1, 32,16,1), (-1,512)),
    dense_add_block(512,32),
    dense_add_block(512,32),
    dense_add_block(512,32),
    dense_add_block(512,32),
]))
branch_1_a.cuda();
branch_1_b.cuda();

branch_2_a = nn.Sequential(
    SubsampleSplitter(stride=[2,1], chunk_chans_first=False), # 32 x 8
    conv_add_block_3x3(32,32),
    conv_add_block_3x3(32,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 64 x 4
    conv_add_block_3x3(64,32),
    conv_add_block_3x3(64,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 128 x 2
    conv_add_block_3x3(128,32),
    conv_add_block_3x3(128,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 256 x 1
    ViewAs((-1, 256,1,1), (-1,256)),
    dense_add_block(256,64),
    dense_add_block(256,64),
    dense_add_block(256,64),
    dense_add_block(256,64),
)


branch_2_b = deepcopy(branch_2_a).cuda()
branch_2_a.cuda();
branch_2_b.cuda();

final_model = nn.Sequential(
    dense_add_block(1024,256),
    dense_add_block(1024,256),
    dense_add_block(1024,256),
    dense_add_block(1024,256),
    RFFT()
)
final_model.cuda();

In [None]:
o = Node(None, base_model)
o = Node(o, ChunkChans(2))
o1a = Node(o, Select(0))
o1b = Node(o, Select(1))
o1a = Node(o1a, branch_1_a)
o1b = Node(o1b, branch_1_b)
o2 = Node(o1a, ChunkChans(2))
o2a = Node(o2, Select(0))
o2b = Node(o2, Select(1))
o2a = Node(o2a, branch_2_a)
o2b = Node(o2b, branch_2_b)
o = Node([o1b,o2a,o2b], CatChans())
o = Node(o, final_model)
feature_model = o
if cuda:
    feature_model.cuda()
feature_model.eval()
# Check that forward + inverse is really identical
t_out = feature_model(train_inputs[0][:2])
inverted = invert(feature_model, t_out)
assert th.allclose(train_inputs[0][:2], inverted, rtol=1e-3,atol=1e-4)
device = list(feature_model.parameters())[0].device
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2, 2*np.prod(train_inputs[0].size()[1:]) - 2)
class_dist.cuda()
optim_model = th.optim.Adam(feature_model.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2)

In [None]:
%%writefile plot.py
import torch as th
import matplotlib.pyplot as plt
import numpy as np
from reversible2.util import var_to_np
from reversible2.plot import display_close
from matplotlib.patches import Ellipse
import seaborn

def plot_outs(feature_model, train_inputs, test_inputs, class_dist):
    
    # Compute dist for mean/std of encodings
    data_cls_dists = []
    for i_class in range(len(train_inputs)):
        this_class_outs = feature_model(train_inputs[i_class])[:,:2]
        data_cls_dists.append(
            th.distributions.MultivariateNormal(th.mean(this_class_outs, dim=0),
            covariance_matrix=th.diag(th.std(this_class_outs, dim=0) ** 2)))
    for setname, set_inputs in (("Train", train_inputs), ("Test", test_inputs)):

        outs = [feature_model(ins) for ins in set_inputs]
        c_outs = [o[:,:2] for o in outs]

        c_outs_all = th.cat(c_outs)

        cls_dists = []
        for i_class in range(len(c_outs)):
            mean, std = class_dist.get_mean_std(i_class)
            cls_dists.append(
                th.distributions.MultivariateNormal(mean[:2],covariance_matrix=th.diag(std[:2] ** 2)))

        preds_per_class = [th.stack([cls_dists[i_cls].log_prob(c_out)
                         for i_cls in range(len(cls_dists))],
                        dim=-1) for c_out in c_outs]

        pred_labels_per_class = [np.argmax(var_to_np(preds), axis=1)
                       for preds in preds_per_class]

        labels = np.concatenate([np.ones(len(set_inputs[i_cls])) * i_cls 
         for i_cls in range(len(train_inputs))])

        acc = np.mean(labels == np.concatenate(pred_labels_per_class))

        data_preds_per_class = [th.stack([data_cls_dists[i_cls].log_prob(c_out)
                         for i_cls in range(len(cls_dists))],
                        dim=-1) for c_out in c_outs]
        data_pred_labels_per_class = [np.argmax(var_to_np(data_preds), axis=1)
                            for data_preds in data_preds_per_class]
        data_acc = np.mean(labels == np.concatenate(data_pred_labels_per_class))

        print("{:s} Accuracy: {:.1f}%".format(setname, acc * 100))
        fig = plt.figure(figsize=(5,5))
        ax = plt.gca()
        for i_class in range(len(c_outs)):
            #if i_class == 0:
            #    continue
            o = var_to_np(c_outs[i_class]).squeeze()
            incorrect_pred_mask = pred_labels_per_class[i_class] != i_class
            plt.scatter(o[:,0], o[:,1], s=20, alpha=0.75, label=["Right", "Rest"][i_class])
            assert len(incorrect_pred_mask) == len(o)
            plt.scatter(o[incorrect_pred_mask,0], o[incorrect_pred_mask,1], marker='x', color='black',
                       alpha=1, s=5)
            means, stds = class_dist.get_mean_std(i_class)
            means = var_to_np(means)[:2]
            stds = var_to_np(stds)[:2]
            for sigma in [0.5,1,2,3]:
                ellipse = Ellipse(means, stds[0]*sigma, stds[1]*sigma)
                ax.add_artist(ellipse)
                ellipse.set_edgecolor(seaborn.color_palette()[i_class])
                ellipse.set_facecolor("None")
        for i_class in range(len(c_outs)):
            o = var_to_np(c_outs[i_class]).squeeze()
            plt.scatter(np.mean(o[:,0]), np.mean(o[:,1]),
                       color=seaborn.color_palette()[i_class+2], s=80, marker="^",
                       label=["Right Mean", "Rest Mean"][i_class])

        plt.title("{:6s} Accuracy:        {:.1f}%\n"
                  "From data mean/std: {:.1f}%".format(setname, acc * 100, data_acc * 100))
        plt.legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
    return

In [None]:
from reversible2.timer import Timer
from plot import plot_outs

i_start_epoch_out = 200
n_epochs = 1001
for i_epoch in range(n_epochs):
    with Timer(name='EpochLoop', verbose=False) as loop_time:
        optim_model.zero_grad()
        optim_dist.zero_grad()
        for i_class in range(len(train_inputs)):
            class_ins = train_inputs[i_class]
            samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
            inverted = feature_model.invert(samples)
            outs = feature_model(class_ins)
            if i_epoch < i_start_epoch_out:
                ot_loss_out = th.zeros(1, device=class_ins.device)
            else:
                ot_loss_out = ot_euclidean_loss_for_samples(outs[:,:2].squeeze(), samples[:,:2].squeeze())
            ot_loss_in = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())

            other_class_ins = train_inputs[1-i_class]
            changed_to_other_class = class_dist.change_to_other_class(outs, i_class_from=i_class, i_class_to=1-i_class)
            other_inverted = feature_model.invert(changed_to_other_class)
            ot_transformed_in = ot_euclidean_loss_for_samples(other_class_ins.squeeze(), other_inverted.squeeze())
            if i_epoch < i_start_epoch_out:
                ot_transformed_out = th.zeros(1, device=class_ins.device)
            else:
                other_samples = class_dist.get_samples(1-i_class, len(train_inputs[i_class]) * 5)
                ot_transformed_out = ot_euclidean_loss_for_samples(changed_to_other_class[:,:2].squeeze(),
                                                                   other_samples[:,:2].squeeze(),)
            loss = ot_loss_in + ot_loss_out + ot_transformed_in + ot_transformed_out
            loss.backward()
        optim_model.step()
        optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loss: {:.2E}".format(loss.item()))
        print("OT Loss In: {:.2E}".format(ot_loss_in.item()))
        print("OT Loss Out: {:.2E}".format(ot_loss_out.item()))
        print("Transformed OT Loss In: {:.2E}".format(ot_transformed_in.item()))
        print("Transformed OT Loss Out: {:.2E}".format(ot_transformed_out.item()))
        print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)

In [None]:
for n_samples_factor in [5,1]:
    for setname, set_inputs in (("Train", train_inputs), ("Test", test_inputs)):
        total_loss = 0
        total_spec_ot = 0
        for i_class in range(2):
            class_ins = set_inputs[i_class]
            samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * n_samples_factor)
            inverted = feature_model.invert(samples)
            loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
            total_loss += loss
            spec_ot = ot_euclidean_loss_for_samples(
                th.rfft(class_ins.squeeze(), signal_ndim=1, normalized=True).view(class_ins.shape[0],-1),
                th.rfft(inverted.squeeze(), signal_ndim=1, normalized=True).view(inverted.shape[0],-1))
            total_spec_ot += spec_ot
        print("{:d} Samples ". format(len(samples)) + setname + " Loss: {:.1f}".format(total_loss.item() / 2))
        print("{:d} Samples ". format(len(samples)) + setname + " SpecLoss: {:.1f}".format(spec_ot.item() / 2))

# Show losses between sets
for setname, set_inputs in (("Train", train_inputs),):
    total_loss = 0
    total_spec_ot = 0
    for i_class in range(2):
        class_ins = set_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5) # take same number of samples
        inverted = test_inputs[i_class]
        loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
        total_loss += loss
        spec_ot = ot_euclidean_loss_for_samples(
            th.rfft(class_ins.squeeze(), signal_ndim=1, normalized=True).view(class_ins.shape[0],-1),
            th.rfft(inverted.squeeze(), signal_ndim=1, normalized=True).view(inverted.shape[0],-1))
        total_spec_ot += spec_ot
    print(setname + " Loss: {:.1f}".format(total_loss.item() / 2))
    print(setname + " SpecLoss: {:.1f}".format(spec_ot.item() / 2))


In [None]:
for i_class in range(len(train_inputs)):
    fig, axes = plt.subplots(5,4, figsize=(16,12), sharex=True, sharey=True)

    for ax, curve in zip(axes.flatten(), var_to_np(train_inputs[i_class][:len(axes.flatten())]).squeeze()):
        ax.plot(curve, color=seaborn.color_palette()[i_class])
    display_close(fig)
    

In [None]:
for i_class in range(len(train_inputs)):
    samples = class_dist.get_samples(i_class, 20)
    inverted = feature_model.invert(samples)
    fig, axes = plt.subplots(5,4, figsize=(16,12), sharex=True, sharey=True)

    for ax, curve in zip(axes.flatten(), var_to_np(inverted).squeeze()):
        ax.plot(curve, color=seaborn.color_palette()[i_class])
    display_close(fig)
    

In [None]:
for i_class in range(len(train_inputs)):
    samples = class_dist.get_samples(i_class, 2000)
    inverted = feature_model.invert(samples)
    amps_inv = th.sum(th.abs(th.rfft(inverted.squeeze(), 1, )), dim=-1)
    amps_real =  th.sum(th.abs(th.rfft(train_inputs[i_class].squeeze(), 1, )), dim=-1)
    fig = plt.figure(figsize=(8,3))
    plt.plot(np.fft.rfftfreq(512, d=1/512.0), var_to_np(th.mean(amps_real, dim=0)))
    plt.plot(np.fft.rfftfreq(512, d=1/512.0), var_to_np(th.mean(amps_inv, dim=0)))
    plt.legend(("Real", "Fake"))
    display_close(fig)
    fig = plt.figure(figsize=(8,3))
    plt.plot(np.fft.rfftfreq(512, d=1/512.0),
             np.log(var_to_np(th.mean(amps_inv, dim=0))/ var_to_np(th.mean(amps_real, dim=0))))
    plt.legend(("Real", "Fake"))
    display_close(fig)
    

In [None]:
def create_th_grid(dim_0_vals, dim_1_vals):
    curves = []
    for dim_0_val in dim_0_vals:
        this_curves = []
        for dim_1_val in dim_1_vals:
            vals = th.stack((dim_0_val, dim_1_val))
            this_curves.append(vals)
        curves.append(th.stack(this_curves))
    curves = th.stack(curves)
    return curves

In [None]:
vals, i_sorted = th.sort(class_dist.get_mean_std(0,)[1][2:] + class_dist.get_mean_std(1,)[1][2:], descending=True)

print(var_to_np(vals[:8]))
print(var_to_np(i_sorted[:8]) + 2)

In [None]:
i_dim_1 = 0
i_dim_2 = 1
means = th.stack([class_dist.get_mean_std(i_cls,)[0][[i_dim_1, i_dim_2]]
     for i_cls in range(len(train_inputs))])
stds = th.stack([class_dist.get_mean_std(i_cls,)[1][[i_dim_1, i_dim_2]]
     for i_cls in range(len(train_inputs))])
mins = th.min(means - stds * 2, dim=0)[0]
maxs = th.max(means + stds * 2, dim=0)[0]


In [None]:
from matplotlib.patches import Rectangle
outs = [feature_model(ins) for ins in train_inputs]
#c_outs = #[o[:,:] for o in outs]
plt.figure(figsize=(5,5))
for i_cls, c_out in enumerate(outs):
    plt.scatter(var_to_np(c_out[:,i_dim_1]),
                   var_to_np(c_out[:,i_dim_2]), s=50, alpha=0.75, label='Encodings {:s}'.format(

                       ["Right", "Rest"][i_cls]))
ax = plt.gca()
rect = Rectangle(var_to_np(mins), maxs[0].item() - mins[0].item(),  maxs[1].item() - mins[1].item())
ax.add_artist(rect)
rect.set_edgecolor('black')
rect.set_facecolor("None")

In [None]:
n_vals = 30
dim_0_vals = th.linspace(mins[0].item(), maxs[0].item(),n_vals, device=mins.device)
dim_1_vals = th.linspace(mins[1].item(), maxs[1].item(),n_vals, device=mins.device)


grid = create_th_grid(dim_0_vals, dim_1_vals)

In [None]:
full_grid = class_dist.non_class_means.repeat(grid.shape[0],grid.shape[1],1)

full_grid.data[:,:,i_dim_1] = grid[:,:,0]
full_grid.data[:,:,i_dim_2] = grid[:,:,1]
inverted = invert(feature_model, full_grid.view(-1, full_grid.shape[-1])).squeeze()

inverted_grid = inverted.view(len(dim_0_vals), len(dim_1_vals),-1)

x_len, y_len = var_to_np(maxs - mins) / full_grid.shape[:-1]

max_abs = th.max(th.abs(inverted_grid))

y_factor =  (y_len / (2*max_abs)).item() * 0.9

In [None]:
plt.figure(figsize=(32,32))

for i_x in range(full_grid.shape[0]):
    for i_y in range(full_grid.shape[1]):
        x_start = mins[0].item() + x_len * i_x + 0.1 * x_len
        x_end = mins[0].item() + x_len * i_x + 0.9 * x_len
        y_center = mins[1].item() + y_len * i_y + 0.5 * y_len
        
        curve = var_to_np(inverted_grid[i_x][i_y])
        label = ''
        if i_x == 0 and i_y == 0:
            label = 'Generated data'
        plt.plot(np.linspace(x_start, x_end, len(curve)),
                 curve * y_factor + y_center, color='black',
                label=label)
outs = [feature_model(ins) for ins in train_inputs]
#c_outs = [o[:,:2] for o in outs]

plt.gca().set_autoscale_on(False)
for i_cls, c_out in enumerate(outs):
    plt.scatter(var_to_np(c_out[:,i_dim_1]),
               var_to_np(c_out[:,i_dim_2]), s=200, alpha=0.75, label='Encodings {:s}'.format(
               ["Right", "Rest"][i_cls]))
    
plt.legend(fontsize=14)

In [None]:
fig = plt.figure(figsize=(8,2))
for i_class in range(2):
    cur_mean, cur_std = class_dist.get_mean_std(i_class, )
    inverted = invert(feature_model, cur_mean.unsqueeze(0))
    plt.plot(var_to_np(inverted.squeeze()))
plt.legend(("Right Hand", "Rest"))
display(fig)
plt.close(fig)

In [None]:
plt.figure(figsize=(32,24))

outs = [feature_model(ins) for ins in train_inputs]
c_outs = [o[:,:2] for o in outs]

for i_cls, c_out in enumerate(c_outs):
    for a_c_in, a_c_out in zip(train_inputs[i_cls], c_out):
        curve = var_to_np(a_c_in.squeeze())
        x_start = a_c_out[0].item() - x_len * 0.45
        x_end = a_c_out[0].item() + x_len * 0.45
        y_center = a_c_out[1].item()
        plt.plot(np.linspace(x_start, x_end, len(curve)),
                     curve * y_factor + y_center, color=seaborn.color_palette()[i_cls])
plt.xlim(mins[0].item(), maxs[0].item())
plt.ylim(mins[1].item(), maxs[1].item())
plt.gca().set_autoscale_on(False)
for i_cls, c_out in enumerate(c_outs):
    plt.scatter(var_to_np(c_out[:,0]),
               var_to_np(c_out[:,1]), s=200, alpha=0.75, label='Encodings {:s}'.format(
               ["Right", "Rest"][i_cls]))
    
plt.legend(fontsize=14)

In [None]:
plt.figure(figsize=(32,24))

outs = [feature_model(ins) for ins in test_inputs]
c_outs = [o[:,:2] for o in outs]

for i_cls, c_out in enumerate(c_outs):
    for a_c_in, a_c_out in zip(train_inputs[i_cls], c_out):
        curve = var_to_np(a_c_in.squeeze())
        x_start = a_c_out[0].item() - x_len * 0.45
        x_end = a_c_out[0].item() + x_len * 0.45
        y_center = a_c_out[1].item()
        plt.plot(np.linspace(x_start, x_end, len(curve)),
                     curve * y_factor + y_center, color=seaborn.color_palette()[i_cls])
plt.xlim(mins[0].item(), maxs[0].item())
plt.ylim(mins[1].item(), maxs[1].item())
plt.gca().set_autoscale_on(False)
for i_cls, c_out in enumerate(c_outs):
    plt.scatter(var_to_np(c_out[:,0]),
               var_to_np(c_out[:,1]), s=200, alpha=0.75, label='Encodings {:s}'.format(
               ["Right", "Rest"][i_cls]))
    
plt.legend(fontsize=14)

In [None]:
for i_class in range(len(train_inputs)):
    class_ins = train_inputs[i_class]
    outs = feature_model(class_ins)
    cor_probs = th.exp(class_dist.get_class_log_prob(i_class, outs))
    incor_probs = th.exp(class_dist.get_class_log_prob(1-i_class, outs))
    changed_outs = class_dist.change_to_other_class(outs, i_class_from=i_class, i_class_to=1-i_class)
    changed_inverted = invert(feature_model, changed_outs)
    fig, axes = plt.subplots(6,4, figsize=(16,16), sharex=True, sharey=True)
    cor_changed_probs = th.exp(class_dist.get_class_log_prob(i_class, changed_outs))
    incor_changed_probs = th.exp(class_dist.get_class_log_prob(1-i_class, changed_outs))

    for ax, original, changed, p1, p2,p3,p4 in zip(
            axes.flatten(), var_to_np(class_ins).squeeze(), var_to_np(changed_inverted).squeeze(),
            var_to_np(cor_probs), var_to_np(incor_probs),
            var_to_np(cor_changed_probs), var_to_np(incor_changed_probs),
    ):
        ax.plot(original, color=seaborn.color_palette()[i_class])
        ax.plot(changed, color=seaborn.color_palette()[1-i_class], alpha=0.75)
        ax.set_title("{:.1f} vs. {:.1f}\n{:.1f} vs. {:.1f}".format(p1*100,p2*100,p3*100,p4*100))
    plt.subplots_adjust(hspace=0.5)
    display_close(fig)