In [1]:
import numpy as np
from math import sqrt
import ot
import torch
import aux_functions

In [2]:
def a_idx(i: int, j: int, _w_dim):
    if i == j:
        return None
    idx = 0
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            if (i == k and j == l) or (j == k and i == l):
                return idx
            else:
                idx += 1

def w_indices(a_i: int, _w_dim):
    if a_i >= int(_w_dim*(_w_dim - 1)//2) or a_i < 0:
        return None
    idx = 0
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            if idx == a_i:
                return k, l
            else:
                idx += 1

def list_pairs(_w_dim: int):
    lst = []
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            lst.append((k,l))
    return lst
pair_lists = [list_pairs(wd) for wd in range(30)]


def fast_w_indices(a_i: int, _w_dim: int):
    if a_i >= int(_w_dim*(_w_dim - 1)//2) or a_i < 0 or _w_dim < 2:
        return None
    if _w_dim < len(pair_lists):
        return pair_lists[_w_dim][a_i]
    else:
        return w_indices(a_i, _w_dim)

In [3]:
def gen_4mom_approx(_w_dim: int, _batch_size: int, _W: np.ndarray = None, _K: np.ndarray = None, _H: np.ndarray = None):

    _a_dim = int(_w_dim*(_w_dim - 1)//2)

    lst = []
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            lst.append((k,l))

    if _W is None:
        __W = np.random.randn(_batch_size, _w_dim)
    else:
        __W = _W

    if _H is None:
        __H = sqrt(1/12) * np.random.randn(_batch_size, _w_dim)
    else:
        __H = _H

    if _K is None:
        __K = sqrt(1/720) * np.random.randn(_batch_size, _w_dim)
    else:
        __K = _K

    squared_K = np.square(__K)
    C = np.random.exponential(8/15, size=(_batch_size, _w_dim))

    p = 21130/25621
    c = sqrt(1/3) - 8/15

    ber = np.random.binomial(1, p = p, size=(_batch_size, _a_dim))
    uni = np.random.uniform(-sqrt(3), sqrt(3), size=(_batch_size, _a_dim))
    rademacher = np.ones(_a_dim) - 2* np.random.binomial(1, 0.5, size= (_batch_size, _a_dim))
    ksi = ber*uni + (1-ber)*rademacher

    #print(ksi)


    def sigma(i: int, j: int):
        return np.sqrt(3/28*(C[:,i] + c)*(C[:,j] + c) + 144/28*(squared_K[:,i] + squared_K[:,j]))

    idx = 0
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            sig = sigma(k, l)
            #print(f"shape: {sig.shape}, k: {k}, l: {l}, sig: {sig}")

            # now calculate a from ksi and sigma (but store a in ksi)
            ksi[:, idx] *= sig

            # calculate the whole thing
            ksi[:, idx] += __H[:,k]*__W[:,l] - __W[:,k]*__H[:,l] + 12*(__K[:,k]*__H[:,l] - __H[:,k]*__K[:,l])
            idx += 1

    return ksi

print(gen_4mom_approx(3, 5))

[[-0.4159377  -0.03736672  0.25547058]
 [-1.06641229 -0.14146951 -0.04425991]
 [-0.30824583  0.41563677 -0.68640744]
 [ 0.15223566 -0.03635092 -0.51462266]
 [-0.14455896 -0.40167325 -0.17983259]]


In [4]:
samples = np.genfromtxt('samples/fixed_samples_4-dim.csv', dtype=np.float, delimiter=',')
W = samples[:,:4]
a_true = samples[:,4:10]
print(samples.shape)
bsz = samples.shape[0]
generated = gen_4mom_approx(4,bsz,_W=W)


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  samples = np.genfromtxt('samples/fixed_samples_4-dim.csv', dtype=np.float, delimiter=',')


(65536, 10)


In [5]:
print(generated.shape)

(65536, 6)


In [6]:
err = [sqrt(ot.wasserstein_1d(a_true[:,i],generated[:,i],p=2)) for i in range(4)]
print(err)

[0.005918662187765701, 0.007572179703706825, 0.0049248601010889684, 0.007173870300729153]


In [8]:
jerr = aux_functions.joint_wass_dist(a_true[:10000], generated[:10000])
print(jerr)

0.3306195260239328


In [8]:
def four_combos(n: int):
    lst = []
    for i in range(n):
        for j in range(i,n):
            for k in range(j,n):
                for l in range(k,n):
                    lst.append((i,j,k,l))
    return lst

print(four_combos(3))

[(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2), (0, 0, 1, 1), (0, 0, 1, 2), (0, 0, 2, 2), (0, 1, 1, 1), (0, 1, 1, 2), (0, 1, 2, 2), (0, 2, 2, 2), (1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 2, 2), (1, 2, 2, 2), (2, 2, 2, 2)]


In [9]:
def fourth_moments(input: np.ndarray):
    lst = four_combos(3)
    res = []
    for i,j,k,l in lst:
        col = input[:,i]*input[:,j]*input[:,k]*input[:,l]
        res.append(col.mean())
    return res


w2 = np.concatenate((W,W,W,W), axis=0)
print(w2.shape)
combo_list = four_combos(3)
moms = fourth_moments(generated)
moms2 = fourth_moments(a_true)
for i in range(len(combo_list)):
    print(f"combo: {combo_list[i]}, 4_match_RV moment: {moms[i] :.7f}, samples moment: {moms2[i] :.7f}")

(262144, 3)
combo: (0, 0, 0, 0), 4_match_RV moment: 0.1361759, samples moment: 0.1357276
combo: (0, 0, 0, 1), 4_match_RV moment: 0.0326981, samples moment: 0.0357896
combo: (0, 0, 0, 2), 4_match_RV moment: 0.0666188, samples moment: 0.0649001
combo: (0, 0, 1, 1), 4_match_RV moment: 0.0726994, samples moment: 0.0728633
combo: (0, 0, 1, 2), 4_match_RV moment: 0.0008516, samples moment: 0.0013768
combo: (0, 0, 2, 2), 4_match_RV moment: 0.0723352, samples moment: 0.0717662
combo: (0, 1, 1, 1), 4_match_RV moment: 0.0501095, samples moment: 0.0485544
combo: (0, 1, 1, 2), 4_match_RV moment: 0.0278908, samples moment: 0.0279942
combo: (0, 1, 2, 2), 4_match_RV moment: 0.0042910, samples moment: 0.0053913
combo: (0, 2, 2, 2), 4_match_RV moment: 0.0760783, samples moment: 0.0750950
combo: (1, 1, 1, 1), 4_match_RV moment: 0.3016073, samples moment: 0.2966960
combo: (1, 1, 1, 2), 4_match_RV moment: -0.0382996, samples moment: -0.0387191
combo: (1, 1, 2, 2), 4_match_RV moment: 0.0815085, samples mom

In [10]:
w2 = np.concatenate((W,W,W,W), axis=0)
print(w2.shape)

(262144, 3)


In [11]:
generated = gen_4mom_approx(3,4*bsz,_W=w2)
moms = fourth_moments(generated)
for i in range(len(combo_list)):
    print(f"combo: {combo_list[i]}, 4_match_RV moment: {moms[i] :.7f}, samples moment: {moms2[i] :.7f}")

combo: (0, 0, 0, 0), 4_match_RV moment: 0.1350575, samples moment: 0.1357276
combo: (0, 0, 0, 1), 4_match_RV moment: 0.0336348, samples moment: 0.0357896
combo: (0, 0, 0, 2), 4_match_RV moment: 0.0655758, samples moment: 0.0649001
combo: (0, 0, 1, 1), 4_match_RV moment: 0.0722153, samples moment: 0.0728633
combo: (0, 0, 1, 2), 4_match_RV moment: 0.0007437, samples moment: 0.0013768
combo: (0, 0, 2, 2), 4_match_RV moment: 0.0717766, samples moment: 0.0717662
combo: (0, 1, 1, 1), 4_match_RV moment: 0.0481259, samples moment: 0.0485544
combo: (0, 1, 1, 2), 4_match_RV moment: 0.0280197, samples moment: 0.0279942
combo: (0, 1, 2, 2), 4_match_RV moment: 0.0045032, samples moment: 0.0053913
combo: (0, 2, 2, 2), 4_match_RV moment: 0.0753795, samples moment: 0.0750950
combo: (1, 1, 1, 1), 4_match_RV moment: 0.2946441, samples moment: 0.2966960
combo: (1, 1, 1, 2), 4_match_RV moment: -0.0395968, samples moment: -0.0387191
combo: (1, 1, 2, 2), 4_match_RV moment: 0.0827522, samples moment: 0.08177