In [None]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import edward as ed
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import pickle
import six
import tensorflow as tf

from edward.models import (
    Dirichlet, Categorical, Empirical, ParamMixture)

plt.style.use('ggplot')

In [None]:
def load_mimic_data(data_dir):
    """Loads a dataset extracted from the MIMIC-III critical care database.
    
    Args:
        None
        
    Returns:
        w - List of lists of 1D NumPy arrays containing
            tokens for each individual and data source.
        dicts - List of token id to token dictionaries for 
                all data types
    """
    
    data_types = ['note', 'lab', 'med']

    D = 50  # number of patients
    S = len(data_types)  # number of data types

    dicts = [None] * S
    w = [[None] * S for d in range(D)]
    z = [[None] * S for d in range(D)]
    for s, dt in enumerate(data_types):
        dict_file = os.path.join(data_dir,
                                 'dicts',
                                 dt + '_dict.p')

        form_corpus_file = os.path.join(data_dir, 
                                        'form_corpora',
                                        dt + '_form_corpus.txt')
        
        with open(dict_file, 'rb') as file:
            dicts[s] = pickle.load(file)
            

        with open(form_corpus_file, 'r') as file:
            for d, line in enumerate(file):
                doc_tokenids = []
                tokenid_counts = line.split(' ')[1:]

                for tic in tokenid_counts:
                    ti_c = tic.strip().split(':')
                    tokenid = float(ti_c[0])
                    count = int(ti_c[1])
                    
                    if count == 1:
                        count += 1

                    for _ in range(count):
                        doc_tokenids.append(tokenid)

                w[d][s] = np.array(doc_tokenids)
    
    return w, dicts, D, S



In [None]:
###############
# DATA
###############

# Get MIMIC data
w_train, dicts, D, S= load_mimic_data('../data')

# Calculate vocabulary size for each data type
V = [None] * S
for s in range(S):
    V[s] = len(dicts[s])
    

In [None]:
###############
# MODEL
###############
K = 5

alpha = tf.ones(K) * 0.01

beta, phi = [None] * S, [None] * S
for s in range(S):
    beta[s] = tf.ones(V[s]) * 0.01
    phi[s] = Dirichlet(concentration=beta[s], 
                     sample_shape=K)

theta = [None] * D
w = [[None] * S for d in range(D)]
z = [[None] * S for d in range(D)]
for d in range(D):
    theta[d] = Dirichlet(concentration=alpha)
    
    for s in range(S):

        w[d][s] = ParamMixture(mixing_weights=theta[d], 
                            component_params={'probs': phi[s]},
                            component_dist=Categorical,
                            sample_shape=len(w_train[d][s]),
                            validate_args=True)

        z[d][s] = w[d][s].cat
        

In [None]:
import time

####################
#INFERENCE
####################

overall_time = time.time()

# Data vars
data_dict = {}
for d in range(D):
    for s in range(S):
        data_dict[w[d][s]] = w_train[d][s]


# Latent vars
latent_vars_dict = {}

T = 1000 # number of samples
qphi = [None] * S
for s in range(S):
    print('Building latents for phi {} of {}'.format(s + 1, S))
    qphi[s] = Empirical(tf.Variable(tf.zeros([T, K, V[s]])))
    latent_vars_dict[phi[s]] = qphi[s]

qtheta = [None] * D
qz = [[None] * S for d in range(D)]
for d in range(D):
    print('Building latents for doc {} of {}'.format(d + 1, D))
    qtheta[d]= Empirical(tf.Variable(tf.ones([T, K]) / K))
    latent_vars_dict[theta[d]] = qtheta[d]
    
    for s in range(S):
        N = len(w_train[d][s])

        qz[d][s] = Empirical(tf.Variable(tf.zeros([T, N], dtype=tf.int32)))
        latent_vars_dict[z[d][s]] = qz[d][s]
print()

# Proposal vars
proposal_vars_dict = {}

phi_cond = [None] * S
for s in range(S):
    iteration_time = time.time()
    print('Building proposals for phi {} of {}'.format(s + 1, S))
    phi_cond[s] = ed.complete_conditional(phi[s])
    proposal_vars_dict[phi[s]] = phi_cond[s]
    end = time.time()
    print('Overall time: {}, Iteration time: {}'.format(end - overall_time,
                                                        end - iteration_time))

          
