In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.high_gamma import load_file, create_inputs
from reversible2.high_gamma import load_train_test
th.backends.cudnn.benchmark = True
from reversible2.models import deep_invertible


In [None]:
sensor_names = ['Fz', 
                'FC3','FC1','FCz','FC2','FC4',
                'C5','C3','C1','Cz','C2','C4','C6',
                'CP3','CP1','CPz','CP2','CP4',
                'P1','Pz','P2',
                'POz']

In [None]:
# create model
# create dist

train_inputs, test_inputs = load_train_test(
    subject_id=4,
    car=True,
    n_sensors=22,
    final_hz=256,
    start_ms=500,
    stop_ms=1500,
    half_before=True,
    only_load_given_sensors=False,
)

In [None]:
# create model
# create dist

test_dist_inputs, test_dist_inputs_2 = load_train_test(
    subject_id=5,
    car=True,
    n_sensors=22,
    final_hz=256,
    start_ms=500,
    stop_ms=1500,
    half_before=True,
    only_load_given_sensors=False,
)

In [None]:
train_less = [t[:180,7:9].clone().contiguous() for t in train_inputs]
test_less = [t[:180,7:9].clone().contiguous() for t in test_inputs]
test_dist_less = [t[:180,7:9].clone().contiguous() for t in test_dist_inputs]
for t in train_less + test_less + test_dist_less:
    t.data[:,1] = 0

In [None]:
from reversible2.models import larger_model

from reversible2.distribution import TwoClassIndependentDist

import ot

from reversible2.ot_exact import get_matched_samples


from reversible2.model_and_dist import ModelAndDist, set_dist_to_empirical
from reversible2.util import flatten_2d

In [None]:
n_chans = train_less[0].shape[1]
n_time = train_less[0].shape[2]

n_chan_pad = 0
filter_length_time = 11
    
model = larger_model(n_chans, n_time, final_fft=True, kernel_length=11, constant_memory=False)
model.cuda()
dist = TwoClassIndependentDist(np.prod(train_less[0].size()[1:]))
dist.cuda()
model_and_dist = ModelAndDist(model, dist)
set_dist_to_empirical(model_and_dist.model, model_and_dist.dist, train_less)


optim = th.optim.Adam([{'params': dist.parameters(), 'lr':1e-2},
                      {'params': list(model_and_dist.model.parameters()),
                      'lr': 1e-4}])

In [None]:
from reversible2.timer import Timer
i_class = 1
n_epochs = 2001
class_ins = train_less[i_class].cuda()
test_ins = test_less[i_class].cuda()
test_dist_ins = test_dist_less[i_class].cuda()
noise_factor = 1e-2
for i_epoch in range(n_epochs):
    with Timer(verbose=False) as timer:
        optim.zero_grad()
        for i_class in range(2):
            class_ins = train_less[i_class].cuda()
            log_probs = model_and_dist.get_total_log_prob(
                i_class, class_ins + (th.rand_like(class_ins) - 0.5) * noise_factor)
            loss = -th.mean(log_probs)
            loss.backward()
        optim.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.1E} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)

In [None]:
from reversible2.model_and_dist import create_empirical_dist

In [None]:
with th.no_grad():
    print("Actual Model")
    for setname, inputs in (("Train", train_less), ("Test", test_less)):
        corrects = []
        for i_class in range(2):
            outs = model_and_dist.log_softmax(inputs[i_class].cuda())
            pred_label = np.argmax(var_to_np(outs), axis=1)
            correct = pred_label == i_class
            corrects.extend(correct)
        acc = np.mean(corrects)
        print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))



    for name, inputs in (("Train", train_less),
                         ("Combined", [th.cat((train_less[i_class].cuda(),
                                            test_less[i_class].cuda()), dim=0)
                    for i_class in range(2)]),
                         ("Test", test_less)):
        emp_dist = create_empirical_dist(model_and_dist.model, inputs)

        emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
        print(name)
        with th.no_grad():
            for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                corrects = []
                for i_class in range(2):
                    outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                    pred_label = np.argmax(var_to_np(outs), axis=1)
                    correct = pred_label == i_class
                    corrects.append(correct)
                acc = np.mean(np.concatenate(corrects))
                print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                  np.mean(corrects[0]) * 100,
                                                                  np.mean(corrects[1]) * 100))

