In [1]:
import numpy as np
import scipy as sp
import scipy.stats
from scipy.special import digamma
from hmcfa.model import HiddenMarkovFA

In [2]:
import logging
logging.basicConfig(level=logging.DEBUG)

In [3]:
T = 5
G = 100
N = 2000
K = 3

data = np.random.random(size = (T,G,N))
fa = HiddenMarkovFA(data = data, n_factors = K, hyperparameters={})

In [5]:
fa.M_step()

(array([[[0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         ...,
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5]],
 
        [[0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         ...,
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5]],
 
        [[0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         ...,
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5]],
 
        [[0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         ...,
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5]],
 
        [[0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         ...,
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5],
         [0.5, 0.5, 0.5]]]),
 array([[[[[0.25, 0.25],
           [0.25, 0.25]],
 
          [[0.25, 0.25],
           [0.25, 0.25]],
 
          [[0.25, 0.25],
    

In [4]:
fa.run(progress_bar=True)

  0%|          | 0/1000 [00:00<?, ?it/s]

2024-09-06 14:28:48,035 - hmcfa - DEBUG - Performing global updates
2024-09-06 14:28:48,080 - hmcfa - DEBUG - Updated L
2024-09-06 14:28:48,087 - hmcfa - DEBUG - Updated F
2024-09-06 14:28:48,092 - hmcfa - DEBUG - Updated tau
2024-09-06 14:28:48,094 - hmcfa - DEBUG - Updated alpha
2024-09-06 14:28:48,094 - hmcfa - DEBUG - Updated A
2024-09-06 14:28:48,095 - hmcfa - DEBUG - Performing local updates
2024-09-06 14:28:48,096 - hmcfa - DEBUG - Performing V step


  0%|          | 0/1000 [15:14<?, ?it/s]


KeyboardInterrupt: 

In [4]:
digamma(fa.dirchlet_A)

array([[[2.95218353, 3.12076202],
        [2.94524192, 2.94426541]],

       [[3.06207842, 3.00082486],
        [2.7597074 , 3.05854622]],

       [[3.11717495, 2.85048565],
        [3.22934584, 2.78897192]]])

In [5]:
fa.dirchlet_A.shape

(3, 2, 2)

In [7]:
fa.dirchlet_A.sum(axis = 2)

array([[0.12165041, 1.70807934],
       [0.62854274, 1.9431701 ],
       [3.25366362, 3.90681624]])

In [10]:
fa.dirchlet_A

array([[[0.02848131, 0.0931691 ],
        [0.24515632, 1.46292302]],

       [[0.58982179, 0.03872095],
        [0.99887223, 0.94429787]],

       [[2.51121222, 0.7424514 ],
        [1.43428295, 2.47253329]]])

In [14]:
digamma(fa.dirchlet_A.sum(axis = 2))

array([[-8.61342887,  0.21493704],
       [-1.44071461,  0.38546472],
       [ 1.01830822,  1.22931631]])

In [16]:
digamma(fa.dirchlet_A)[0]

array([[-3.56420662e+01, -1.11667615e+01],
       [-4.31229862e+00,  1.24840337e-03]])

In [17]:
digamma(fa.dirchlet_A) - digamma(fa.dirchlet_A.sum(axis = 2)[:,:,np.newaxis])

array([[[-27.02863731,  -2.55333262],
        [ -4.52723566,  -0.21368863]],

       [[ -0.13743922, -24.90036383],
        [ -0.96453702,  -1.05823399]],

       [[ -0.30966837,  -2.12350923],
        [ -1.25611792,  -0.53971806]]])

In [9]:
np.exp(digamma(fa.dirchlet_A) - digamma(fa.dirchlet_A.sum(axis = 2))[:,:b, np.newaxis])

array([[[1.82646756e-12, 7.78218824e-02],
        [1.08105188e-02, 8.07599807e-01]],

       [[8.71587322e-01, 1.53429684e-11],
        [3.81159628e-01, 3.47068195e-01]],

       [[7.33690228e-01, 1.19611148e-01],
        [2.84757334e-01, 5.82912576e-01]]])

In [3]:
fa.a_alpha

array([[1.88122239, 1.40639841, 0.01562732],
       [3.30244876, 2.65323527, 0.51166638],
       [0.64757283, 0.33458648, 1.86384593],
       [0.86599957, 2.49086952, 1.15389524],
       [1.32879466, 0.46421577, 1.13173046]])

In [7]:
product = contract('tdj,tdj,tin,tjn->tdijn', fa.eta, fa.eta, fa.mu_F, fa.mu_F)

In [9]:
product.shape

(5, 10, 3, 3, 20)

In [27]:
product.shape

(5, 10, 3, 3, 20)

In [28]:
mask

array([[False,  True,  True],
       [ True, False,  True],
       [ True,  True, False]])

In [6]:
mask = ~np.eye(fa.K, dtype=bool)

In [29]:
product.shape

(5, 10, 3, 3, 20)

In [35]:
np.sum(product[0,0,0,i,2] for i in range(3) if i != 0)

  np.sum(product[0,0,0,i,2] for i in range(3) if i != 0)


-0.08822841485608768

In [30]:
result = np.einsum('tdijn,ij->tdin', product, mask)


In [32]:
result

array([[[[ 2.26908881e-02, -5.44246163e-01, -8.82284149e-02, ...,
          -3.70099667e-01, -1.49646639e-01, -3.14113110e-01],
         [ 2.20621199e-03, -6.08793383e-02, -2.54808662e-03, ...,
          -2.41270639e-02, -2.10253423e-02, -2.35939023e-02],
         [ 1.72528358e-02, -2.96790824e-01,  2.53755170e-01, ...,
           5.11464286e-01, -6.83885064e-02,  4.53959734e-02]],

        [[ 3.41816830e-02, -6.75965508e-01, -1.16301518e-01, ...,
          -4.75184748e-01, -1.23680713e-01, -4.80236709e-01],
         [ 1.03852750e-02, -2.84874951e-01, -1.71822463e-02, ...,
          -1.25316190e-01, -9.71146770e-02, -1.16226537e-01],
         [ 5.48303343e-02, -3.42121832e-01,  2.83690530e-01, ...,
           5.70837429e-01,  2.79675210e-01, -4.43869149e-01]],

        [[ 2.00049014e-02, -2.56550631e-01, -5.20169134e-02, ...,
          -1.98531388e-01,  2.59477473e-02, -2.87877633e-01],
         [ 8.26767153e-04, -3.21956159e-02,  2.76457748e-02, ...,
           5.57008619e-02, -1.8120

In [31]:
result.shape

(5, 10, 3, 20)

In [11]:
final_result.shape

(5, 10, 3, 20)

In [17]:
 np.sum(final_result, axis=-1)[0][0][0]

2.1707300949453425

In [18]:
# manual out product sum thingo

product

array([[[[[ 2.47752070e-02,  1.48747773e-02,  3.13877519e-03, ...,
            1.35204670e-02,  1.04453957e-01,  2.13423589e-01],
          [ 2.05326197e-02, -5.46147174e-01, -8.60302194e-02, ...,
           -3.65606431e-01, -1.73362377e-01, -2.81604578e-01],
          [ 2.15826845e-03,  1.90101086e-03, -2.19819550e-03, ...,
           -4.49323573e-03,  2.37157374e-02, -3.25085321e-02]],

         [[ 2.02942092e-03, -5.39805695e-02, -8.50312966e-03, ...,
           -3.61361265e-02, -1.71349414e-02, -2.78334783e-02],
          [ 1.68189626e-03,  1.98196820e+00,  2.33061008e-01, ...,
            9.77155613e-01,  2.84388860e-02,  3.67252513e-02],
          [ 1.76791061e-04, -6.89876878e-03,  5.95504304e-03, ...,
            1.20090626e-02, -3.89040094e-03,  4.23957599e-03]],

         [[ 9.43418854e-03,  8.30966831e-03, -9.60871704e-03, ...,
           -1.96407602e-02,  1.03665852e-01, -1.42100776e-01],
          [ 7.81864728e-03, -3.05100492e-01,  2.63363887e-01, ...,
            5.31105

In [19]:
product.shape

(5, 10, 3, 3, 20)

In [25]:
result

array([[[[[ 2.26908881e-02,  2.69334755e-02,  4.53078267e-02],
          [-5.44246163e-01,  1.67757881e-02, -5.31272397e-01],
          [-8.82284149e-02,  9.40579691e-04, -8.28914442e-02],
          ...,
          [-3.70099667e-01,  9.02723127e-03, -3.52085964e-01],
          [-1.49646639e-01,  1.28169694e-01, -6.89084200e-02],
          [-3.14113110e-01,  1.80915056e-01, -6.81809897e-02]],

         [[ 1.85868732e-03,  2.20621199e-03,  3.71131718e-03],
          [ 1.97506943e+00, -6.08793383e-02,  1.92798763e+00],
          [ 2.39016051e-01, -2.54808662e-03,  2.24557878e-01],
          ...,
          [ 9.89164676e-01, -2.41270639e-02,  9.41019487e-01],
          [ 2.45484851e-02, -2.10253423e-02,  1.13039446e-02],
          [ 4.09648273e-02, -2.35939023e-02,  8.89177298e-03]],

         [[ 8.64049758e-03,  1.02560388e-02,  1.72528358e-02],
          [-3.04038508e-01,  9.37165192e-03, -2.96790824e-01],
          [ 2.70093212e-01, -2.87939198e-03,  2.53755170e-01],
          ...,
      

In [26]:
result.shape

(5, 10, 3, 20, 3)

In [3]:
outer_prod = np.einsum('tdj,tdi,tdi,tdj->tdji', fa.eta, fa.eta, fa.mu_L, fa.mu_L)

In [4]:
outer_prod.shape

(5, 10, 3, 3)

In [41]:
fa.a_tau / fa.b_tau

array([[9.11988395e+00, 1.41191541e+00, 6.49568860e-02, 3.20663906e-01,
        2.45446237e+00, 1.75807861e+00, 6.13988277e-02, 6.92174308e-01,
        7.87443566e+00, 1.40026433e+01],
       [7.37341784e-01, 5.40121651e-01, 1.70155397e+01, 1.81612575e-01,
        8.30309400e-01, 2.47143951e-01, 5.19397345e+01, 4.53488731e+00,
        8.38808512e+00, 8.55834014e-02],
       [4.68759448e+00, 1.13720480e+01, 3.50783905e+00, 1.27718410e+00,
        3.24005742e-01, 3.24682889e+01, 7.34596086e+00, 1.83801826e+01,
        4.91462469e-02, 9.22550547e-02],
       [3.80065872e+00, 4.03559712e-01, 2.24235996e+00, 3.54883916e+00,
        1.09683915e+00, 6.61136392e-02, 1.11542884e+00, 2.72214079e-01,
        2.55021374e-01, 3.41576495e+00],
       [1.44762678e-01, 8.28433395e-01, 1.35332034e+00, 3.72771354e+00,
        9.49708928e-01, 2.15225362e-01, 2.38089903e+00, 3.47103240e-01,
        9.46890081e-01, 1.62301469e-02]])

In [49]:
outer_prod[0,0,:,:] * (fa.a_tau / fa.b_tau)[0,0]

array([[0.22687133, 1.17349312, 0.04658714],
       [1.17349312, 6.06989908, 0.24097221],
       [0.04658714, 0.24097221, 0.00956649]])

In [5]:
outer_prod_weighted = np.einsum('td,tdij->tdij', fa.a_tau / fa.b_tau, outer_prod)

In [11]:
outer_prod[0,0,:,:]

array([[0.00033509, 0.00788604, 0.00530355],
       [0.00788604, 0.18559159, 0.12481488],
       [0.00530355, 0.12481488, 0.08394105]])

In [19]:
outer_prod[1,1,:,:] * (fa.a_tau / fa.b_tau)[1,1]

array([[0.00313452, 0.00845953, 0.03506041],
       [0.00845953, 0.02283079, 0.09462192],
       [0.03506041, 0.09462192, 0.39215927]])

In [20]:
outer_prod_weighted[1,1,:,:]

array([[0.00313452, 0.00845953, 0.03506041],
       [0.00845953, 0.02283079, 0.09462192],
       [0.03506041, 0.09462192, 0.39215927]])

In [6]:
outer_prod_weighted.shape

(5, 10, 3, 3)

In [21]:
outer_prod_weighted_sum = outer_prod_weighted.sum(axis = 1)

In [22]:
outer_prod_weighted_sum.shape

(5, 3, 3)

In [86]:
outer_prod_weighted_sum[0,:,:]

array([[ 4.23538902,  1.96929541, -0.33522576],
       [ 1.96929541, 12.76209685,  1.87337993],
       [-0.33522576,  1.87337993,  9.77263238]])

In [79]:
fa.mu_F.shape

(5, 3, 20)

In [35]:
fa.mu_F[0,:,1]

array([ 1.32474286, -0.25130425,  1.20813723])

In [25]:
outer_prod_weighted_sum[0,:,:]

array([[ 18.73253416,  -0.64822334,   6.95281101],
       [ -0.64822334,  14.54303982, -11.98835733],
       [  6.95281101, -11.98835733,  40.89011432]])

In [36]:
final_weighted_product = np.einsum('tin,tki->tnki', fa.mu_F, outer_prod_weighted_sum)

In [37]:
final_weighted_product[0,1,:,:]

array([[ 24.81579079,   0.16290128,   8.39994985],
       [ -0.85872924,  -3.65472771, -14.48358085],
       [  9.21068671,   3.01272514,  49.40086957]])

In [84]:
final_weighted_product.shape

(5, 20, 3, 3)

In [38]:
mask = ~np.eye(fa.K, dtype=bool)

In [4]:
# sum bars
np.einsum('tnki,ki->tnk', final_weighted_product, mask)

NameError: name 'final_weighted_product' is not defined

In [5]:
np.einsum('tnki,ki->tnk', final_weighted_product, mask)[0,0,:]

NameError: name 'final_weighted_product' is not defined

In [8]:
np.einsum('tgk,ki->tgki', fa.sigma2_L, np.eye(fa.K))

array([[[[2.19224907e-01, 0.00000000e+00, 0.00000000e+00],
         [0.00000000e+00, 1.39379341e+00, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 2.12444459e+00]],

        [[1.65216336e+00, 0.00000000e+00, 0.00000000e+00],
         [0.00000000e+00, 8.95437975e-01, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 1.02100307e+00]],

        [[5.48596925e-01, 0.00000000e+00, 0.00000000e+00],
         [0.00000000e+00, 1.17486040e+00, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 7.85945339e-01]],

        [[4.37890486e-01, 0.00000000e+00, 0.00000000e+00],
         [0.00000000e+00, 7.90151515e-01, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 1.25979992e+00]],

        [[1.52051259e+00, 0.00000000e+00, 0.00000000e+00],
         [0.00000000e+00, 2.27389987e-02, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 6.83300843e-01]],

        [[9.06663004e-02, 0.00000000e+00, 0.00000000e+00],
         [0.00000000e+00, 1.38904519e+00, 0.00

In [13]:
np.einsum('tgk,tgi->tgki', fa.mu_L, fa.mu_L)[0,0,:,:]

array([[ 0.74719755,  0.97032006, -0.7503577 ],
       [ 0.97032006,  1.26006974, -0.97442388],
       [-0.7503577 , -0.97442388,  0.75353123]])

In [12]:
np.outer(fa.mu_L[0,0,:],fa.mu_L[0,0,:])

array([[ 0.74719755,  0.97032006, -0.7503577 ],
       [ 0.97032006,  1.26006974, -0.97442388],
       [-0.7503577 , -0.97442388,  0.75353123]])

In [14]:
np.einsum('tkn,tin->tkin', fa.mu_F, fa.mu_F)

array([[[[ 9.04950680e-01,  1.18026420e+00,  2.85808998e-01,
           5.21863986e-01,  1.77403369e-01,  3.74184328e-01,
           2.24797754e-01,  1.58023707e+00,  2.24224153e-03,
           3.18305766e-04,  5.70986882e-02,  4.07812121e+00,
           1.77322204e+00,  1.79231385e-02,  1.19722804e-01,
           1.27822674e+00,  1.41940405e-01,  1.68045361e-03,
           3.36570147e-01,  3.56808537e-01],
         [ 7.63501214e-02,  2.62143548e-01, -3.09792568e-01,
           5.10596050e-01, -9.87348270e-03,  4.34410412e-01,
           2.04307654e-01,  1.40815670e+00,  8.34077251e-02,
           1.70868619e-02,  3.86435556e-01, -1.01402611e+00,
           1.77595716e+00, -1.16592089e-01, -3.27851356e-01,
           2.51382811e+00, -1.47771727e-01,  6.36263516e-03,
          -9.03038238e-01,  1.08666619e+00],
         [ 1.03582581e+00, -7.46132892e-02,  8.81012028e-01,
          -5.35517596e-01,  5.42003460e-02,  3.35688008e-01,
          -5.68702990e-01, -1.62443924e+00,  4.43558957e

In [18]:
thing = (np.einsum('tkn,tin->tkin', fa.mu_F, fa.mu_F) + np.einsum('tkn,ki->tkin', fa.sigma2_F, np.eye(fa.K))).sum(axis = -1)

In [26]:
thing[0,:,:]

array([[30.28361827,  5.93676315,  4.7584712 ],
       [ 5.93676315, 35.34752254, -2.4027517 ],
       [ 4.7584712 , -2.4027517 , 47.31297217]])

In [28]:
outer_prod = np.einsum('tgk,tgi->tgki', fa.mu_L, fa.mu_L)

In [31]:
outer_prod[0,0,:,:]

array([[ 0.74719755,  0.97032006, -0.7503577 ],
       [ 0.97032006,  1.26006974, -0.97442388],
       [-0.7503577 , -0.97442388,  0.75353123]])

In [33]:
thing[0,:,:] * outer_prod[0,1,:,:]

array([[ 1.21700675e+00, -2.47716272e+00,  5.65454506e-02],
       [-2.47716272e+00,  1.53138144e+02,  2.96454949e-01],
       [ 5.65454506e-02,  2.96454949e-01,  1.66247627e-01]])

In [1]:
(np.einsum('tgk,tgi->tgki', fa.mu_L, fa.mu_L) * thing[:,np.newaxis,:,:]).shape

NameError: name 'np' is not defined

In [38]:
np.power(fa.eta[:,:,:,np.newaxis], np.eye(fa.K))

array([[[[0.69589129, 1.        , 1.        ],
         [1.        , 0.23903612, 1.        ],
         [1.        , 1.        , 0.89180936]],

        [[0.02076982, 1.        , 1.        ],
         [1.        , 0.25523604, 1.        ],
         [1.        , 1.        , 0.70278511]],

        [[0.42648275, 1.        , 1.        ],
         [1.        , 0.36774706, 1.        ],
         [1.        , 1.        , 0.56622588]],

        [[0.41490725, 1.        , 1.        ],
         [1.        , 0.54665498, 1.        ],
         [1.        , 1.        , 0.61518569]],

        [[0.16268739, 1.        , 1.        ],
         [1.        , 0.11746629, 1.        ],
         [1.        , 1.        , 0.18403808]],

        [[0.00269822, 1.        , 1.        ],
         [1.        , 0.86720587, 1.        ],
         [1.        , 1.        , 0.70652838]],

        [[0.25359522, 1.        , 1.        ],
         [1.        , 0.09918234, 1.        ],
         [1.        , 1.        , 0.14895406]],


In [40]:
(np.power(fa.eta[:,:,:,np.newaxis], np.eye(fa.K)) * fa.eta[:,:,np.newaxis,:]).shape

(5, 10, 3, 3)

In [37]:
fa.eta

array([[[0.69589129, 0.23903612, 0.89180936],
        [0.02076982, 0.25523604, 0.70278511],
        [0.42648275, 0.36774706, 0.56622588],
        [0.41490725, 0.54665498, 0.61518569],
        [0.16268739, 0.11746629, 0.18403808],
        [0.00269822, 0.86720587, 0.70652838],
        [0.25359522, 0.09918234, 0.14895406],
        [0.92633998, 0.09849913, 0.31424852],
        [0.85344917, 0.05368053, 0.26694605],
        [0.87080994, 0.20652907, 0.5137484 ]],

       [[0.71035773, 0.99435095, 0.94835039],
        [0.83648882, 0.23968579, 0.87914821],
        [0.36755708, 0.20653794, 0.70879402],
        [0.18283397, 0.07951155, 0.56609532],
        [0.52137018, 0.71067139, 0.98481578],
        [0.19841884, 0.85908937, 0.57826   ],
        [0.60400196, 0.20525928, 0.56478612],
        [0.91057637, 0.81594498, 0.70610707],
        [0.48287254, 0.62009766, 0.57586121],
        [0.18155122, 0.66434287, 0.65363524]],

       [[0.96200667, 0.30569229, 0.86877777],
        [0.76005904, 0.4079717