In [1]:
%matplotlib inline

import pystan
import pandas as pd
import numpy as np
import scipy.io as sio
import pickle

In [2]:
model_code="""
data{
    int I;
    int R;
    int K;
    real eta;
    real zeta;
    int J;
    real x[I, J];
    vector[R] b[I];
    matrix[K,R] lambda;
}
parameters{
    matrix[K,R] omega;
    vector<lower=0>[K] phi_alpha;
    vector<lower=0>[K] phi_beta;
}
transformed parameters{
    simplex[K] theta[I];
    for (i in 1:I){
        theta[i] = softmax(omega*b[i]);
    }
}
model{
    for (k in 1:K){
        for (r in 1:R){
            omega[k,r] ~ normal(0, lambda[k,r]);    //prior
        }
    }
    for (k in 1:K){
        phi_alpha[k] ~ exponential(eta);    //prior
    }
    for (k in 1:K){
        phi_beta[k] ~ exponential(zeta);    //prior
    }
    for (i in 1:I){
        for (j in 1:J){
            x[i,j] ~ gamma(dot_product(phi_alpha, theta[i]),dot_product(phi_beta, theta[i]));
        }
    }
    for (i in 1:I){
        for (j in 1:J){
            real prob_state[K];
            for (k in 1:K){
                prob_state[k] = log(theta[i, k]);
            }
            increment_log_prob(log_sum_exp(prob_state));
        }
    }
}
"""

In [3]:
model = pystan.StanModel(model_code=model_code, verbose=True)

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_f63f0f63d888d9d2bb963c008cef4695 NOW.
INFO:pystan:OS: linux, Python: 3.6.1 |Continuum Analytics, Inc.| (default, May 11 2017, 13:09:58) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)], Cython 0.25.2


Compiling /tmp/tmp2ur4kxc_/stanfit4anon_model_f63f0f63d888d9d2bb963c008cef4695_6495836358651909336.pyx because it changed.
[1/1] Cythonizing /tmp/tmp2ur4kxc_/stanfit4anon_model_f63f0f63d888d9d2bb963c008cef4695_6495836358651909336.pyx
building 'stanfit4anon_model_f63f0f63d888d9d2bb963c008cef4695_6495836358651909336' extension
creating /tmp/tmp2ur4kxc_/tmp
creating /tmp/tmp2ur4kxc_/tmp/tmp2ur4kxc_
gcc -pthread -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DBOOST_RESULT_OF_USE_TR1 -DBOOST_NO_DECLTYPE -DBOOST_DISABLE_ASSERTS -I/tmp/tmp2ur4kxc_ -I/home/ayman/anaconda3/envs/ati-dsg/lib/python3.6/site-packages/pystan -I/home/ayman/anaconda3/envs/ati-dsg/lib/python3.6/site-packages/pystan/stan/src -I/home/ayman/anaconda3/envs/ati-dsg/lib/python3.6/site-packages/pystan/stan/lib/stan_math_2.15.0 -I/home/ayman/anaconda3/envs/ati-dsg/lib/python3.6/site-packages/pystan/stan/lib/stan_math_2.15.0/lib/eigen_3.2.9 -I/home/ayman/anaconda3/envs/ati-dsg/lib/python3.6/site-package

In [19]:
with open('./model_data.pickle', 'rb') as f:
    model_data = pickle.load(f)

In [20]:
with open('./init_dict.pickle', 'rb') as f:
    init_dict = pickle.load(f)

In [21]:
fit = model.sampling(data=model_data, verbose=True)

In [22]:
fit

Inference for Stan model: anon_model_f63f0f63d888d9d2bb963c008cef4695.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

               mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
omega[0,0]     -0.2  3.3e-3   0.18  -0.54  -0.32   -0.2  -0.08   0.16   2901    1.0
omega[1,0]     0.22  3.3e-3   0.18  -0.12    0.1   0.21   0.33   0.58   2856    1.0
omega[0,1]     0.02  1.3e-3   0.07  -0.11  -0.02   0.02   0.07   0.15   2614    1.0
omega[1,1]    -0.31  1.6e-3   0.07  -0.45  -0.35  -0.31  -0.26  -0.17   2184    1.0
omega[0,2]     0.79  2.8e-3   0.11   0.57   0.71   0.78   0.86    1.0   1563    1.0
omega[1,2]    -0.07  1.8e-3   0.08  -0.24  -0.13  -0.07  -0.01    0.1   2216    1.0
phi_alpha[0]   3.88  3.5e-3   0.14   3.62   3.79   3.88   3.97   4.18   1604    1.0
phi_alpha[1]   1.22  3.2e-3   0.12   0.96   1.15   1.23    1.3   1.43   1482    1.0
phi_beta[0]  5.9e-6  3.6e-8 1.2e-6 3.2e-6 5.1e-6 5.9e-6

In [23]:
init_dict

{'omega': array([[ 0.00690881, -0.00881324, -0.81415286],
        [-0.33992712,  0.2416064 , -0.09427549]]),
 'phi_alpha': array([   30.22552776,  1840.90205194]),
 'phi_beta': array([  36.97634206,  199.69323767]),
 'theta': array([[ 0.50056703,  0.49943297],
        [ 0.52876474,  0.47123526],
        [ 0.27523373,  0.72476627],
        [ 0.54292619,  0.45707381],
        [ 0.75291747,  0.24708253],
        [ 0.32817784,  0.67182216],
        [ 0.52222185,  0.47777815],
        [ 0.35171769,  0.64828231],
        [ 0.29633767,  0.70366233],
        [ 0.15961563,  0.84038437],
        [ 0.67989739,  0.32010261],
        [ 0.6873702 ,  0.3126298 ],
        [ 0.29109447,  0.70890553],
        [ 0.49502973,  0.50497027],
        [ 0.42967348,  0.57032652],
        [ 0.62742333,  0.37257667],
        [ 0.22918787,  0.77081213],
        [ 0.69908914,  0.30091086],
        [ 0.71175992,  0.28824008],
        [ 0.65290825,  0.34709175],
        [ 0.42846908,  0.57153092],
        [ 0.5911160