In [None]:
from reversible2.distribution import TwoClassIndependentDist

In [None]:
mean_stds = th.mean(th.stack((model_and_dist.dist.get_mean_std(0)[1],
                 model_and_dist.dist.get_mean_std(1)[1]),dim=0), dim=0).clone()
_, i_sorted = th.sort(mean_stds,descending=True)
n_dims = 2

In [None]:
with th.no_grad():
    for n_dims in range(1,7):
        print("Class dims", n_dims)
        i_this_dims = i_sorted[:n_dims]
        for name, inputs in (("Actual model", None),
                             ("Train", train_less),
                             ("Combined", [th.cat((train_less[i_class].cuda(),
                                                test_less[i_class].cuda()), dim=0)
                        for i_class in range(2)]),
                             ("Test", test_less)):
            this_dist = TwoClassIndependentDist(len(i_this_dims), truncate_to=None)
            this_dist.cuda()
            for i_class in range(2):
                if inputs is None:
                    mean, std = model_and_dist.dist.get_mean_std(i_class)
                else:
                    this_outs = model_and_dist.model(inputs[i_class].cuda())
                    mean = th.mean(this_outs, dim=0)
                    std = th.std(this_outs, dim=0)
                this_dist.set_mean_std(i_class, mean[i_this_dims], std[i_this_dims])
            print(name)
            with th.no_grad():
                for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                    corrects = []
                    for i_class in range(2):
                        outs = model_and_dist.model(inner_inputs[i_class].cuda())
                        outs = outs[:,i_this_dims]
                        preds = this_dist.log_softmax(outs)
                        pred_label = np.argmax(var_to_np(preds), axis=1)
                        correct = pred_label == i_class
                        corrects.append(correct)
                    acc = np.mean(np.concatenate(corrects))
                    print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                      np.mean(corrects[0]) * 100,
                                                                      np.mean(corrects[1]) * 100))



#### let's investigate lipschitz constant around train and test

In [None]:
with th.no_grad():
    for setname, inputs in (("Train", train_less), ("Test", test_less)):
        corrects = []
        for i_class in range(2):
            outs = model_and_dist.model(inputs[i_class].cuda())
            

In [None]:
n_avg_change = 0.01
perturbations = th.rand_like(outs) - 0.5
norm = (n_avg_change * np.sqrt(perturbations.shape[1]))
perturbations =  norm * (
    perturbations / th.norm(perturbations, dim=1, keepdim=True))


In [None]:
lip_loss

In [None]:
from reversible2.timer import Timer
from reversible2.distribution import TwoClassIndependentDist
i_class = 1
n_epochs = 2001
class_ins = train_less[i_class].cuda()
test_ins = test_less[i_class].cuda()
test_dist_ins = test_dist_less[i_class].cuda()
noise_factor = 1e-2
lip_threshold = 1.3
lip_perturb_factor = 0.1
lip_loss_factor = 1000
for i_epoch in range(n_epochs):
    with Timer(verbose=False) as timer:
        optim.zero_grad()
        for i_class in range(2):
            class_ins = train_less[i_class].cuda()
            log_probs = model_and_dist.get_total_log_prob(
                i_class, class_ins + (th.rand_like(class_ins) - 0.5) * noise_factor)
            loss = -th.mean(log_probs)
            loss.backward()
            
            lip_ins = class_ins
            outs = model_and_dist.model(lip_ins)
            perturbations = th.rand_like(outs) - 0.5
            norm = (lip_perturb_factor * np.sqrt(perturbations.shape[1]))
            perturbations =  norm * (
                perturbations / th.norm(perturbations, p=2, dim=1, keepdim=True))
            perturbed = outs + perturbations
            inverted = model.invert(perturbed)
            diffs = th.norm(flatten_2d(lip_ins) - flatten_2d(inverted), dim=1, p=2) 
            ratio = diffs / norm
            lip_loss = th.mean(F.relu(ratio - lip_threshold) ** 2)
            lip_loss = lip_loss * lip_loss_factor
            lip_loss.backward()
            
        optim.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.1E} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            
            print("Actual Model")
            for setname, inputs in (("Train", train_less), ("Test", test_less)):
                corrects = []
                for i_class in range(2):
                    outs = model_and_dist.log_softmax(inputs[i_class].cuda())
                    pred_label = np.argmax(var_to_np(outs), axis=1)
                    correct = pred_label == i_class
                    corrects.extend(correct)
                acc = np.mean(corrects)
                print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))



            for name, inputs in (("Train", train_less),
                                 ("Combined", [th.cat((train_less[i_class].cuda(),
                                                    test_less[i_class].cuda()), dim=0)
                            for i_class in range(2)]),
                                 ("Test", test_less)):
                emp_dist = create_empirical_dist(model_and_dist.model, inputs)

                emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
                print(name)
                with th.no_grad():
                    for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                        corrects = []
                        for i_class in range(2):
                            outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                            pred_label = np.argmax(var_to_np(outs), axis=1)
                            correct = pred_label == i_class
                            corrects.append(correct)
                        acc = np.mean(np.concatenate(corrects))
                        print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                          np.mean(corrects[0]) * 100,
                                                                          np.mean(corrects[1]) * 100))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)

