In [1]:
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy as np
import ot
import matplotlib.pyplot as plt
import timeit
import copy
from math import sqrt



In [2]:
device = torch.device('cpu')

noise_size = 62

# Number of training epochs
num_epochs = 15

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

ngpu = 0

# To keep the criterion Lipschitz
weight_cliping_limit = 0.01

batch_size = 1024

test_batch_size = 65536

w_dim = 5

data_dim = int(w_dim*(w_dim - 1)//2)

In [3]:
# create dataloader for samples

def row_processer(row):
    return np.array(row, dtype= np.float32)

filename = f"samples/samples_{w_dim}-dim.csv"
datapipe = dp.iter.FileOpener([filename], mode='b')
datapipe = datapipe.parse_csv(delimiter=',')
datapipe = datapipe.map(row_processer)
dataloader = DataLoader(dataset=datapipe, batch_size=batch_size, num_workers=2)

# Check if the dimensions match
d = next(iter(dataloader))
if d.size(1) != data_dim + w_dim:
    print("!!!!!!!!!!!!!!!!!!!!!!!!! WRONG DATA DIMENSIONS !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),

            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),

            nn.Linear(1024,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256,data_dim)
        )

    def forward(self, input):
        return self.main(input)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim + data_dim,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),

            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024,256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(128,1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [7]:
# initialize nets
netD = Discriminator().to(device)
netD.apply(weights_init)
netG = Generator().to(device)
netG.apply(weights_init)

Generator(
  (main): Sequential(
    (0): Linear(in_features=67, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=1024, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Linear(in_features=1024, out_features=256, bias=True)
    (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [8]:
# Initialise optimiser

#optG = torch.optim.Adam(netG.parameters(),lr = lr, betas=(beta1,0.999))
#optD = torch.optim.Adam(netD.parameters(), lr = lr, betas=(beta1,0.999))
optG = torch.optim.RMSprop(netG.parameters(), lr = lr)
optD = torch.optim.RMSprop(netD.parameters(), lr = lr)

# A fixed W increment for testing purposes
W_fixed: torch.Tensor = torch.tensor([1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7])


W_fixed = W_fixed[:w_dim].unsqueeze(1).transpose(1,0)
W_fixed = W_fixed.expand((test_batch_size,w_dim))

# Load "true" samples generated from this fixed W increment
test_filename = f"samples/fixed_samples_{w_dim}-dim.csv"
A_fixed_true = np.genfromtxt(test_filename,dtype=float,delimiter=',',)
A_fixed_true = A_fixed_true[:,w_dim:(w_dim+data_dim)]

wass_errors = []

iters = 0

one = torch.FloatTensor([1])
mone = one * -1

In [9]:
# Early stopping setup

# Will have two backup points:
# One where the sum of Wasserstein errors was minimal and one where the max was minimal

min_sum = float('inf')
min_sum_errors = [1.0 for i in range(data_dim)]
min_sum_paramsG = copy.deepcopy(netG.state_dict())
min_sum_paramsD = copy.deepcopy(netD.state_dict())

# min_max_err = float('inf')
# min_max_errors = [1.0 for i in range(data_dim)]
# min_max_paramsG = copy.deepcopy(netG.state_dict())
# min_max_paramsD = copy.deepcopy(netD.state_dict())

In [10]:
for epoch in range(num_epochs):

    for i, data in enumerate(dataloader):
        netD.zero_grad()

        # weight clipping so critic is lipschitz
        for p in netD.parameters():
            p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)

        # check actual batch size (last batch could be shorter)
        b_size = data.size(0)

        # Train Discriminator
        # first on real data
        out_D_real = netD(data)
        lossDr = out_D_real.mean(0).view(1)
        lossDr.backward(one)

        # then on fake data

        # data has shape (b_size, w_dim + data_dim) where w_dim are the dimensions of the driving BM and data_dim is the dim of Levy Areas
        W = data[:,:w_dim]
        A_real = data[:,w_dim:(w_dim + data_dim)]
        noise = torch.randn((b_size,noise_size), dtype=torch.float, device=device)
        gen_in = torch.cat((noise,W),1)
        # generate fake data
        generated_A = netG(gen_in)
        fake_in = torch.cat((W,generated_A.detach()),1)

        lossDf = netD(fake_in)
        lossDf = lossDf.mean(0).view(1)
        lossDf.backward(mone)
        lossD = lossDr - lossDf
        optD.step()

        # train Generator with probability 1/5
        if np.random.randint(1,6) == 5:
            netG.zero_grad()

            fake_in = torch.cat((W,generated_A),1)
            lossG = netD(fake_in)
            lossG = lossG.mean(0).view(1)
            lossG.backward(one)
            optG.step()

        if iters%100 == 0:
            # Test Wasserstein error for fixed W
            noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
            g_in = torch.cat((noise,W_fixed),1)
            A_fixed_gen = netG(g_in).detach().numpy()
            errors = [sqrt(ot.wasserstein_1d(A_fixed_true[:,i],A_fixed_gen[:,i],p=2)) for i in range(data_dim)]

            # Print out partial results
            pretty_errors = ["{0:0.5f}".format(i) for i in errors]
            print(f"epoch: {epoch}/{num_epochs}, iter: {iters},\n errors: {pretty_errors}")
            # Save for plotting
            wass_errors.append(errors)

            # Early stopping checkpoint
            error_sum = sum(errors)
            if error_sum <= min_sum:
                min_sum = error_sum
                min_sum_errors = errors
                min_sum_paramsG = copy.deepcopy(netG.state_dict())
                min_sum_paramsD = copy.deepcopy(netD.state_dict())
                print("Saved parameters")


        iters += 1

RuntimeError: running_mean should contain 1024 elements not 512

W_fixed = [1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7]
list_pairs(5) = [(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]


In [None]:
# Return to early stopping checkpoint

best_netG = Generator()
best_netG.load_state_dict(min_sum_paramsG)

torch.save(min_sum_paramsG, f'model_saves/GAN2_{w_dim}d_24epochs_generator.pt')
torch.save(min_sum_paramsD, f'model_saves/GAN2_{w_dim}d_24epochs_discriminator.pt')
# best_netD = Discriminator()
# best_netD.load_state_dict(min_sum_paramsD)

In [None]:
# Test Wasserstein error for fixed W
noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
g_in = torch.cat((noise,W_fixed),1)
A_fixed_gen = best_netG(g_in).detach().numpy()
errors = [sqrt(ot.wasserstein_1d(A_fixed_true[:,i],A_fixed_gen[:,i],p=2)) for i in range(data_dim)]

# Print out partial results
pretty_errors = ["{0:0.5f}".format(i) for i in errors]
print(f"best net errors: {pretty_errors}")

In [None]:
# Draw errors through iterations

plt.figure(figsize=(10,5))
plt.title("2-Wasserstein distance of generated samples from the original samples for fixed W increment")
plt.plot(wass_errors)
plt.xlabel("iterations")
plt.ylabel("Wasserstein distance")
plt.legend()
plt.show()

In [None]:
# Time measurements

W_fixed: torch.Tensor = torch.tensor([1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7])
W_fixed = W_fixed[:w_dim].unsqueeze(1).transpose(1,0)
W_fixed = W_fixed.expand((test_batch_size,w_dim))
noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
g_in = torch.cat((noise,W_fixed),1)
netG.eval()
start_time = timeit.default_timer()
for i in range(100):
    A_fixed_out=netG(g_in)

elapsed = timeit.default_timer() - start_time
print(elapsed)

Takes 34.2s to generate 6553600 samples (original GAN)
Calling iterated_integrals(h = 1.0, err = 0.0005) 6553600-times takes 100.5s