In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs
from reversible2.invert import invert

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs

In [None]:
from reversible2.graph import Node

In [None]:
o1 = Node(None, SubsampleSplitter(stride=[2,1],chunk_chans_first=False))
o = Node(o1, SubsampleSplitter(stride=[2,1],chunk_chans_first=True))
ins = th.linspace(0,7,8).unsqueeze(0).unsqueeze(1).unsqueeze(-1)

In [None]:
%%time
outs = o.forward(ins)
inverted = o.invert(outs)

In [None]:
assert th.allclose(inverted, ins)

In [None]:
import numpy as np
class SplitEveryNth(nn.Module):
    def __init__(self, n_parts):
        super(SplitEveryNth, self).__init__()
        self.n_parts = n_parts

    def forward(self, x):
        xs = tuple([x[:,i::self.n_parts] for i in range(self.n_parts)])
        return xs
    def invert(self,y):
        new_y = th.zeros((y[0].shape[0], y[0].shape[1] * self.n_parts,) + y[0].shape[2:],
                        device=y[0].device)
        for i in range(self.n_parts):
            new_y[:,i::self.n_parts] = y[i]
        return new_y
    
class Select(nn.Module):
    def __init__(self, index):
        super(Select, self).__init__()
        self.index = index
    
    def forward(self, x):
        return x[self.index]
    
    def invert(self, y):
        return y
    
class Identity(nn.Module):
    def forward(self, *x):
        return x

    
class CatChans(nn.Module):
    def __init__(self):
        self.n_chans = None
        super(CatChans, self).__init__()
        
    def forward(self, *x):
        n_chans = tuple([a_x.size()[1] for a_x in x])
        if self.n_chans is None:
            self.n_chans = n_chans
        else:
            assert n_chans == self.n_chans
        return th.cat(x, dim=1)
    
    def invert(self, y):
        assert self.n_chans is not None, "make forward first"
        xs = []
        bounds = np.insert(np.cumsum(self.n_chans), 0,0)
        for i_b in range(len(bounds) - 1):
            xs.append(y[:,bounds[i_b]:bounds[i_b+1]])
        return xs

In [None]:
from reversible2.graph import Node
o1 = Node(None, SubsampleSplitter(stride=[2,1],chunk_chans_first=False))
o = Node(o1, SubsampleSplitter(stride=[2,1],chunk_chans_first=True))
o = Node(o, SplitEveryNth(2))
o1 = Node(o, Select(0))
o2 = Node(o, Select(1))
o = Node([o1,o2], CatChans())

ins = th.linspace(0,7,8).unsqueeze(0).unsqueeze(1).unsqueeze(-1)




In [None]:
%%time
outs = o.forward(ins)
inverted = o.invert(outs)

assert th.allclose(inverted, ins)

In [None]:
class ChunkChans(nn.Module):
    def __init__(self, n_parts):
        super(ChunkChans, self).__init__()
        self.n_parts = n_parts

    def forward(self, x):
        xs = th.chunk(x, chunks=self.n_parts, dim=1,)
        return xs
    
    def invert(self,y):
        y = th.cat(y, dim=1)
        return y
    

In [None]:
from reversible2.graph import Node
o1 = Node(None, SubsampleSplitter(stride=[2,1],chunk_chans_first=False))
o = Node(o1, SubsampleSplitter(stride=[2,1],chunk_chans_first=True))
o = Node(o, ChunkChans(2))
o1 = Node(o, Select(0))
o2 = Node(o, Select(1))
o = Node([o1,o2], CatChans())

ins = th.linspace(0,7,8).unsqueeze(0).unsqueeze(1).unsqueeze(-1)

In [None]:
%%time
outs = o.forward(ins)
inverted = o.invert(outs)

assert th.allclose(inverted, ins)