In [None]:
lip_loss / lip_loss_factor

In [None]:
ratio

In [None]:
norm * lip_threshold

In [None]:
diffs

In [None]:
from reversible2.timer import Timer
from reversible2.distribution import TwoClassIndependentDist
i_class = 1
n_epochs = 2001
class_ins = train_less[i_class].cuda()
test_ins = test_less[i_class].cuda()
test_dist_ins = test_dist_less[i_class].cuda()
noise_factor = 1e-2
lip_threshold = 1.3
lip_perturb_factor = 0.1
lip_loss_factor = 1e5
for i_epoch in range(n_epochs):
    with Timer(verbose=False) as timer:
        optim.zero_grad()
        for i_class in range(2):
            class_ins = train_less[i_class].cuda()
            log_probs = model_and_dist.get_total_log_prob(
                i_class, class_ins + (th.rand_like(class_ins) - 0.5) * noise_factor)
            loss = -th.mean(log_probs)
            loss.backward()
            
            lip_ins = th.cat((class_ins, test_less[i_class].cuda()), dim=0)
            outs = model_and_dist.model(lip_ins)
            outs = 
            perturbations = th.rand_like(outs) - 0.5
            norm = (lip_perturb_factor * np.sqrt(perturbations.shape[1]))
            perturbations =  norm * (
                perturbations / th.norm(perturbations, p=2, dim=1, keepdim=True))
            perturbed = outs + perturbations
            inverted = model.invert(perturbed)
            diffs = th.norm(flatten_2d(lip_ins) - flatten_2d(inverted), dim=1, p=2) 
            ratio = diffs / norm
            lip_loss = th.mean(F.relu(ratio - lip_threshold) ** 2)
            lip_loss = lip_loss * lip_loss_factor
            lip_loss.backward()
            
        optim.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.1E} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            
            print("Actual Model")
            for setname, inputs in (("Train", train_less), ("Test", test_less)):
                corrects = []
                for i_class in range(2):
                    outs = model_and_dist.log_softmax(inputs[i_class].cuda())
                    pred_label = np.argmax(var_to_np(outs), axis=1)
                    correct = pred_label == i_class
                    corrects.extend(correct)
                acc = np.mean(corrects)
                print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))



            for name, inputs in (("Train", train_less),
                                 ("Combined", [th.cat((train_less[i_class].cuda(),
                                                    test_less[i_class].cuda()), dim=0)
                            for i_class in range(2)]),
                                 ("Test", test_less)):
                emp_dist = create_empirical_dist(model_and_dist.model, inputs)

                emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
                print(name)
                with th.no_grad():
                    for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                        corrects = []
                        for i_class in range(2):
                            outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                            pred_label = np.argmax(var_to_np(outs), axis=1)
                            correct = pred_label == i_class
                            corrects.append(correct)
                        acc = np.mean(np.concatenate(corrects))
                        print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                          np.mean(corrects[0]) * 100,
                                                                          np.mean(corrects[1]) * 100))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)



## older, retrain for test dist

In [None]:

optim_dist = th.optim.Adam([{'params': dist.parameters(), 'lr':1e-2},])


