In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')
%cd /home/schirrmr/


In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png' 
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

In [None]:

from torch.nn.parameter import Parameter
from torch.nn.utils.weight_norm import _norm
class WeightNorm(object):
    def __init__(self, name, dim, fixed_norm):
        self.name = name
        self.dim = dim
        self.fixed_norm = fixed_norm

    def compute_weight(self, module):
        v = getattr(module, self.name + '_v')
        if self.fixed_norm is None:
            g = getattr(module, self.name + '_g')
        else:
            g = self.fixed_norm
        return v * (g / _norm(v, self.dim))

    @staticmethod
    def apply(module, name, dim, fixed_norm):
        fn = WeightNorm(name, dim, fixed_norm)

        weight = getattr(module, name)

        # remove w from parameter list
        del module._parameters[name]

        # add g and v as new parameters and express w as g/||v|| * v
        if fixed_norm is None:
            module.register_parameter(name + '_g', Parameter(_norm(weight, dim).data))
        module.register_parameter(name + '_v', Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn

    def remove(self, module):
        weight = self.compute_weight(module)
        delattr(module, self.name)
        if self.fixed_norm is None:
            del module._parameters[self.name + '_g']
        del module._parameters[self.name + '_v']
        module.register_parameter(self.name, Parameter(weight.data))

    def __call__(self, module, inputs):
        setattr(module, self.name, self.compute_weight(module))


def weight_norm(module, name='weight', dim=0, fixed_norm=None):
    r"""Applies weight normalization to a parameter in the given module.

    .. math::
         \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}

    Weight normalization is a reparameterization that decouples the magnitude
    of a weight tensor from its direction. This replaces the parameter specified
    by `name` (e.g. "weight") with two parameters: one specifying the magnitude
    (e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
    Weight normalization is implemented via a hook that recomputes the weight
    tensor from the magnitude and direction before every :meth:`~Module.forward`
    call.

    By default, with `dim=0`, the norm is computed independently per output
    channel/plane. To compute a norm over the entire weight tensor, use
    `dim=None`.

    See https://arxiv.org/abs/1602.07868

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter
        dim (int, optional): dimension over which to compute the norm

    Returns:
        The original module with the weight norm hook

    Example::

        >>> m = weight_norm(nn.Linear(20, 40), name='weight')
        Linear (20 -> 40)
        >>> m.weight_g.size()
        torch.Size([40, 1])
        >>> m.weight_v.size()
        torch.Size([40, 20])

    """
    WeightNorm.apply(module, name, dim, fixed_norm=fixed_norm)
    return module

In [None]:
# 1d example uniform 2,-2 two points at 1,-1

In [None]:
from numpy.random import RandomState
from braindecode.torch_ext.util import np_to_var, var_to_np
from IPython.display import display

In [None]:
x = np_to_var([-1,1], dtype=np.float32, requires_grad=True)
x_fake = np_to_var(np.linspace(-2,2, 400), dtype=np.float32, requires_grad=True)


In [None]:
from torch.nn import ConstantPad2d
from reversible.revnet import ReversibleBlockOld, SubsampleSplitter, ViewAs
import torch as th
from torch import nn
class ConcatReLU(nn.Module):
    def __init__(self):
        super(ConcatReLU, self).__init__()

    def forward(self, x):
        return th.cat((nn.functional.relu(x), nn.functional.relu(-x)), dim=1)

class Clip(nn.Module):
    def __init__(self):
        super(Clip, self).__init__()

    def forward(self, x):
        return th.clamp(x, 0,1.2)
    

In [None]:
from braindecode.torch_ext.util import set_random_seeds
set_random_seeds(0, False) # worked with 102398213
model = nn.Sequential(nn.Linear(1,2), ConcatReLU(), 
                      Clip(), nn.Linear(4,1))
model[0].bias.data[0] = -1
model[0].bias.data[1] = 1
model[0].weight.data[0] = 1
model[0].weight.data[1] = 1
model[3].weight.data[:] = 0.25
model[3].bias.data[:] = 0

In [None]:
list(model.parameters())

In [None]:
fake_out = model(x_fake.unsqueeze(1))
fig = plt.figure(figsize=(12,3))
plt.plot(var_to_np(x_fake), var_to_np(fake_out).squeeze())
plt.ylabel('Discriminator Score')
plt.xlabel('Input')
display(fig)
plt.close(fig)
fig = plt.figure(figsize=(12,3))
part_model = nn.Sequential(model[0], model[1], model[2])
out_part = part_model(x_fake.unsqueeze(1))

plt.plot(var_to_np(x_fake), var_to_np(out_part[:,0] + out_part[:,2]).squeeze())

plt.plot(var_to_np(x_fake), var_to_np(out_part[:,1] + out_part[:,3]).squeeze())
display(fig)
plt.close(fig)

In [None]:

optimizer = th.optim.Adam(model.parameters(), lr=5e-3)

In [None]:


n_epochs = 2000
for i_epoch in range(n_epochs):
    real_out = model(x.unsqueeze(1))
    fake_out = model(x_fake.unsqueeze(1))

    loss = -(th.mean(real_out) - th.mean(fake_out))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i_epoch % (n_epochs // 20) == 0:

        fig = plt.figure(figsize=(12,3))
        plt.plot(var_to_np(x_fake), var_to_np(fake_out).squeeze())
        plt.ylabel('Discriminator Score')
        plt.xlabel('Input')
        display(fig)
        plt.close(fig)
        print(loss.item())
        print(model[0].weight)
        print(model[0].bias)
        fig = plt.figure(figsize=(12,3))
        part_model = nn.Sequential(model[0], model[1], model[2])
        out_part = part_model(x_fake.unsqueeze(1))

        plt.plot(var_to_np(x_fake), var_to_np(out_part[:,0] + out_part[:,2]).squeeze())

        plt.plot(var_to_np(x_fake), var_to_np(out_part[:,1] + out_part[:,3]).squeeze())
        display(fig)
        plt.close(fig)

#### now with weight norm regularization on a point at 0,0

In [None]:
class PointWiseMultLayer(nn.Module):
    def __init__(self):
        super(PointWiseMultLayer,self).__init__()
        self.weights = th.ones(4, requires_grad=True)
    def forward(self, x):
        return x * self.weights.unsqueeze(0)

valid_x = np_to_var([-0.5, 0.5], dtype=np.float32, requires_grad=True)

from braindecode.torch_ext.util import set_random_seeds
set_random_seeds(0, False) # worked with 102398213
model = nn.Sequential(nn.Linear(1,2), ConcatReLU(), 
                      Clip(), PointWiseMultLayer(), nn.Linear(4,1))
model[0].bias.data[0] = -1
model[0].bias.data[1] = 1
model[0].weight.data[0] = 1
model[0].weight.data[1] = 1
model[4].weight.data[:] = 0.25
model[4].bias.data[:] = 0


fake_out = model(x_fake.unsqueeze(1))
fig = plt.figure(figsize=(12,3))
plt.plot(var_to_np(x_fake), var_to_np(fake_out).squeeze())
display(fig)
plt.close(fig)
fig = plt.figure(figsize=(12,3))
part_model = nn.Sequential(model[0], model[1], model[2])
out_part = part_model(x_fake.unsqueeze(1))

plt.plot(var_to_np(x_fake), var_to_np(out_part[:,0] + out_part[:,2]).squeeze())

plt.plot(var_to_np(x_fake), var_to_np(out_part[:,1] + out_part[:,3]).squeeze())
display(fig)
plt.close(fig)

optim_wnorm = th.optim.Adam([model[3].weights], lr=5e-3)
optimizer = th.optim.Adam(model.parameters(), lr=5e-3)

In [None]:
x_fake = np_to_var(np.linspace(-2,2, 400), dtype=np.float32, requires_grad=True)


n_epochs = 2000
for i_epoch in range(n_epochs):
    real_out = model(x.unsqueeze(1))
    fake_out = model(x_fake.unsqueeze(1))

    loss = -(th.mean(real_out) - th.mean(fake_out))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    valid_out = model(valid_x.unsqueeze(1))
    loss = -th.mean(valid_out)
    optim_wnorm.zero_grad()
    loss.backward()
    optim_wnorm.step()
    if i_epoch % (n_epochs // 20) == 0:

        fig = plt.figure(figsize=(12,3))
        plt.plot(var_to_np(x_fake), var_to_np(fake_out).squeeze())
        plt.ylabel('Discriminator Score')
        plt.xlabel('Input')
        display(fig)
        plt.close(fig)
        print("loss {:.3f}".format(loss.item()))
        print("wnorms", [model[3].weights])
        print(model[0].weight)
        print(model[0].bias)
        part_model = nn.Sequential(model[0], model[1], model[2])
        out_part = part_model(x_fake.unsqueeze(1))
        fig = plt.figure(figsize=(12,3))

        plt.plot(var_to_np(x_fake), var_to_np(out_part[:,0] + out_part[:,2]).squeeze())

        plt.plot(var_to_np(x_fake), var_to_np(out_part[:,1] + out_part[:,3]).squeeze())
        display(fig)
        plt.close(fig)

### optimize generated points

In [None]:
valid_x = np_to_var([-0.5, 0.5], dtype=np.float32, requires_grad=True)

from braindecode.torch_ext.util import set_random_seeds
set_random_seeds(0, False) # worked with 102398213
model = nn.Sequential(nn.Linear(1,2), ConcatReLU(), 
                      Clip(), PointWiseMultLayer(), nn.Linear(4,1))
model[0].bias.data[0] = -1
model[0].bias.data[1] = 1
model[0].weight.data[0] = 1
model[0].weight.data[1] = 1
model[4].weight.data[:] = 0.25
model[4].bias.data[:] = 0

x_gen = np_to_var(np.linspace(-2,2, 400), dtype=np.float32, requires_grad=True)

fake_out = model(x_fake.unsqueeze(1))
fig = plt.figure(figsize=(12,3))
plt.plot(var_to_np(x_fake), var_to_np(fake_out).squeeze())
display(fig)
plt.close(fig)
fig = plt.figure(figsize=(12,3))
part_model = nn.Sequential(model[0], model[1], model[2])
out_part = part_model(x_fake.unsqueeze(1))

plt.plot(var_to_np(x_fake), var_to_np(out_part[:,0] + out_part[:,2]).squeeze())

plt.plot(var_to_np(x_fake), var_to_np(out_part[:,1] + out_part[:,3]).squeeze())
display(fig)
plt.close(fig)

optim_wnorm = th.optim.Adam([model[3].weights], lr=5e-3)
optimizer = th.optim.Adam(model.parameters(), lr=5e-3)
optim_gen = th.optim.Adam([x_gen], lr=5e-3)

In [None]:
x_fake = np_to_var(np.linspace(-2,2, 400), dtype=np.float32, requires_grad=True)

n_epochs = 2000
for i_epoch in range(n_epochs):
    real_out = model(x.unsqueeze(1))
    fake_out = model(x_gen.unsqueeze(1))

    loss = -(th.mean(real_out) - th.mean(fake_out))
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()
    x_gen.grad.data.neg_()
    optim_gen.step()
    valid_out = model(valid_x.unsqueeze(1))
    loss = -th.mean(valid_out)
    optim_wnorm.zero_grad()
    loss.backward()
    optim_wnorm.step()
    if i_epoch % (n_epochs // 20) == 0:
        fake_out = model(x_fake.unsqueeze(1))
        fig = plt.figure(figsize=(12,3))
        plt.ylabel('Discriminator Score')
        plt.xlabel('Input')
        plt.plot(var_to_np(x_fake), var_to_np(fake_out).squeeze())
        
        display(fig)
        plt.close(fig)
        
        fig = plt.figure(figsize=(12,3))
        plt.plot(var_to_np(x_gen).squeeze(), var_to_np(x_gen).squeeze() * 0, ls='', marker='o', alpha=0.5)
        plt.xlabel('Input (generated)')
        
        display(fig)
        plt.close(fig)
        print("loss {:.3f}".format(loss.item()))
        print("wnorms", [model[3].weights])
        print(model[0].weight)
        print(model[0].bias)
        part_model = nn.Sequential(model[0], model[1], model[2])
        out_part = part_model(x_fake.unsqueeze(1))
        fig = plt.figure(figsize=(12,3))

        plt.plot(var_to_np(x_fake), var_to_np(out_part[:,0] + out_part[:,2]).squeeze())

        plt.plot(var_to_np(x_fake), var_to_np(out_part[:,1] + out_part[:,3]).squeeze())
        display(fig)
        plt.close(fig)

In [None]:
print(list(model.parameters()))

In [None]:
real_out = model(x.unsqueeze(1))

In [None]:
fake_out = model(x_gen.unsqueeze(1))

In [None]:
fake_out