theta_cond = [None] * D
z_cond = [[None] * S for d in range(D)]
for d in range(D):
    iteration_time = time.time()
    print('Building proposals for doc {} of {}'.format(d + 1, D))
          
    theta_cond[d] = ed.complete_conditional(theta[d])
    proposal_vars_dict[theta[d]] = theta_cond[d]
    
    for s in range(S):
        z_cond[d][s] = ed.complete_conditional(z[d][s])
        proposal_vars_dict[z[d][s]] = z_cond[d][s]
        
    end = time.time()
    print('Overall time: {}, Iteration time: {}'.format(end - overall_time,
                                                        end - iteration_time))



In [None]:
# Building latents for phi 1 of 3
# Building latents for phi 2 of 3
# Building latents for phi 3 of 3
# Building latents for doc 1 of 50
# Building latents for doc 2 of 50
# Building latents for doc 3 of 50
# Building latents for doc 4 of 50
# Building latents for doc 5 of 50
# Building latents for doc 6 of 50
# Building latents for doc 7 of 50
# Building latents for doc 8 of 50
# Building latents for doc 9 of 50
# Building latents for doc 10 of 50
# Building latents for doc 11 of 50
# Building latents for doc 12 of 50
# Building latents for doc 13 of 50
# Building latents for doc 14 of 50
# Building latents for doc 15 of 50
# Building latents for doc 16 of 50
# Building latents for doc 17 of 50
# Building latents for doc 18 of 50
# Building latents for doc 19 of 50
# Building latents for doc 20 of 50
# Building latents for doc 21 of 50
# Building latents for doc 22 of 50
# Building latents for doc 23 of 50
# Building latents for doc 24 of 50
# Building latents for doc 25 of 50
# Building latents for doc 26 of 50
# Building latents for doc 27 of 50
# Building latents for doc 28 of 50
# Building latents for doc 29 of 50
# Building latents for doc 30 of 50
# Building latents for doc 31 of 50
# Building latents for doc 32 of 50
# Building latents for doc 33 of 50
# Building latents for doc 34 of 50
# Building latents for doc 35 of 50
# Building latents for doc 36 of 50
# Building latents for doc 37 of 50
# Building latents for doc 38 of 50
# Building latents for doc 39 of 50
# Building latents for doc 40 of 50
# Building latents for doc 41 of 50
# Building latents for doc 42 of 50
# Building latents for doc 43 of 50
# Building latents for doc 44 of 50
# Building latents for doc 45 of 50
# Building latents for doc 46 of 50
# Building latents for doc 47 of 50
# Building latents for doc 48 of 50
# Building latents for doc 49 of 50
# Building latents for doc 50 of 50