In [None]:
with th.no_grad():
    test_outs = [model_and_dist.model(test_less[i_class].cuda()).detach()
                 for i_class in range(2)]


In [None]:
from reversible2.timer import Timer
i_class = 1
n_epochs = 2001
class_ins = train_less[i_class].cuda()
test_ins = test_less[i_class].cuda()
test_dist_ins = test_dist_less[i_class].cuda()
noise_factor = 1e-2
for i_epoch in range(n_epochs):
    with Timer(verbose=False) as timer:
        optim_dist.zero_grad()
        for i_class in range(2):
            log_probs = model_and_dist.dist.get_total_log_prob(
                i_class, test_outs[i_class])
            loss = -th.mean(log_probs)
            loss.backward()
        optim_dist.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.1E} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)

In [None]:
with th.no_grad():
    for setname, inputs in (("Train", train_less), ("Test", test_less)):
        corrects = []
        for i_class in range(2):
            outs = log_softmax(model_and_dist, inputs[i_class].cuda())
            pred_label = np.argmax(var_to_np(outs), axis=1)
            correct = pred_label == i_class
            corrects.extend(correct)
        acc = np.mean(corrects)
        print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))
    

In [None]:
plt.plot(var_to_np(model_and_dist.dist.get_mean_std(0)[1]))
plt.plot(var_to_np(model_and_dist.dist.get_mean_std(1)[1]))

In [None]:
from reversible2.model_and_dist import ModelAndDist, create_empirical_dist

In [None]:
print("Actual Model")
with th.no_grad():
    for setname, inputs in (("Train", train_less), ("Test", test_less)):
        corrects = []
        for i_class in range(2):
            outs = log_softmax(model_and_dist, inputs[i_class].cuda())
            pred_label = np.argmax(var_to_np(outs), axis=1)
            correct = pred_label == i_class
            corrects.extend(correct)
        acc = np.mean(corrects)
        print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))
    


for name, inputs in (("Train", train_less),
                     ("Combined", [th.cat((train_less[i_class].cuda(),
                                        test_less[i_class].cuda()), dim=0)
                for i_class in range(2)]),
                     ("Test", test_less)):
    emp_dist = create_empirical_dist(model_and_dist.model, inputs)

    emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
    print(name)
    with th.no_grad():
        for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
            corrects = []
            for i_class in range(2):
                outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                pred_label = np.argmax(var_to_np(outs), axis=1)
                correct = pred_label == i_class
                corrects.append(correct)
            acc = np.mean(np.concatenate(corrects))
            print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                              np.mean(corrects[0]) * 100,
                                                              np.mean(corrects[1]) * 100))
    

### First reset dist to train

In [None]:
with th.no_grad():
    train_outs = [model_and_dist.model(ins.cuda()) for ins in train_less]

In [None]:
from reversible2.timer import Timer
i_class = 1
n_epochs = 2001
class_ins = train_less[i_class].cuda()
test_ins = test_less[i_class].cuda()
test_dist_ins = test_dist_less[i_class].cuda()
noise_factor = 1e-2
for i_epoch in range(n_epochs):
    with Timer(verbose=False) as timer:
        optim_dist.zero_grad()
        for i_class in range(2):
            log_probs = model_and_dist.dist.get_total_log_prob(
                i_class, train_outs[i_class])
            loss = -th.mean(log_probs)
            loss.backward()
        optim_dist.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.1E} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)

## semi supervised now

In [None]:
from reversible2.timer import Timer
i_class = 1
n_epochs = 2001
both_test_ins = th.cat([test_less[0].cuda(), test_less[1].cuda()], dim=0).detach()

