In [26]:
import numpy as np
from math import sqrt

w_dim = 3
a_dim = int(w_dim*(w_dim - 1)//2)
W = np.random.randn(w_dim)
H = 1/12 * np.random.randn(w_dim)
K = 1/720 * np.random.randn(w_dim)
C = np.random.exponential(8/15, size=(w_dim,))

print(W)
print(H)
print(K)
print(C)

[1.93908841 0.03832014 1.51146853]
[-0.12554142 -0.06621361  0.03543413]
[ 0.00099471 -0.00210933  0.00122218]
[0.26314663 0.11830186 0.39329631]


In [27]:
def a_idx(i: int, j: int, _w_dim):
    if i == j:
        return None
    idx = 0
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            if (i == k and j == l) or (j == k and i == l):
                return idx
            else:
                idx += 1

def w_indices(a_i: int, _w_dim):
    if a_i >= int(_w_dim*(_w_dim - 1)//2) or a_i < 0:
        return None
    idx = 0
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            if idx == a_i:
                return k, l
            else:
                idx += 1

def list_pairs(_w_dim: int):
    lst = []
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            lst.append((k,l))
    return lst
pair_lists = [list_pairs(wd) for wd in range(30)]


def fast_w_indices(a_i: int, _w_dim: int):
    if a_i >= int(_w_dim*(_w_dim - 1)//2) or a_i < 0 or _w_dim < 2:
        return None
    if _w_dim < len(pair_lists):
        return pair_lists[_w_dim][a_i]
    else:
        return w_indices(a_i, _w_dim)

print([fast_w_indices(i, 5) for i in range(13)])

[(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), None, None, None]


In [28]:
# Generate ksi
p = 21130/25621
c = sqrt(1/3) - 8/15

ber = np.random.binomial(1, p = p, size=(a_dim,))
uni = np.random.uniform(-sqrt(3), sqrt(3), size=(a_dim,))
rademacher = np.ones(a_dim) - 2* np.random.binomial(1, 0.5, size= (a_dim,))
ksi = ber*uni + (1-ber)*rademacher


In [29]:
def sigma(i: int, j: int):
    return sqrt(3/28*(C[i] + c)*(C[j] + c) + 1/28*((12*K[i])**2 + (12*K[j])**2))

sig = []
for k in range(w_dim):
    for l in range(k+1, w_dim):
        sig.append(sigma(k,l))

print(sig)

[0.07327990241783074, 0.12002048438583897, 0.08738426389767474]


In [30]:
res = np.array(sig) * ksi
print(res)

[ 0.0549302  -0.12002048 -0.10986279]


In [31]:
print(ksi)

[ 0.74959432 -1.         -1.25723765]


In [41]:
def gen_4mom_approx(_w_dim: int, _batch_size: int, _W: np.ndarray = None, _K: np.ndarray = None, _H: np.ndarray = None):
    _a_dim = int(_w_dim*(_w_dim - 1)//2)

    lst = []
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            lst.append((k,l))

    if _W == None:
        __W = np.random.randn(_batch_size, _w_dim)
    else:
        __W = _W

    if _H == None:
        __H = 1/12 * np.random.randn(_batch_size, _w_dim)
    else:
        __H = _H

    if _K == None:
        __K = 1/720 * np.random.randn(_batch_size, _w_dim)
    else:
        __K = _K

    squared_K = np.square(__K)
    C = np.random.exponential(8/15, size=(_batch_size, _w_dim))

    p = 21130/25621
    c = sqrt(1/3) - 8/15

    ber = np.random.binomial(1, p = p, size=(_batch_size, a_dim))
    uni = np.random.uniform(-sqrt(3), sqrt(3), size=(_batch_size, a_dim))
    rademacher = np.ones(a_dim) - 2* np.random.binomial(1, 0.5, size= (_batch_size, a_dim))
    ksi = ber*uni + (1-ber)*rademacher
    print(ksi)
    def sigma(i: int, j: int):
        return np.sqrt(3/28*(C[:,i] + c)*(C[:,j] + c) + 144/28*(squared_K[:,i] + squared_K[:,j]))

    idx = 0
    for k in range(_w_dim):
        for l in range(k+1, _w_dim):
            sig = sigma(k, l)
            print(f"shape: {sig.size}, k: {k}, l: {l}, sig: {sig}")

            # now calculate a from ksi and sigma (but store a in ksi)
            ksi[:, idx] *= sig

            # calculate the whole thing
            ksi[:, idx] += __H[:,k]*__W[:,l] - __W[:,k]*__H[:,l] + 12*(__K[:,k]*__H[:,l] - __H[:,k]*__K[:,l])
            idx += 1

    return ksi

print(gen_4mom_approx(3, 5))

[[-1.10530845  1.4897569   0.31536049]
 [-0.40357178  0.37900206 -1.47201362]
 [-0.03689588 -0.19652217 -0.61886711]
 [-1.         -1.41862029  0.23846889]
 [-1.36851659  1.41591402  1.1431877 ]]
shape: 5, k: 0, l: 1, sig: [0.0283069  0.46541773 0.13462885 0.06288323 0.11041406]
shape: 5, k: 0, l: 2, sig: [0.11408845 0.17922834 0.12450243 0.27475238 0.09972826]
shape: 5, k: 1, l: 2, sig: [0.09784813 0.21429375 0.13404342 0.14695043 0.08380169]
[[-0.16987286  0.16646962  0.26070803]
 [-0.21429492  0.05468135 -0.36318466]
 [-0.10393036 -0.02059426 -0.07821499]
 [-0.06012626 -0.43156355 -0.0047571 ]
 [-0.15444398  0.07362423  0.15012622]]
