In [2]:
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy as np
import ot
import matplotlib.pyplot as plt
import timeit
import copy
from math import sqrt

from torch.autograd import grad as torch_grad

In [22]:
device = torch.device('cpu')
noise_size = 62
batch_size = 1024
w_dim = 4
a_dim = int(w_dim*(w_dim - 1)//2)
s_dim = int(2**w_dim)

In [17]:
# calculate T and M

def generate_signs(n: int):
    lst = []
    for i in range(2**n):
        binary_exp = list(bin(i)[2:])
        lst.append((n-len(binary_exp))*[0]+binary_exp)

    res = 2*np.array(lst, dtype=float) - np.ones((2**n,n), dtype=float)
    return res

signs = generate_signs(w_dim)

first_dim = []
second_dim = []
third_dim = []
fourth_dim = []
values = []
M_list = []

for s in range(len(signs)):
    idx = 0
    M_row = []
    for i in range(w_dim):
        for j in range(i+1,w_dim):
            first_dim.append(s)
            second_dim.append(idx)
            third_dim.append(i)
            fourth_dim.append(j)
            values.append(-1*signs[s,j].item())
            first_dim.append(s)
            second_dim.append(idx)
            third_dim.append(j)
            fourth_dim.append(i)
            values.append(signs[s,j].item())
            idx+=1
            M_row.append(signs[s,j].item() * signs[s,i].item())
    M_list.append(M_row)

indices = [first_dim,second_dim,third_dim,fourth_dim]
T = torch.sparse_coo_tensor(indices=indices,values=values, size = (len(signs),a_dim,w_dim,w_dim)).to_dense().contiguous()

M = torch.tensor(M_list).unsqueeze(1).contiguous()


# A function that takes W, H and B (B is the Levy Area of the Brownian Bridge) and computes A = WTH+MB
# where the hell is maribor anyway?
def wthmb(w_in: torch.Tensor, h_in: torch.Tensor, b_in: torch.Tensor):
    _bsz = w_in.shape[0]
    assert w_in.shape == (_bsz,w_dim)
    assert h_in.shape == (_bsz, w_dim)
    assert b_in.shape == (_bsz, a_dim)
    _W = w_in.view(1,1,_bsz,w_dim)
    _H = h_in.view(_bsz,w_dim,1)
    _B = b_in.view(1,_bsz,a_dim)
    WT = torch.matmul(_W, T).permute(0,2,1,3)
    WTH = torch.matmul(WT, _H).squeeze()
    MB = torch.mul(M,_B)
    return torch.flatten(WTH + MB, start_dim=0,end_dim=1)

def wth(w_in: torch.Tensor, h_in: torch.Tensor):
    _bsz = w_in.shape[0]
    assert w_in.shape == (_bsz,w_dim)
    assert h_in.shape == (_bsz, w_dim, 1)
    _W = w_in.view(1,1,_bsz,w_dim)
    WT = torch.matmul(_W, T).permute(0,2,1,3)
    WTH = torch.flatten(torch.matmul(WT, h_in).squeeze(), start_dim=0,end_dim=1).detach()
    return WTH


In [18]:
h_dim = w_dim
W = torch.randn((batch_size,w_dim), dtype= torch.float)
H = torch.randn((batch_size,w_dim), dtype= torch.float)
B = torch.randn((batch_size,a_dim), dtype= torch.float)

# print(W)
# print(H)
# print(B)

In [25]:
start_time = timeit.default_timer()
for i in range(10000):
    a = torch.randperm(batch_size*s_dim)[:batch_size]

elapsed = timeit.default_timer() - start_time
print(elapsed)


2.7808924890000526


In [None]:
class symGenerator1(nn.Module):
    def __init__(self):
        super(symGenerator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,w_dim+a_dim)
        )

    def forward(self, input):
        w = input[:,noise_size:w_dim+noise_size]
        x = self.main(input)
        h = x[:,:w_dim]
        b = x[:,w_dim:w_dim+a_dim]
        return wthmb(w,h,b)



In [None]:
class HsymGenerator1(nn.Module):
    def __init__(self):
        super(HsymGenerator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim +noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,a_dim)
        )

    def forward(self, input):
        w = input[:,noise_size:w_dim+noise_size]
        bsz = input.shape[0]
        h = sqrt(1/12) * torch.randn((bsz,w_dim), dtype=torch.float)
        noise = input[:,:noise_size]
        x = torch.cat((noise,h),dim=1)
        b = self.main(x)
        return wthmb(w,h,b)