noise_factor = 1e-2
for i_epoch in range(n_epochs):
    with Timer(verbose=False) as timer:
        optim.zero_grad()
            
        noised_ins = both_test_ins + (th.rand_like(both_test_ins) - 0.5) * noise_factor
        log_probs_per_class = [model_and_dist.get_total_log_prob(
                        j_class,noised_ins) for j_class in range(2)]

        total_probs = th.logsumexp(th.stack(log_probs_per_class, dim=-1), dim=1)
        loss = -th.mean(total_probs)
        loss.backward()
        optim.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.1E} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            
        
        with th.no_grad():
            print("Actual Model")
            for setname, inputs in (("Train", train_less), ("Test", test_less)):
                corrects = []
                for i_class in range(2):
                    outs = log_softmax(model_and_dist, inputs[i_class].cuda())
                    pred_label = np.argmax(var_to_np(outs), axis=1)
                    correct = pred_label == i_class
                    corrects.extend(correct)
                acc = np.mean(corrects)
                print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))



            for name, inputs in (("Train", train_less),
                                 ("Combined", [th.cat((train_less[i_class].cuda(),
                                                    test_less[i_class].cuda()), dim=0)
                            for i_class in range(2)]),
                                 ("Test", test_less)):
                emp_dist = create_empirical_dist(model_and_dist.model, inputs)

                emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
                print(name)
                with th.no_grad():
                    for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                        corrects = []
                        for i_class in range(2):
                            outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                            pred_label = np.argmax(var_to_np(outs), axis=1)
                            correct = pred_label == i_class
                            corrects.append(correct)
                        acc = np.mean(np.concatenate(corrects))
                        print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                          np.mean(corrects[0]) * 100,
                                                                          np.mean(corrects[1]) * 100))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)


## Weight decay added

In [None]:
n_chans = train_less[0].shape[1]
n_time = train_less[0].shape[2]

n_chan_pad = 0
filter_length_time = 11
    
model = larger_model(n_chans, n_time, final_fft=True, kernel_length=11, constant_memory=False)
model.cuda()
dist = TwoClassIndependentDist(np.prod(train_less[0].size()[1:]))
dist.cuda()
model_and_dist = ModelAndDist(model, dist)
set_dist_to_empirical(model_and_dist.model, model_and_dist.dist, train_less)


optim = th.optim.Adam([{'params': dist.parameters(), 'lr':1e-2},
                      {'params': list(model_and_dist.model.parameters()),
                      'lr': 1e-4}])

In [None]:
# go through number of dimensions, increasingly and use to classify

In [None]:
from reversible2.timer import Timer
i_class = 1
n_epochs = 2001
both_test_ins = th.cat([test_less[0].cuda(), test_less[1].cuda()], dim=0).detach()

noise_factor = 1e-2
for i_epoch in range(n_epochs):
    if i_epoch > 0: # skip to see starting values
        with Timer(verbose=False) as timer:
            optim.zero_grad()


            for i_class in range(2):
                class_ins = train_less[i_class].cuda()
                log_probs = model_and_dist.get_total_log_prob(
                    i_class, class_ins + (th.rand_like(class_ins) - 0.5) * noise_factor)
                loss = -th.mean(log_probs)
                loss.backward()

            noised_ins = both_test_ins + (th.rand_like(both_test_ins) - 0.5) * noise_factor
            noised_outs = model_and_dist.model(noised_ins)
            log_probs_per_class = [model_and_dist.dist.get_total_log_prob(
                            j_class,noised_outs) for j_class in range(2)]

            total_probs = th.logsumexp(th.stack(log_probs_per_class, dim=-1), dim=1)
            loss = -th.mean(total_probs)
            loss.backward()
            optim.step()

    if i_epoch % (n_epochs // 20) == 0:
        with th.no_grad():
            print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
            print("Runtime {:.0f} ms".format(timer.elapsed))
            text_strs = []
            for setname, inputs in (("Train", class_ins), ("Test", test_ins), ("Other", test_dist_ins)):
                examples = model_and_dist.get_examples(1,len(inputs) * 20)
                matched_examples = get_matched_samples(flatten_2d(inputs), flatten_2d(examples))
                OT = th.mean(th.norm(flatten_2d(inputs).unsqueeze(1)  - matched_examples, p=2, dim=2))#
                nll = -th.mean(model_and_dist.get_total_log_prob(i_class, inputs))
                text_strs.append("{:7s} NLL {:.1E}".format(setname, nll.item()))
                text_strs.append("{:7s} OT {:.1E}".format(setname, OT.item()))
            display_text("\n".join(text_strs))
            
        print("Actual Model")
        with th.no_grad():
            for setname, inputs in (("Train", train_less), ("Test", test_less)):
                corrects = []
                for i_class in range(2):
                    outs = log_softmax(model_and_dist, inputs[i_class].cuda())
                    pred_label = np.argmax(var_to_np(outs), axis=1)
                    correct = pred_label == i_class
                    corrects.extend(correct)
                acc = np.mean(corrects)
                print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))



            for name, inputs in (("Train", train_less),
                                 ("Combined", [th.cat((train_less[i_class].cuda(),
                                                    test_less[i_class].cuda()), dim=0)
                            for i_class in range(2)]),
                                 ("Test", test_less)):
                emp_dist = create_empirical_dist(model_and_dist.model, inputs)

                emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
                print(name)
                with th.no_grad():
                    for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                        corrects = []
                        for i_class in range(2):
                            outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                            pred_label = np.argmax(var_to_np(outs), axis=1)
                            correct = pred_label == i_class
                            corrects.append(correct)
                        acc = np.mean(np.concatenate(corrects))
                        print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                          np.mean(corrects[0]) * 100,
                                                                          np.mean(corrects[1]) * 100))
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
            fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
            for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
                ax.plot(var_to_np(signal).squeeze().T)
                for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                    ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                    ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
            display_close(fig)
            fig = plt.figure()
            plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
            display_close(fig)
            examples = model_and_dist.get_examples(1,len(class_ins) * 20)
            fake_bps = np.abs(np.fft.rfft(var_to_np(examples[:,0]).squeeze()))
            real_bps = np.abs(np.fft.rfft(var_to_np(class_ins[:,0]).squeeze()))
            fig = plt.figure(figsize=(8,3))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(real_bps, axis=0))
            plt.plot(np.fft.rfftfreq(256, 1/256.0), np.mean(fake_bps, axis=0))
            display_close(fig)


