In [1]:
import numpy as np
from math import sqrt
import ot
import torch
from TheGAN import LevyGAN
from aux_functions import *
import timeit

## Demonstration of the random variable which matches the first 4 moments of Levy Area

There is also a simple version which just matches the variances conditional on the W increment

In [2]:
w_dim = 3
a_dim = int((w_dim * (w_dim - 1)) // 2)
bsz = 262144
data = np.genfromtxt(f'samples/fixed_samples_{w_dim}-dim_big.csv', dtype=float, delimiter=',')
a_true = data[:, w_dim:(w_dim + a_dim)]
W = data[:, :w_dim]
W_torch = torch.tensor(W, dtype=torch.float)
print(data.shape)

(262144, 6)


In [None]:
mid_prec_samples = np.genfromtxt(f'samples/mid_prec_fixed_samples_{w_dim}-dim.csv', dtype=float,delimiter=',') # 0.68s
a_mid_prec = mid_prec_samples[:, w_dim:(w_dim + a_dim)]

In [6]:
start_time = timeit.default_timer()
generated2mom = gen_2mom_approx(w_dim, bsz, _W = W)
a_generated2mom = generated2mom[:, w_dim:(w_dim + a_dim)]
elapsed = timeit.default_timer() - start_time
print(elapsed)

1.296301570999276


In [4]:
start_time = timeit.default_timer()
generated4mom = gen_4mom_approx(w_dim, bsz, _W=W)
a_generated4mom = generated4mom[:, w_dim:(w_dim + a_dim)]
elapsed = timeit.default_timer() - start_time
print(elapsed)

0.14728145300068718


In [4]:
T, M, S = generate_tms(w_dim, torch.device('cpu'))
start_time = timeit.default_timer()
h = sqrt(1 / 12) * torch.randn((bsz, w_dim), dtype=torch.float)
wth = aux_compute_wth(W_torch, h, S, T, w_dim).detach()
b = sqrt(1 / 12) * torch.randn((bsz, w_dim), dtype=torch.float)
a_wthmb = aux_compute_wthmb(wth, b, M, w_dim)
elapsed = timeit.default_timer() - start_time
a_wthmb_np = a_wthmb.numpy()
print(elapsed)

0.07527714800016838


In [3]:
levG = LevyGAN()
levG.do_timeing = True
levG.load_dicts(serial_num_to_load=1, descriptor="_max_scr")
a_gan = levG.eval(W_torch)

EVAL TIME: 0.5236519610007235


In [3]:
err = [sqrt(ot.wasserstein_1d(a_true[:,i], a_generated2mom[:, i], p=2)) for i in range(a_dim)]
print(make_pretty(err))

[0.0314, 0.0277, 0.0289]


In [4]:
err = [sqrt(ot.wasserstein_1d(a_true[:,i], a_generated4mom[:, i], p=2)) for i in range(a_dim)]
print(make_pretty(err))

[0.0037, 0.0035, 0.0039]


In [5]:
err = [sqrt(ot.wasserstein_1d(a_true[:,i], a_mid_prec[:, i], p=2)) for i in range(a_dim)]
print(make_pretty(err))

[0.0031, 0.0028, 0.0034]


In [18]:
err = [sqrt(ot.wasserstein_1d(a_true[:,i], a_wthmb_np[:, i], p=2)) for i in range(a_dim)]
print(make_pretty(err))

[0.03, 0.0273, 0.029]


In [None]:
err = [sqrt(ot.wasserstein_1d(a_true[:,i], a_gan[:, i], p=2)) for i in range(a_dim)]
print(make_pretty(err))

In [3]:
joint_err = joint_wass_dist(a_true[:16384], a_generated2mom[:16384])
print(joint_err)

0.227711139218486


In [3]:
joint_err = joint_wass_dist(a_true[:16384], a_generated4mom[:16384])
print(joint_err)

0.08777318220273997


In [3]:
joint_err = joint_wass_dist(a_true[:16384], a_mid_prec[:16384])
print(joint_err)

0.08548501584219308


In [15]:
joint_err = joint_wass_dist(a_true[:16384], a_wthmb_np[:16384])
print(joint_err)

0.10378455706401561


In [None]:
joint_err = joint_wass_dist(a_true[:16384], a_gan[:16384])
print(joint_err)

## A utility for calculating the empirical fourth moments of a set of samples

In [None]:
def four_combos(n: int):
    lst = []
    for i in range(n):
        for j in range(i,n):
            for k in range(j,n):
                for l in range(k,n):
                    lst.append((i,j,k,l))
    return lst

def fourth_moments(input_samples: np.ndarray):
    dim = input_samples.shape[1]
    lst = four_combos(dim)
    res = []
    for i,j,k,l in lst:
        col = input_samples[:, i] * input_samples[:, j] * input_samples[:, k] * input_samples[:, l]
        res.append(col.mean())
    return res

combo_list = four_combos(6)
moms = fourth_moments(a_generated4mom)
moms2 = fourth_moments(a_true)

In [32]:
combo_list = four_combos(6)
abs_sum = 0
for i in range(len(combo_list)):
    abs_sum += abs(moms[i] - moms2[i])
    print(f"mom: {combo_list[i]}, 4_mom_RV: {moms[i] :.7f}, samples: {moms2[i] :.7f}")
print(abs_sum)
print(abs_sum/len(combo_list))

mom: (0, 0, 0, 0), 4_mom_RV: 0.1346635, samples: 0.1347481
mom: (0, 0, 0, 1), 4_mom_RV: 0.0335340, samples: 0.0328427
mom: (0, 0, 0, 2), 4_mom_RV: 0.0083241, samples: 0.0084791
mom: (0, 0, 0, 3), 4_mom_RV: 0.0659578, samples: 0.0663669
mom: (0, 0, 0, 4), 4_mom_RV: 0.0166638, samples: 0.0164522
mom: (0, 0, 0, 5), 4_mom_RV: 0.0000592, samples: -0.0001937
mom: (0, 0, 1, 1), 4_mom_RV: 0.0717887, samples: 0.0713913
mom: (0, 0, 1, 2), 4_mom_RV: 0.0079290, samples: 0.0079163
mom: (0, 0, 1, 3), 4_mom_RV: 0.0011992, samples: 0.0007711
mom: (0, 0, 1, 4), 4_mom_RV: 0.0025173, samples: 0.0024212
mom: (0, 0, 1, 5), 4_mom_RV: 0.0054178, samples: 0.0054886
mom: (0, 0, 2, 2), 4_mom_RV: 0.0419814, samples: 0.0421878
mom: (0, 0, 2, 3), 4_mom_RV: 0.0025456, samples: 0.0026153
mom: (0, 0, 2, 4), 4_mom_RV: -0.0085327, samples: -0.0084858
mom: (0, 0, 2, 5), 4_mom_RV: -0.0219468, samples: -0.0219955
mom: (0, 0, 3, 3), 4_mom_RV: 0.0722696, samples: 0.0732025
mom: (0, 0, 3, 4), 4_mom_RV: 0.0115913, samples: 0.

In [29]:
print(make_pretty(moms - moms2))

TypeError: unsupported operand type(s) for -: 'list' and 'list'

In [20]:
w2 = np.concatenate((W,W,W,W), axis=0)
print(w2.shape)

(1048576, 4)


In [24]:
generated4mom = gen_4mom_approx(3, bsz*4, _W=w2)
a_generated4mom = generated4mom[:, 4:10]
moms = fourth_moments(a_generated4mom)
for i in range(len(combo_list)):
    print(f"moment: {combo_list[i]}, 4_mom_RV: {moms[i] :.7f}, samples: {moms2[i] :.7f}")

moment: (0, 0, 0, 0), 4_mom_RV: 0.1342995, samples: 0.1360724
moment: (0, 0, 0, 1), 4_mom_RV: 0.0331013, samples: 0.0338491
moment: (0, 0, 0, 2), 4_mom_RV: 0.0661390, samples: 0.0082963
moment: (0, 0, 1, 1), 4_mom_RV: 0.0711489, samples: 0.0661575
moment: (0, 0, 1, 2), 4_mom_RV: 0.0008355, samples: 0.0169508
moment: (0, 0, 2, 2), 4_mom_RV: 0.0727376, samples: 0.0004610
moment: (0, 1, 1, 1), 4_mom_RV: 0.0478692, samples: 0.0721430
moment: (0, 1, 1, 2), 4_mom_RV: 0.0274638, samples: 0.0084095
moment: (0, 1, 2, 2), 4_mom_RV: 0.0046850, samples: 0.0007835
moment: (0, 2, 2, 2), 4_mom_RV: 0.0768439, samples: 0.0026474
moment: (1, 1, 1, 1), 4_mom_RV: 0.2929890, samples: 0.0055236
moment: (1, 1, 1, 2), 4_mom_RV: -0.0394454, samples: 0.0430429
moment: (1, 1, 2, 2), 4_mom_RV: 0.0813562, samples: 0.0021761
moment: (1, 2, 2, 2), 4_mom_RV: -0.0318231, samples: -0.0087013
moment: (2, 2, 2, 2), 4_mom_RV: 0.1861091, samples: -0.0224802