# Building proposals for phi 1 of 3
# Overall time: 69.31266689300537, Iteration time: 54.86830019950867
# Building proposals for phi 2 of 3
# Overall time: 123.89829397201538, Iteration time: 54.58547306060791
# Building proposals for phi 3 of 3
# Overall time: 182.77833604812622, Iteration time: 58.87992310523987
# Building proposals for doc 1 of 50
# Overall time: 381.61259174346924, Iteration time: 198.8338508605957
# Building proposals for doc 2 of 50
# Overall time: 567.233188867569, Iteration time: 185.62039589881897
# Building proposals for doc 3 of 50
# Overall time: 771.7069668769836, Iteration time: 204.47345495224
# Building proposals for doc 4 of 50
# Overall time: 936.0917587280273, Iteration time: 164.38445901870728
# Building proposals for doc 5 of 50
# Overall time: 1120.0477230548859, Iteration time: 183.95552325248718
# Building proposals for doc 6 of 50
# Overall time: 1276.690170764923, Iteration time: 156.6422679424286
# Building proposals for doc 7 of 50
# Overall time: 1444.1476140022278, Iteration time: 167.45653009414673
# Building proposals for doc 8 of 50
# Overall time: 1735.4187908172607, Iteration time: 291.2701859474182
# Building proposals for doc 9 of 50
# Overall time: 1968.6613538265228, Iteration time: 233.24120998382568
# Building proposals for doc 10 of 50
# Overall time: 2185.009156703949, Iteration time: 216.34635066986084
# Building proposals for doc 11 of 50
# Overall time: 2463.9241259098053, Iteration time: 278.9145920276642
# Building proposals for doc 12 of 50
# Overall time: 2678.157154083252, Iteration time: 214.228657245636
# Building proposals for doc 13 of 50
# Overall time: 2890.1912257671356, Iteration time: 212.03397393226624
# Building proposals for doc 14 of 50
# Overall time: 3122.7205917835236, Iteration time: 232.52895283699036
# Building proposals for doc 15 of 50
# Overall time: 3485.069458961487, Iteration time: 362.3487198352814
# Building proposals for doc 16 of 50
# Overall time: 3644.913102865219, Iteration time: 159.84318804740906
# Building proposals for doc 17 of 50
# Overall time: 3796.687983751297, Iteration time: 151.77477407455444
# Building proposals for doc 18 of 50
# Overall time: 3950.3377928733826, Iteration time: 153.64937901496887
# Building proposals for doc 19 of 50
# Overall time: 4101.002737760544, Iteration time: 150.66468381881714
# Building proposals for doc 20 of 50
# Overall time: 4395.566714763641, Iteration time: 294.5638659000397
# Building proposals for doc 21 of 50
# Overall time: 4544.476539850235, Iteration time: 148.89263892173767
# Building proposals for doc 22 of 50
# Overall time: 4694.441290855408, Iteration time: 149.9646019935608
# Building proposals for doc 23 of 50
# Overall time: 4848.94672203064, Iteration time: 154.50405025482178
# Building proposals for doc 24 of 50
# Overall time: 5003.179519891739, Iteration time: 154.2320339679718
# Building proposals for doc 25 of 50
# Overall time: 5161.905777931213, Iteration time: 158.72613787651062
# Building proposals for doc 26 of 50
# Overall time: 5534.618554830551, Iteration time: 372.71096301078796
# Building proposals for doc 27 of 50
# Overall time: 5679.782926797867, Iteration time: 145.11774802207947
# Building proposals for doc 28 of 50
# Overall time: 5831.451673746109, Iteration time: 151.66864371299744
# Building proposals for doc 29 of 50
# Overall time: 5986.792187929153, Iteration time: 155.3397068977356
# Building proposals for doc 30 of 50
# Overall time: 6134.088886976242, Iteration time: 147.29659509658813
# Building proposals for doc 31 of 50
# Overall time: 6277.987422943115, Iteration time: 143.89844512939453
# Building proposals for doc 32 of 50
# Overall time: 6423.601721763611, Iteration time: 145.61409378051758
# Building proposals for doc 33 of 50
# Overall time: 6568.684260845184, Iteration time: 145.08171796798706
# Building proposals for doc 34 of 50
# Overall time: 7045.336709737778, Iteration time: 476.6523468494415
# Building proposals for doc 35 of 50
# Overall time: 7213.730815887451, Iteration time: 168.30789494514465
# Building proposals for doc 36 of 50
# Overall time: 7378.473023891449, Iteration time: 164.73943185806274
# Building proposals for doc 37 of 50
# Overall time: 7534.124243974686, Iteration time: 155.65112495422363
# Building proposals for doc 38 of 50
# Overall time: 7677.08373093605, Iteration time: 142.95938110351562
# Building proposals for doc 39 of 50
# Overall time: 7813.287255048752, Iteration time: 136.20341610908508
# Building proposals for doc 40 of 50
# Overall time: 7951.467782735825, Iteration time: 138.1804027557373
# Building proposals for doc 41 of 50
# Overall time: 8091.492294788361, Iteration time: 140.02440977096558
# Building proposals for doc 42 of 50
# Overall time: 8243.043648958206, Iteration time: 151.55125617980957
# Building proposals for doc 43 of 50
# Overall time: 8830.350775003433, Iteration time: 587.3069970607758
# Building proposals for doc 44 of 50
# Overall time: 9049.290264844894, Iteration time: 218.93365597724915
# Building proposals for doc 45 of 50
# Overall time: 9207.876252889633, Iteration time: 158.58503985404968
# Building proposals for doc 46 of 50
# Overall time: 9364.491986989975, Iteration time: 156.61562514305115
# Building proposals for doc 47 of 50
# Overall time: 9530.133072853088, Iteration time: 165.6409628391266
# Building proposals for doc 48 of 50
# Overall time: 9684.557219028473, Iteration time: 154.42402720451355
# Building proposals for doc 49 of 50
# Overall time: 9836.190227985382, Iteration time: 151.63291597366333
# Building proposals for doc 50 of 50
# Overall time: 9986.481297969818, Iteration time: 150.29096603393555

In [None]:
# Inference procedure w/Gibbs sampling
inference = ed.Gibbs(latent_vars=latent_vars_dict,
                     proposal_vars=proposal_vars_dict,
                     data=data_dict)

inference.initialize(n_iter=T, n_print=10, logdir='log')

tf.global_variables_initializer().run()

for n in range(inference.n_iter):
    info_dict = inference.update()
    inference.print_progress(info_dict)

inference.finalize()

In [None]:
plt.pcolormesh(qphi[1].params[-1].eval())
plt.colorbar()

In [None]:
var = tf.pow(qphi[1] - tf.reduce_mean(qphi[1].params, axis=0), 2)
plt.pcolormesh(np.log(var.eval()))