In [1]:
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy as np
import ot
import matplotlib.pyplot as plt
import timeit
import copy
from math import sqrt



To cite:

J. M. C. Clark and R. J. Cameron. The maximum rate of convergence of discrete approximations for Stochastic differential equations. in Stochastic Differential Systems Filtering and Control, ed. by Grigelionis (Springer, Berlin), 1980.
A. S. Dickinson. Optimal Approximation of the Second Iterated Integral of Brownian Motion. Stochastic Analysis and Applications, 25(5):1109{1128, 2007.

F. Kastner, A. Rößler. "An Analysis of Approximation Algorithms for Iterated Stochastic Integrals and a Julia and Matlab Simulation Toolbox". arXiv:2201.08424

Foster, J. M. Numerical Approximations for Stochastic Differential Equations. University of Oxford, 2020.

In [2]:
device = torch.device('cpu')

noise_size = 32

# Number of training epochs using classical training
num_epochs = 15

# Number of iterations of Chen training
num_Chen_iters = 5000

# 'Adam' of 'RMSProp'
which_optimizer = 'RMSProp'

# Learning rate for optimizers
lr = 0.00002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

ngpu = 0

# To keep the criterion Lipschitz
weight_cliping_limit = 0.01

batch_size = 1024

test_batch_size = 65536

w_dim = 4

a_dim = int(w_dim*(w_dim - 1)//2)

# if 1 use GAN1, if 2 use GAN2, etc.
which_model = 2

# slope for LeakyReLU
leakyReLU_slope = 0.2

# this gives the option to rum the training process multiple times with differently initialised GANs
num_trials = 1

In [3]:
# CHEN RELATION
# Levy-area satisfies a version of the Chen relation (see Chen_relation.pdf) and is the unique distribution which satisfies this version of the relation

def chen_combine(w_a_in: torch.TensorType):
    # the batch dimension of the inputs will be quartered
    out_size = w_a_in.size(0)//2
    assert 2*out_size == w_a_in.size(0)
    assert w_a_in.size(1) == w_dim + a_dim

    # w_0_s is from 0 to t/2 and w_s_t is from t/2 to t
    w_0_s,w_s_t = w_a_in.chunk(2)
    result = torch.clone(w_0_s + w_s_t)
    result[:,:w_dim] = sqrt(0.5)*result[:,:w_dim]
    result[:,w_dim:(w_dim+a_dim)] = 0.5*result[:,w_dim:(w_dim+a_dim)]

    idx = w_dim
    for k in range(w_dim - 1):
        for l in range(k+1,w_dim):
            correction_term = 0.25*(w_0_s[:,k]*w_s_t[:,l] - w_0_s[:,l]*w_s_t[:,k])
            result[:,idx] += correction_term
            idx += 1

    return result

# prints the 2-Wasserstein distances (in each of the Levy-area dimensions) between the input and chen_combine(chen_combine(input))
# The idea behind this is that Levy-area is the unique distribution which is close to chen_combine of itself
# Indeed this is experimentally confirmed in test.ipynb

def chen_error_2step(w_a_in: torch.TensorType):
    combined_data = chen_combine(w_a_in)
    combined_data = chen_combine(combined_data)
    return [sqrt(ot.wasserstein_1d(combined_data[:,w_dim+i],w_a_in[:,w_dim+i],p=2)) for i in range(a_dim)]

In [4]:
# create dataloader for samples

def row_processer(row):
    return np.array(row, dtype= np.float32)

filename = f"samples/samples_{w_dim}-dim.csv"
datapipe = dp.iter.FileOpener([filename], mode='b')
datapipe = datapipe.parse_csv(delimiter=',')
datapipe = datapipe.map(row_processer)
dataloader = DataLoader(dataset=datapipe, batch_size=batch_size, num_workers=2)

# Check if the dimensions match
d = next(iter(dataloader))
if d.size(1) != a_dim + w_dim:
    print("!!!!!!!!!!!!!!!!!!!!!!!!! WRONG DATA DIMENSIONS !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [6]:
# GAN 1

class Generator1(nn.Module):
    def __init__(self):
        super(Generator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,a_dim)
        )

    def forward(self, input):
        return self.main(input)


class Discriminator1(nn.Module):
    def __init__(self):
        super(Discriminator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim + a_dim,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(512,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(128,1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [7]:
# GAN 2

class Generator2(nn.Module):
    def __init__(self):
        super(Generator2, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),

            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),

            nn.Linear(1024,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,a_dim)
        )

    def forward(self, input):
        return self.main(input)



class Discriminator2(nn.Module):
    def __init__(self):
        super(Discriminator2, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim + a_dim,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(1024,256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# GAN 3

class Generator3(nn.Module):
    def __init__(self):
        super(Generator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim+noise_size,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128,a_dim)
        )

    def forward(self, input):
        return self.main(input)


class Discriminator3(nn.Module):
    def __init__(self):
        super(Discriminator1, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(w_dim + a_dim,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(512,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(leakyReLU_slope),

            nn.Linear(128,1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [8]:
# initialize nets
if which_model == 1:
    netD = Discriminator1().to(device)
    netG = Generator1().to(device)
elif which_model == 2:
    netD = Discriminator2().to(device)
    netG = Generator2().to(device)


netD.apply(weights_init)
netG.apply(weights_init)

Generator2(
  (main): Sequential(
    (0): Linear(in_features=36, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=1024, out_features=512, bias=True)
    (7): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Linear(in_features=512, out_features=6, bias=True)
  )
)

In [9]:
# Initialise optimiser

if which_optimizer == 'Adam':
    optG = torch.optim.Adam(netG.parameters(),lr = lr, betas=(beta1,0.999))
    optD = torch.optim.Adam(netD.parameters(), lr = lr, betas=(beta1,0.999))
elif which_optimizer == 'RMSProp':
    optG = torch.optim.RMSprop(netG.parameters(), lr = lr)
    optD = torch.optim.RMSprop(netD.parameters(), lr = lr)

# A fixed W increment for testing purposes
W_fixed: torch.Tensor = torch.tensor([1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7])


W_fixed = W_fixed[:w_dim].unsqueeze(1).transpose(1,0)
W_fixed = W_fixed.expand((test_batch_size,w_dim))

# Load "true" samples generated from this fixed W increment
test_filename = f"samples/fixed_samples_{w_dim}-dim.csv"
A_fixed_true = np.genfromtxt(test_filename,dtype=float,delimiter=',',)
A_fixed_true = A_fixed_true[:,w_dim:(w_dim+a_dim)]

wass_errors = []
chen_errors = []

iters = 0

one = torch.FloatTensor([1])
mone = one * -1

In [10]:
# Early stopping setup

# Will have two backup points:
# One where the sum of Wasserstein errors was minimal and one where the max was minimal

min_sum = float('inf')
min_sum_errors = [1.0 for i in range(a_dim)]
min_sum_paramsG = copy.deepcopy(netG.state_dict())
min_sum_paramsD = copy.deepcopy(netD.state_dict())

# min_max_err = float('inf')
# min_max_errors = [1.0 for i in range(a_dim)]
# min_max_paramsG = copy.deepcopy(netG.state_dict())
# min_max_paramsD = copy.deepcopy(netD.state_dict())

In [11]:
for epoch in range(num_epochs):

    for i, data in enumerate(dataloader):
        netD.zero_grad()

        # weight clipping so critic is lipschitz
        for p in netD.parameters():
            p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)

        # check actual batch size (last batch could be shorter)
        b_size = data.size(0)

        # Train Discriminator
        # first on real data
        out_D_real = netD(data)
        lossDr = out_D_real.mean(0).view(1)
        lossDr.backward(one)

        # then on fake data

        # data has shape (b_size, w_dim + a_dim) where w_dim are the dimensions of the driving BM and a_dim is the dim of Levy Areas
        W = data[:,:w_dim]
        A_real = data[:,w_dim:(w_dim + a_dim)]
        noise = torch.randn((b_size,noise_size), dtype=torch.float, device=device)
        gen_in = torch.cat((noise,W),1)
        # generate fake data
        generated_A = netG(gen_in)
        fake_in = torch.cat((W,generated_A.detach()),1)

        lossDf = netD(fake_in)
        lossDf = lossDf.mean(0).view(1)
        lossDf.backward(mone)
        lossD = lossDr - lossDf
        optD.step()

        # train Generator with probability 1/5
        if iters%5 == 0:
            netG.zero_grad()

            fake_in = torch.cat((W,generated_A),1)
            lossG = netD(fake_in)
            lossG = lossG.mean(0).view(1)
            lossG.backward(one)
            optG.step()

        if iters%100 == 0:
            # Test Wasserstein error for fixed W
            noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
            g_in = torch.cat((noise,W_fixed),1)
            A_fixed_gen = netG(g_in).detach().numpy()
            errors = [sqrt(ot.wasserstein_1d(A_fixed_true[:,i],A_fixed_gen[:,i],p=2)) for i in range(a_dim)]

            # Test Chen discrepancy
            # W = torch.randn((4*batch_size, w_dim), dtype= torch.float, device=device)
            # noise = torch.randn((4*batch_size,noise_size), dtype=torch.float, device=device)
            # gen_in = torch.cat((noise,W),1)
            # A_gen = netG(gen_in)
            # w_a = torch.cat((W,A_gen.detach()),1)
            # ch_err = chen_error_2step(w_a)

            # Print out partial results
            pretty_errors = ["{0:0.5f}".format(i) for i in errors]
            # pretty_chen_errors = ["{0:0.5f}".format(i) for i in ch_err]
            print(f"epoch: {epoch}/{num_epochs}, iter: {iters},\n errors: {pretty_errors}")
            # Save for plotting
            wass_errors.append(errors)
            # chen_errors.append(ch_err)

            # Early stopping checkpoint
            error_sum = sum(errors)
            if error_sum <= min_sum:
                min_sum = error_sum
                min_sum_errors = errors
                min_sum_paramsG = copy.deepcopy(netG.state_dict())
                min_sum_paramsD = copy.deepcopy(netD.state_dict())
                print("Saved parameters")

        iters += 1

epoch: 0/15, iter: 0,
 errors: ['0.05891', '0.40264', '0.02776', '0.18810', '0.40710', '0.04202']
Saved parameters
epoch: 0/15, iter: 100,
 errors: ['0.03545', '0.34840', '0.05786', '0.16455', '0.43127', '0.03444']
Saved parameters
epoch: 0/15, iter: 200,
 errors: ['0.11171', '0.33866', '0.13599', '0.16647', '0.42639', '0.09630']
epoch: 0/15, iter: 300,
 errors: ['0.13057', '0.34119', '0.11026', '0.18883', '0.44122', '0.13984']
epoch: 0/15, iter: 400,
 errors: ['0.12433', '0.34987', '0.15813', '0.18647', '0.41546', '0.09783']
epoch: 0/15, iter: 500,
 errors: ['0.12112', '0.35945', '0.10727', '0.19406', '0.42196', '0.09922']
epoch: 0/15, iter: 600,
 errors: ['0.12191', '0.34968', '0.10475', '0.21655', '0.43706', '0.11088']
epoch: 0/15, iter: 700,
 errors: ['0.10223', '0.34541', '0.10955', '0.18368', '0.42270', '0.10725']
epoch: 0/15, iter: 800,
 errors: ['0.07061', '0.34116', '0.10192', '0.16770', '0.41853', '0.08501']
epoch: 0/15, iter: 900,
 errors: ['0.06856', '0.35112', '0.09066', '

KeyboardInterrupt: 

W_fixed = [1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7]
list_pairs(5) = [(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]

GAN2 best: ['0.06348', '0.25953', '0.06187', '0.11594', '0.12005', '0.11992', '0.07918', '0.15956', '0.16242', '0.01383']

In [None]:
# Return to early stopping checkpoint
if which_model == 1:
    best_netG = Generator1().to(device)
elif which_model == 2:
    best_netG = Generator2().to(device)

best_netG.load_state_dict(min_sum_paramsG)

torch.save(min_sum_paramsG, f'model_saves/GAN2_{w_dim}d_24epochs_generator.pt')
torch.save(min_sum_paramsD, f'model_saves/GAN2_{w_dim}d_24epochs_discriminator.pt')
# best_netD = Discriminator()
# best_netD.load_state_dict(min_sum_paramsD)

In [None]:
# Test Wasserstein error for fixed W
noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
g_in = torch.cat((noise,W_fixed),1)
A_fixed_gen = best_netG(g_in).detach().numpy()
errors = [sqrt(ot.wasserstein_1d(A_fixed_true[:,i],A_fixed_gen[:,i],p=2)) for i in range(a_dim)]

# Print out partial results
pretty_errors = ["{0:0.5f}".format(i) for i in errors]
print(f"best net errors: {pretty_errors}")

best net errors: ['0.06348', '0.25953', '0.06187', '0.11594', '0.12005', '0.11992', '0.07918', '0.15956', '0.16242', '0.01383']

In [None]:
# Draw errors through iterations

plt.figure(figsize=(10,5))
plt.title("2-Wasserstein distance of generated samples from the original samples for fixed W increment")
plt.plot(wass_errors)
plt.xlabel("iterations")
plt.ylabel("Wasserstein distance")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.title("2-Wasserstein distances after 2-step Chen recombinations")
plt.plot(chen_errors)
plt.xlabel("iterations")
plt.ylabel("2-Wasserstein distance")
plt.legend()
plt.show()

In [None]:
chen_iters = 0
chen_training_wass_errors = []
chen_training_chen_errors = []
for i in range(num_Chen_iters):
    netD.zero_grad()

    # weight clipping so critic is lipschitz
    for p in netD.parameters():
        p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)

    # Train Discriminator
    # generate 4*batch_size of fake data
    W = torch.randn((4*batch_size, w_dim), dtype= torch.float, device=device)
    noise = torch.randn((4*batch_size,noise_size), dtype=torch.float, device=device)
    gen_in = torch.cat((noise,W),1)
    A_gen = netG(gen_in)
    fake_in = torch.cat((W,A_gen.detach()),1)
    lossD_fake = netD(fake_in)
    lossD_fake = lossD_fake.mean(0).view(1)
    lossD_fake.backward(mone)

    # now use chen_combine to produce "true" data from the fake one
    # using chen_combine twice reduces batch dimension from 4*batch_size to batch_size
    true_data = chen_combine(fake_in.detach())
    true_data = chen_combine(true_data)
    true_data = chen_combine(true_data)
    assert true_data.size(0) == batch_size//2

    lossD_real = netD(true_data)
    lossD_real = 8 * lossD_real.mean(0).view(1) # multiply by 4 to counteract the 4x smaller batch
    lossD_real.backward(one)
    optD.step()

    # train Generator with probability 1/5
    # if np.random.randint(1,6) == 5:
    if True:
        netG.zero_grad()

        fake_in = torch.cat((W,A_gen),1)
        lossG = netD(fake_in)
        lossG = lossG.mean(0).view(1)
        lossG.backward(one)
        optG.step()

    if chen_iters%100 == 0:
        # Test Wasserstein error for fixed W
        noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
        g_in = torch.cat((noise,W_fixed),1)
        A_fixed_gen = netG(g_in).detach().numpy()
        errors = [sqrt(ot.wasserstein_1d(A_fixed_true[:,i],A_fixed_gen[:,i],p=2)) for i in range(a_dim)]
        chen_training_wass_errors.append(errors)

        # Test Chen discrepancy
        W = torch.randn((4*batch_size, w_dim), dtype= torch.float, device=device)
        noise = torch.randn((4*batch_size,noise_size), dtype=torch.float, device=device)
        gen_in = torch.cat((noise,W),1)
        A_gen = netG(gen_in)
        w_a = torch.cat((W,A_gen.detach()),1)
        ch_err = chen_error_2step(w_a)

        # Print out partial results
        pretty_errors = ["{0:0.5f}".format(i) for i in errors]
        pretty_chen_errors = ["{0:0.5f}".format(i) for i in ch_err]
        print(f"iter: {chen_iters}/{num_Chen_iters},\n errors: {pretty_errors}, \n chen errors: {pretty_chen_errors}")
        # Save for plotting
        chen_training_wass_errors.append(errors)
        chen_training_chen_errors.append(ch_err)


        # Early stopping checkpoint
        error_sum = sum(errors)
        if error_sum <= min_sum:
            min_sum = error_sum
            min_sum_errors = errors
            min_sum_paramsG = copy.deepcopy(netG.state_dict())
            min_sum_paramsD = copy.deepcopy(netD.state_dict())
            print("Saved parameters")

    chen_iters += 1

In [None]:
# Draw errors through iterations

plt.figure(figsize=(10,5))
plt.title("2-Wasserstein distance of generated samples from the original samples for fixed W increment")
plt.plot(chen_training_wass_errors)
plt.xlabel("iterations")
plt.ylabel("Wasserstein distance")
plt.legend()
plt.show()

In [None]:
# Time measurements

W_fixed: torch.Tensor = torch.tensor([1.0,-0.5,-1.2,-0.3,0.7,0.2,-0.9,0.1,1.7])
W_fixed = W_fixed[:w_dim].unsqueeze(1).transpose(1,0)
W_fixed = W_fixed.expand((test_batch_size,w_dim))
noise = torch.randn((test_batch_size,noise_size), dtype=torch.float, device=device)
g_in = torch.cat((noise,W_fixed),1)
netG.eval()
start_time = timeit.default_timer()
for i in range(100):
    A_fixed_out=netG(g_in)

elapsed = timeit.default_timer() - start_time
print(elapsed)

Takes 34.2s to generate 6553600 samples (original GAN)
Calling iterated_integrals(h = 1.0, err = 0.0005) 6553600-times takes 100.5s

In [None]:
# a list that records trial results in the following form (lowest_errors, best_net_params)
trial_results = []

for trial in range(num_trials):
    # initialize nets
    if which_model == 1:
        netD = Discriminator1().to(device)
        netG = Generator1().to(device)
    elif which_model == 2:
        netD = Discriminator2().to(device)
        netG = Generator2().to(device)
    elif which_model == 3:
        netD = Discriminator3().to(device)
        netG = Generator3().to(device)

    netD.apply(weights_init)
    netG.apply(weights_init)


    # Initialise optimiser

    if which_optimizer == 'Adam':
        optG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
        optD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    elif which_optimizer == 'RMSProp':
        optG = torch.optim.RMSprop(netG.parameters(), lr=lr)
        optD = torch.optim.RMSprop(netD.parameters(), lr=lr)