In [None]:
with th.no_grad():
    print("Actual Model")
    for setname, inputs in (("Train", train_less), ("Test", test_less)):
        corrects = []
        for i_class in range(2):
            outs = log_softmax(model_and_dist, inputs[i_class].cuda())
            pred_label = np.argmax(var_to_np(outs), axis=1)
            correct = pred_label == i_class
            corrects.extend(correct)
        acc = np.mean(corrects)
        print("{:6s} Accuracy {:.1f}".format(setname, acc * 100))



    for name, inputs in (("Train", train_less),
                         ("Combined", [th.cat((train_less[i_class].cuda(),
                                            test_less[i_class].cuda()), dim=0)
                    for i_class in range(2)]),
                         ("Test", test_less)):
        
        emp_dist = create_empirical_dist(model_and_dist.model, inputs)

        emp_model_dist = ModelAndDist(model_and_dist.model, emp_dist)
        print(name)
        with th.no_grad():
            for setname, inner_inputs in (("Train", train_less), ("Test", test_less)):
                corrects = []
                for i_class in range(2):
                    outs = emp_model_dist.log_softmax(inner_inputs[i_class].cuda())
                    pred_label = np.argmax(var_to_np(outs), axis=1)
                    correct = pred_label == i_class
                    corrects.append(correct)
                acc = np.mean(np.concatenate(corrects))
                print("{:6s} Accuracy {:.1f} ({:.1f}/{:.1f})".format(setname, acc * 100,
                                                                  np.mean(corrects[0]) * 100,
                                                                  np.mean(corrects[1]) * 100))

In [None]:
plt.figure(figsize=(8,3))
for setname, inputs in (("Train", train_less), ("Test", test_less)):
        for i_class in range(2):
            outs = model_and_dist.model(inputs[i_class].cuda())
            std = th.std(outs, dim=0)
            plt.plot(
                var_to_np(std).squeeze(),
                color=seaborn.color_palette()[i_class],
            ls={'Train':'-', 'Test':'--'}[setname])

In [None]:
plt.figure(figsize=(8,3))
for setname, inputs in (("Train", train_less), ("Test", test_less)):
        for i_class in range(2):
            outs = model_and_dist.model(inputs[i_class].cuda())
            std = th.log(th.std(outs, dim=0))
            plt.plot(
                var_to_np(std).squeeze(),
                color=seaborn.color_palette()[i_class],
            ls={'Train':'-', 'Test':'--'}[setname])