In [1]:
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy as np
import ot



In [4]:
device = torch.device('cpu')

noise_size = 62

# Number of training epochs
num_epochs = 3

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

ngpu = 0

weight_cliping_limit = 0.01

batch_size = 1024

test_batch_size = 65536

w_dim = 9

data_dim = int(w_dim*(w_dim - 1)//2)

In [3]:
def row_processer(row):
    return np.array(row, dtype= np.float32)

filename = f"samples/samples_{w_dim}-dim.csv"
datapipe = dp.iter.FileOpener([filename], mode='b')
datapipe = datapipe.parse_csv(delimiter=',')
datapipe = datapipe.map(row_processer)

In [4]:
dataloader = DataLoader(dataset=datapipe, batch_size=batch_size, num_workers=2)

In [5]:
d = next(iter(dataloader))
print(d)
print(d.shape)
if d.size(1) != data_dim + w_dim:
    print("!!!!!!!!!!!!!!!!!!!!!!!!! WRONG DATA DIMENSIONS !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

tensor([[-1.4574e-01, -4.5351e-01, -4.4477e-01,  ..., -9.9213e-05,
          1.1470e-02,  1.0801e-02],
        [ 2.3683e-01, -4.9418e-02,  2.8505e-02,  ..., -2.5169e-02,
          1.3289e-02, -1.4461e-02],
        [ 1.5292e+00, -6.0493e-01,  2.3344e-01,  ...,  5.1749e-03,
          1.6578e-02,  9.9204e-03],
        ...,
        [-9.5029e-01,  8.7256e-01, -1.4139e+00,  ...,  2.3413e-02,
          1.3653e-02,  1.4888e-03],
        [ 7.5400e-01,  1.5456e-01,  2.3939e+00,  ..., -7.3247e-03,
          1.7868e-03, -2.0897e-03],
        [-2.4747e-01, -3.0479e-01, -1.4944e+00,  ..., -2.0072e-02,
          1.2993e-02, -2.9737e-03]])
torch.Size([1024, 15])


In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128,data_dim)
        )

    def forward(self, input):
        return self.main(input)

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim + data_dim,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),

            nn.Linear(128,1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [9]:
netD = Discriminator().to(device)
netD.apply(weights_init)
netG = Generator().to(device)
netG.apply(weights_init)

Generator(
  (main): Sequential(
    (0): Linear(in_features=67, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=128, bias=True)
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [10]:
optG = torch.optim.Adam(netG.parameters(),lr = lr, betas=(beta1,0.999))
optD = torch.optim.Adam(netD.parameters(), lr = lr, betas=(beta1,0.999))

D_losses = []
G_losses = []

iters = 0

one = torch.FloatTensor([1])
mone = one * -1

In [11]:
for epoch in range(num_epochs):

    for i, data in enumerate(dataloader):
        netD.zero_grad()

        for p in netD.parameters():
            p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)


        b_size = data.size(0)

        out_D_real = netD(data)
        lossDr = out_D_real.mean(0).view(1)
        lossDr.backward(one)

        W = data[:,:w_dim]
        A_real = data[:,w_dim:(w_dim + data_dim)]
        noise = torch.randn((b_size,noise_size), dtype=torch.float, device=device)
        gen_in = torch.cat((noise,W),1)
        generated_A = netG(gen_in)
        fake_in = torch.cat((W,generated_A.detach()),1)

        lossDf = netD(fake_in)
        lossDf = lossDf.mean(0).view(1)
        lossDf.backward(mone)
        lossD = lossDr - lossDf
        optD.step()

        if i%10==0:
            netG.zero_grad()

            fake_in = torch.cat((W,generated_A),1)
            lossG = netD(fake_in)
            lossG = lossG.mean(0).view(1)
            lossG.backward(one)
            optG.step()

        if iters%100 == 0:
            print(f"epoch: {epoch}/{num_epochs}, iter: {iters},\nlossD_fake: {lossDf.item()}, lossD_real: {lossDr.item()} lossG: {lossG.item()}")
            G_losses.append(lossG.item())
            D_losses.append(lossD.item())

        iters += 1

epoch: 0/3, iter: 0,
lossD_fake: 0.5025027990341187, lossD_real: 0.5025032758712769 lossG: 0.5025137066841125
epoch: 0/3, iter: 100,
lossD_fake: 0.5003982782363892, lossD_real: 0.49962788820266724 lossG: 0.5004345774650574
epoch: 0/3, iter: 200,
lossD_fake: 0.5004305839538574, lossD_real: 0.4995853900909424 lossG: 0.5004522800445557
epoch: 0/3, iter: 300,
lossD_fake: 0.5004465579986572, lossD_real: 0.49955469369888306 lossG: 0.5004684329032898
epoch: 0/3, iter: 400,
lossD_fake: 0.5004550814628601, lossD_real: 0.49953901767730713 lossG: 0.5004541873931885
epoch: 0/3, iter: 500,
lossD_fake: 0.5003911852836609, lossD_real: 0.4995414614677429 lossG: 0.5004033446311951
epoch: 0/3, iter: 600,
lossD_fake: 0.5003912448883057, lossD_real: 0.4995378255844116 lossG: 0.5004229545593262
epoch: 0/3, iter: 700,
lossD_fake: 0.5004489421844482, lossD_real: 0.49953314661979675 lossG: 0.500440776348114
epoch: 0/3, iter: 800,
lossD_fake: 0.5004638433456421, lossD_real: 0.4995088577270508 lossG: 0.50047433

In [5]:
W_fixed: torch.Tensor = torch.tensor([1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7])
W_fixed = W_fixed[:w_dim].unsqueeze(1).transpose(1,0)
W_fixed = W_fixed.expand((test_batch_size,w_dim))
print(W_fixed)

tensor([[ 1.0000, -0.5000, -1.2000,  ..., -0.9000,  0.1000,  1.7000],
        [ 1.0000, -0.5000, -1.2000,  ..., -0.9000,  0.1000,  1.7000],
        [ 1.0000, -0.5000, -1.2000,  ..., -0.9000,  0.1000,  1.7000],
        ...,
        [ 1.0000, -0.5000, -1.2000,  ..., -0.9000,  0.1000,  1.7000],
        [ 1.0000, -0.5000, -1.2000,  ..., -0.9000,  0.1000,  1.7000],
        [ 1.0000, -0.5000, -1.2000,  ..., -0.9000,  0.1000,  1.7000]])


In [18]:
noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
g_in = torch.cat((noise,W_fixed),1)
A_fixed_gen = netG(g_in).detach().numpy()
print(A_fixed_gen.shape)

(65536, 10)


In [19]:
#output = torch.cat((W_fixed,A_fixed_gen), 1)
#print(output.shape)
# np.savetxt("fixed_GAN_out_2d.csv", output.detach()[:,2], delimiter=",")

In [20]:
A_fixed_gen = A_fixed_gen
test_filename = f"samples/fixed_samples_{w_dim}-dim.csv"
samples = np.genfromtxt(test_filename,dtype=float,delimiter=',',)
A_fixed_true = samples[:,w_dim:(w_dim+data_dim)]
for i in range(data_dim):
    true_col = A_fixed_true[:,i]
    generated_col = A_fixed_gen[:,i]
    dist = ot.wasserstein_1d(true_col,generated_col,p=2)
    print(dist)

0.003986282303522681
0.015288523266045864
0.01639928861573666
0.014736991408361905
0.012501982121504222
0.018853810314580038
0.002594853676512707
0.006951561533124371
0.0139525881855331
0.006613752276877162
