In [2]:
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy as np
import ot
import matplotlib.pyplot as plt
import timeit
import copy
from math import sqrt

from torch.autograd import grad as torch_grad

In [5]:
device = torch.device('cpu')

noise_size = 62

# Number of training epochs using classical training
num_epochs = 20

# Number of iterations of Chen training
num_Chen_iters = 5000

# 'Adam' of 'RMSProp'
which_optimizer = 'Adam'

# Learning rate for optimizers
lrG = 0.000005
lrD = 0.00005

# Beta1 hyperparam for Adam optimizers
beta1 = 0

beta2 = 0.99

ngpu = 0

# for gradient penalty
gp_weight = 10.0

batch_size = 7

test_batch_size = 16384

w_dim = 4

a_dim = int(w_dim*(w_dim - 1)//2)

# if 1 use GAN1, if 2 use GAN2, etc.
which_model = 2

# slope for LeakyReLU
leakyReLU_slope = 0.2

# this gives the option to rum the training process multiple times with differently initialised GANs
#num_trials = 1

num_tests_for2d = 8

In [11]:
# calculate T and M

def generate_signs(n: int):
    lst = []
    for i in range(2**n):
        binary_exp = list(bin(i)[2:])
        lst.append((n-len(binary_exp))*[0]+binary_exp)

    res = 2*np.array(lst, dtype=float) - np.ones((2**n,n), dtype=float)
    return res

signs = generate_signs(w_dim)

first_dim = []
second_dim = []
third_dim = []
fourth_dim = []
values = []
M_list = []

for s in range(len(signs)):
    idx = 0
    M_row = []
    for i in range(w_dim):
        for j in range(i+1,w_dim):
            first_dim.append(s)
            second_dim.append(idx)
            third_dim.append(i)
            fourth_dim.append(j)
            values.append(-1*signs[s,j].item())
            first_dim.append(s)
            second_dim.append(idx)
            third_dim.append(j)
            fourth_dim.append(i)
            values.append(signs[s,j].item())
            idx+=1
            M_row.append(signs[s,j].item() * signs[s,i].item())
    M_list.append(M_row)

indices = [first_dim,second_dim,third_dim,fourth_dim]
T = torch.sparse_coo_tensor(indices=indices,values=values, size = (len(signs),a_dim,w_dim,w_dim)).to_dense()

M = torch.tensor(M_list).unsqueeze(1)


# A function that takes W, H and B (B is the Levy Area of the Brownian Bridge) and computes A = WTH+MB
# where the hell is maribor anyway?
def wthmb(w_in: torch.Tensor, h_in: torch.Tensor, b_in: torch.Tensor):
    assert w_in.shape == (batch_size,w_dim) and h_in.shape == (batch_size, w_dim) and b_in.shape == (batch_size, a_dim)
    _W = w_in.view(1,1,batch_size,w_dim)
    _H = h_in.view(batch_size,w_dim,1)
    _B = b_in.view(1,batch_size,a_dim)
    WT = torch.matmul(_W, T).permute(0,2,1,3)
    WTH = torch.matmul(WT, _H).squeeze()
    MB = torch.mul(M,_B)
    return torch.flatten(WTH + MB, start_dim=0,end_dim=1)



In [12]:
h_dim = w_dim
W = torch.arange(start = 1, end = w_dim*batch_size+1, dtype = torch.float).view(batch_size,w_dim)
H = torch.arange(start = 5, end = h_dim*batch_size+5, dtype = torch.float).view(batch_size,h_dim)
B = torch.arange(a_dim*batch_size, dtype=torch.float).view(batch_size,a_dim)
print(W)
print(H)
print(B)

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.],
        [25., 26., 27., 28.]])
tensor([[ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.],
        [25., 26., 27., 28.],
        [29., 30., 31., 32.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.],
        [12., 13., 14., 15., 16., 17.],
        [18., 19., 20., 21., 22., 23.],
        [24., 25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34., 35.],
        [36., 37., 38., 39., 40., 41.]])


In [13]:
A = wthmb(W,H,B)


tensor([[ -4.,  -7., -10.,  -1.,  -4.,   1.],
        [  2.,  -1.,  -4.,   5.,   2.,   7.],
        [  8.,   5.,   2.,  11.,   8.,  13.],
        [ 14.,  11.,   8.,  17.,  14.,  19.],
        [ 20.,  17.,  14.,  23.,  20.,  25.],
        [ 26.,  23.,  20.,  29.,  26.,  31.],
        [ 32.,  29.,  26.,  35.,  32.,  37.],
        [ -4.,  -7.,  10.,  -1.,   4.,  -1.],
        [  2.,  -1.,   4.,   5.,  -2.,  -7.],
        [  8.,   5.,  -2.,  11.,  -8., -13.],
        [ 14.,  11.,  -8.,  17., -14., -19.],
        [ 20.,  17., -14.,  23., -20., -25.],
        [ 26.,  23., -20.,  29., -26., -31.],
        [ 32.,  29., -26.,  35., -32., -37.],
        [ -4.,   7., -10.,   1.,  -4.,  -9.],
        [  2.,   1.,  -4.,  -5.,   2., -15.],
        [  8.,  -5.,   2., -11.,   8., -21.],
        [ 14., -11.,   8., -17.,  14., -27.],
        [ 20., -17.,  14., -23.,  20., -33.],
        [ 26., -23.,  20., -29.,  26., -39.],
        [ 32., -29.,  26., -35.,  32., -45.],
        [ -4.,   7.,  10.,   1.,  

In [15]:
blip = torch.tensor([[1,2,3],[4,5,6]])
blup = blip.repeat(4,1)
print(blup)

tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])


In [None]:
class symGenerator1(nn.Module):
    def __init__(self):
        super(symGenerator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,w_dim+a_dim)
        )

    def forward(self, input):
        w = input[:,noise_size:w_dim+noise_size]
        x = self.main(input)
        h = x[:,:w_dim]
        b = x[:,w_dim:w_dim+a_dim]
        return wthmb(w,h,b)



In [None]:
class HsymGenerator1(nn.Module):
    def __init__(self):
        super(HsymGenerator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim +noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,a_dim)
        )

    def forward(self, input):
        w = input[:,noise_size:w_dim+noise_size]
        bsz = input.shape[0]
        h = sqrt(1/12) * torch.randn((bsz,w_dim), dtype=torch.float)
        noise = input[:,:noise_size]
        x = torch.cat((noise,h),dim=1)
        b = self.main(x)
        return wthmb(w,h,b)
