In [2]:
#pip install pymc3

In [2]:
import numpy as np
import torch 
import scipy.stats
import math
from scipy.stats import multivariate_normal
from torch import distributions
import pymc3 as pm
import random
import seaborn as sns



$\ Z_{it}|Z_{i,t-1}$ ~ $\ N$($\Omega Z_{i,t-1}$, $\ I$\)\
$\ X_{it}|Z_{i,t}$ ~ $\ N$($\ Z_{i,t}$, $\ I$\)

Theoretical derivation gives:\
$\ Z_{it}|X_{i,t}$ ~ $\ N$($(1/2)*(\Omega Z_{i,t-1} + \ X_{i,t})$, $\ I/2$\)


In [3]:
N = 500 # num patients
T = 10 # num of hospital visits for each patients
random.seed(53)
zdim = 2
xdim = zdim

Z = torch.zeros((N, T, zdim)) # latent
logPZ = torch.zeros((N, T)) 
X = torch.zeros((N, T, xdim)) # observed
PX = torch.zeros((N, T, xdim)) 
logPX = torch.zeros((N, T))
d = torch.ones(zdim,requires_grad = True) 
omega = torch.diag(d) 
#Psi = torch.randn((xdim,zdim),requires_grad = True)



#mean = torch.zeros((zdim))
#cov = torch.eye(zdim)
#Zi0 = torch.distributions.multivariate_normal.MultivariateNormal(mean,cov)
#Z[0,0] = Zi0.sample()

for t in range(1,T):
    # Zit|Zi,t-1
    meanz = torch.matmul(Z[:,t-1], omega)
    cov = torch.eye(zdim)
    Zt = torch.distributions.multivariate_normal.MultivariateNormal(meanz, cov)
    Z[:,t] = Zt.sample()
    logPZ[:,t] = Zt.log_prob(Z[:,t])


    #Xit|Zit
    meanx = Z[:,t] 
    cov = torch.eye(xdim) 
    Xt = torch.distributions.multivariate_normal.MultivariateNormal(meanx, cov)
    X[:,t] = Xt.sample()
    logPX[:,t] = Xt.log_prob(X[:,t])


        
        
def posterior_i(x,omega):
    n,t,xdim = x.size()
    zdim = xdim
    z = torch.zeros((n,t,zdim))
    logpz = torch.zeros((n, t)) 

    for j in range(1,t):     
        mean = 1/2*(torch.matmul(Z[:,j-1], omega) + x[:,j])
        cov = torch.eye(zdim)*(1/2)
        zt = torch.distributions.multivariate_normal.MultivariateNormal(mean, cov)
        z[:,j] = zt.sample()
        logpz[:,j] = zt.log_prob(z[:,j])     
    return z, logpz
        

    
Z_posterior = torch.zeros((N, T, zdim))
Z_posterior,log_posterior = posterior_i(X,omega)

In [4]:
torch.mean((Z-Z_posterior)**2)

tensor(0.8961)

In [5]:
Z_posterior

tensor([[[ 0.0000e+00,  0.0000e+00],
         [-1.7798e-01,  1.8810e+00],
         [ 5.3781e-02,  3.7040e-02],
         ...,
         [-4.4548e-01,  2.6676e+00],
         [ 1.2291e+00,  1.4103e+00],
         [ 4.9175e-01,  5.9679e-01]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 3.8345e-01, -4.1703e-03],
         [ 1.4259e+00,  1.3981e-01],
         ...,
         [ 2.4969e+00, -3.7381e+00],
         [ 7.5783e-01, -5.5789e+00],
         [ 3.0916e+00, -5.8369e+00]],

        [[ 0.0000e+00,  0.0000e+00],
         [-8.2840e-02, -9.3748e-03],
         [ 5.8908e-01, -1.2804e+00],
         ...,
         [-1.2019e+00,  3.6194e+00],
         [ 6.3056e-01,  3.5899e+00],
         [-2.9003e+00,  4.2952e+00]],

        ...,

        [[ 0.0000e+00,  0.0000e+00],
         [-4.3188e-02, -9.8304e-01],
         [ 6.1278e-01, -1.4190e+00],
         ...,
         [ 1.8112e-01,  1.9335e+00],
         [-1.0638e+00, -2.1785e-01],
         [ 1.5986e-01,  3.3148e-01]],

        [[ 0.0000e+00,  0.0000e+00

In [6]:
Z

tensor([[[ 0.0000,  0.0000],
         [-1.1254,  0.1623],
         [ 0.3141, -1.0375],
         ...,
         [ 1.2489,  1.0395],
         [ 0.0444,  1.0417],
         [ 0.4390,  0.1707]],

        [[ 0.0000,  0.0000],
         [ 0.3458, -0.3084],
         [ 1.0967, -1.2672],
         ...,
         [ 1.7955, -4.6554],
         [ 1.8791, -5.5217],
         [ 3.1513, -5.8138]],

        [[ 0.0000,  0.0000],
         [-0.1426,  0.0354],
         [-0.8688,  0.7505],
         ...,
         [-0.0373,  3.9404],
         [-1.1010,  4.5625],
         [-1.7324,  4.0058]],

        ...,

        [[ 0.0000,  0.0000],
         [ 0.6092, -1.4228],
         [ 0.6943, -0.1799],
         ...,
         [-2.1987,  0.7683],
         [ 0.2545, -0.2085],
         [ 1.0958,  0.3610]],

        [[ 0.0000,  0.0000],
         [ 1.4173,  0.4078],
         [ 2.7967,  0.8755],
         ...,
         [ 1.4952,  2.6349],
         [ 1.7544,  1.5022],
         [ 1.1315,  1.1198]],

        [[ 0.0000,  0.0000],
       