In [2]:
import pymc3

In [3]:
import numpy as np

In [4]:
T = 1000
k = 5
d1 = 20
d2 = 30

In [5]:
W1 = np.random.randn(d1, k)
W2 = np.random.randn(d2, k)
mu1 = np.random.randn(d1, 1)
mu2 = np.random.randn(d2, 1)
pre_Psi1 = np.random.randn(2 * d1, d1)
Psi1 = np.dot(pre_Psi1.T, pre_Psi1)
pre_Psi2 = np.random.randn(2 * d2, d2)
Psi2 = np.dot(pre_Psi2.T, pre_Psi2)

In [6]:
Psi1_sqrt = np.linalg.cholesky(Psi1)
Psi2_sqrt = np.linalg.cholesky(Psi2)

In [7]:
Z = np.random.randn(T, k)
Z_lift1 = np.dot(W1, Z.T).T
Z_lift2 = np.dot(W2, Z.T).T
Z_shift1 = Z_lift1 + mu1.T
Z_shift2 = Z_lift2 + mu2.T
X1_noise = np.random.randn(T, d1)
X2_noise = np.random.randn(T, d2)
X1 = np.dot(X1_noise, Psi1_sqrt) + Z_shift1
X2 = np.dot(X2_noise, Psi2_sqrt) + Z_shift2

In [49]:
print(X1)
print(X2)

[[ -2.32398445  -1.59255581  -2.44672855 ...,   0.56385745  -4.37264267
    1.46990814]
 [ -6.45837612  -3.22049761  -5.83065864 ...,  -1.03853988   5.00340449
   -0.94257034]
 [ 19.94063297  -7.47967805  19.15488059 ...,   8.85027492   5.93933109
    2.84588542]
 ..., 
 [ -1.30626168  -4.93008179  -1.24859542 ...,   8.04830819   4.47257084
   -2.76817863]
 [-10.84327631   5.19879051   2.08167292 ...,   2.36658216   2.7558005
   -6.95496141]
 [-10.19472603  -9.13277559 -14.8106589  ...,   3.63930029  -5.73431122
   -0.697028  ]]
[[  1.51276053   2.71087285   3.64242034 ...,  -1.1244023   -2.27810949
    8.10003107]
 [  5.40314739  -5.3855255    6.67720216 ...,  -7.82624976 -11.56927164
   -2.59496409]
 [  0.59916453   4.38378787  11.62947227 ...,  -0.95346828  -2.78547041
   11.91888584]
 ..., 
 [ -6.7906873   -6.51774093   0.68981859 ...,  10.28594799   8.18885287
    2.14527006]
 [  0.05703094   8.97030929   7.72073323 ...,   7.51423953 -10.72141609
   -1.91146604]
 [ -4.36858784   0

In [54]:
import pymc3 as pm

model = pm.Model()

with model:
    # Priors on Z-lifting matrices
    W1_var = pm.Normal('W1', shape=(d1,k))
    W2_var = pm.Normal('W2', shape=(d2,k))
    
    # Priors on X means
    mu1_var = pm.Normal('mu1', shape=(d1,1))
    mu2_var = pm.Normal('mu2', shape=(d2,1))
    
    # Priors on covariance matrices
    Psi1_var = pm.LKJCholeskyCov(
        'Psi1', n=d1, eta=2., sd_dist=pm.HalfCauchy.dist(2.5))
    Psi2_var = pm.LKJCholeskyCov(
        'Psi2', n=d2, eta=2., sd_dist=pm.HalfCauchy.dist(2.5))
    
    # SDs for the Xs
    L1 = pm.expand_packed_triangular(d1, Psi1_var)
    L2 = pm.expand_packed_triangular(d2, Psi2_var)
    
    # Likelihood on Z
    #Z_var = pm.Normal('Z', shape=(T,k))
    
    # Means for conditional likelihood on Xs conditioned on Z
    Z_lift1_var = W1_var.dot(Z.T).T
    Z_lift2_var = W2_var.dot(Z.T).T
    Z_shift1_var = Z_lift1_var + mu1_var.T
    Z_shift2_var = Z_lift2_var + mu2_var.T
    
    # Conditional likelihoods of Xs conditioned on Z
    X1_var = pm.MvNormal(name='X1', mu=Z_shift1_var, chol=L1, observed=X1)
    X2_var = pm.MvNormal(name='X2', mu=Z_shift2_var, chol=L2, observed=X2)

In [55]:
map_estimate = pm.find_MAP(model=model)

logp = -1.6271e+05, ||grad|| = 0.54034: 100%|██████████| 105/105 [00:00<00:00, 340.20it/s] 


In [56]:
map_estimate

{'Psi1': array([ 8.23352659, -0.4974707 ,  6.68318435, -0.27107963,  0.5533524 ,
         7.87796597, -0.548861  ,  0.10309506,  0.21137883,  7.63900147,
        -0.42639308, -1.33802307,  0.66984676,  0.74457594,  6.58328144,
        -0.41168378, -0.76130569, -0.52102885, -0.13741384, -0.81485936,
         5.93123751, -1.00358377,  0.3183883 , -0.62483799,  0.87340505,
        -0.28410642,  1.49533832,  7.04123931,  0.56966502, -0.62736365,
        -0.7581008 , -0.30347098, -0.16371627,  1.07194017, -0.55776282,
         6.3643898 , -1.02811902,  0.86848564,  1.0998887 , -0.74102524,
         0.35636023, -0.69021109,  1.46089727, -0.18544521,  5.30115028,
         0.8075009 ,  0.69648054, -0.54773099,  1.49684192,  0.10469941,
         0.35329524, -1.12599178, -1.99306734, -0.87195502,  4.44061317,
        -0.05880398, -0.78453752,  1.03680714, -0.08461877, -1.24931722,
        -1.20526577,  0.78502887, -0.44202894, -0.32260434, -2.16847919,
         5.37378083,  0.22342509,  0.727052