In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_orig()

from matplotlib import rc
rc('text', usetex=True)
rc('font', **{'family' : "sans-serif"})
params = {'text.latex.preamble' : [r'\usepackage{siunitx}', r'\usepackage{amsmath}']}
plt.rcParams.update(params)

from scipy.stats import bernoulli
from scipy.special import loggamma

from sklearn.preprocessing import OneHotEncoder

import tqdm

from rpy2.robjects import numpy2ri
from rpy2.robjects.packages import importr

In [2]:
import os
import pickle

In [3]:
from joblib import Parallel, delayed

In [4]:
import warnings
warnings.resetwarnings()
warnings.simplefilter('ignore', UserWarning)

In [5]:
EPS = np.finfo(np.float).eps

In [6]:
pd.options.display.max_rows = 200
pd.options.display.max_columns = 200

In [7]:
indir = './data'
outdir = './output'
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [8]:
# R packages
## blockmodels
blockmodels = importr("blockmodels")
## label.switching
label_switching = importr("label.switching")
## base
base = importr("base")

dollar = base.__dict__["$"]

In [9]:
from functools import lru_cache

@lru_cache(maxsize = 10000)
def normterm_discrete(n, k):
    if n == 1:
        return np.log(k)
    if k == 1:
        return 1.0
    elif k == 2:
        return np.sum(sorted([ np.exp(loggamma(n+1) - loggamma(t+1) - loggamma(n-t+1) + 
                               t*(np.log(t) - np.log(n)) + (n-t)*(np.log(n-t) - np.log(n))
                        )
                        for t in range(1, n)]))
    else:
        return normterm_discrete(n, k-1) + n/(k-2) * normterm_discrete(n, k-2)

In [10]:
def check_latent_index_variable(z):
    unique_z = sorted(np.unique(z))
    if len(unique_z) == np.max(z) + 1:
        return z
    new_z = np.zeros(z.shape, dtype=np.int)
    for index, current in enumerate(unique_z):
        new_z[z == current] = index
    return new_z

In [11]:
def calc_dnml(X, Z1, Z2, K=3, L=3):    
    N = X.shape[0]
    
    codelen_x_z = 0.0
    codelen_z = 0.0

    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg

            if n_all >=2:
                codelen_x_z += n_all * np.log(n_all)
                codelen_x_z += np.log(normterm_discrete(n_all, 2))
                #codelen += n_all * np.log(n_all)
                #codelen += np.log(normterm_discrete(n_all, 2))
            
            if n_pos >=2:
                codelen_x_z -= n_pos * np.log(n_pos)
            if n_neg >=2:
                codelen_x_z -= n_neg * np.log(n_neg)
            
        n_k = np.sum(Z1 == k)
        if n_k >= 1:
            codelen_z += n_k * (np.log(N) - np.log(n_k))

    codelen_z += np.log(normterm_discrete(N, K))
    
    codelen = codelen_x_z + codelen_z

    return codelen, codelen_x_z, codelen_z

In [12]:
def calc_dnml_by_prob(X, Z, alpha, theta, K=3, L=3, eps=1e-12):    
    N = X.shape[0]
    
    Z1 = np.argmax(Z, axis=1)
    
    Z1 = check_latent_index_variable(Z1)
    Z2 = Z1
    
    codelen_x_z = 0.0
    codelen_z = 0.0
        
    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg
            
            if theta[k, l] < eps:
                theta[k, l] = eps
            
            if theta[k, l] > 1.0 - eps:
                theta[k, l] = 1.0 - eps
            
            codelen_x_z += -n_pos * np.log(theta[k, l]) - n_neg * np.log(1.0 - theta[k, l])
            
            if n_all >=2:
                codelen_x_z += np.log(normterm_discrete(n_all, 2))
    
        n_k = np.sum(Z1 == k)
        codelen_z += -n_k * np.log(alpha[k])

    codelen_z += np.log(normterm_discrete(N, K))

    codelen = codelen_x_z + codelen_z

    return codelen, codelen_x_z, codelen_z

In [13]:
def calc_lsc(X, Z1, Z2, K=3, L=3):    
    codelen = 0.0
    N = X.shape[0]

    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg

            if n_all >=2:
                codelen += n_all * np.log(n_all)
            
            if n_pos >=2:
                codelen -= n_pos * np.log(n_pos)
            if n_neg >=2:
                codelen -= n_neg * np.log(n_neg)

        n_k = np.sum(Z1 == k)
        if n_k >= 1:
            codelen += n_k * (np.log(N) - np.log(n_k))

        codelen += (k + (k+1)*(k+2))/2 * np.log(N/(2.0*np.pi)) -(k+1)/2 * np.log(2.0) + \
           (k+1) * loggamma((k+3)/2) - loggamma((k+1)*(k+3)/2) + (k+1)*(k+2)/2 * np.log(np.pi)

    return codelen

In [14]:
def calc_stats(X, #z, 
               scores, scores_f, scores_l, h, delta, K=10):
    scores = np.array(scores)
    scores_f = np.array(scores_f)
    scores_l = np.array(scores_l)
    
    K = scores.shape[2]
    N_trial = scores.shape[0]
    T = scores.shape[1]
    
    codelens = np.array([ codelen_integer(k) for k in range(1, K+1)])
    
    idxes_all = np.argmin(scores + np.tile(codelens[np.newaxis, np.newaxis], (N_trial, T, 1))[0, 0, :], axis=2)
    
    models_estimated = np.nan * np.ones((N_trial, T), dtype=np.float)
    models_former = np.nan * np.ones((N_trial, T), dtype=np.float)
    models_latter = np.nan * np.ones((N_trial, T), dtype=np.float)
    stats_complete = np.nan * np.ones((N_trial, T), dtype=np.float)
    
    for trial in range(scores.shape[0]):
        n_change = 0  # number of changes so far.
        for t in range(h, T-h):
            alpha= (n_change+1/2) / (t+1+1)
            m_estimated = idxes_all[trial, t]

            # Lv.3 change (Model change)
            stats_half_t = np.zeros((K, K), dtype=np.float)
            for k1 in range(K):
                stats_former = scores_f[trial, t, k1]
                for k2 in range(K):
                    if k1 == k2:
                        p = 1.0 - alpha
                    else:
                        p = alpha/(K-1)
                    stats_latter = scores_l[trial, t, k2]
                    stats_half_t[k1, k2] = (stats_former + stats_latter) + codelens[k1] - np.log(p)
            m_former_estimated, m_latter_estimated = np.unravel_index(np.nanargmin(stats_half_t), (K, K))
            models_former[trial, t] = m_former_estimated
            models_latter[trial, t] = m_latter_estimated
            
            stat = 0.5 / h *(scores[trial, t, m_estimated] + codelens[m_estimated] - stats_half_t[m_former_estimated, m_latter_estimated])
            stats_complete[trial, t] = stat
            
            if (m_estimated == m_latter_estimated):
                model_t = m_estimated
            else:
                model_t = m_latter_estimated
                        
            if t >= 1:
                model_prev = models_estimated[trial, t-1]
                if model_t != model_prev:
                    n_change += 1            
            models_estimated[trial, t] = model_t
    
    return stats_complete, models_estimated, models_former, models_latter

In [15]:
def calc_stats_with_modelidx(scores, scores_f, scores_l, idxes_model, idxes_model_f, idxes_model_l, h):
    scores = np.array(scores)
    scores_f = np.array(scores_f)
    scores_l = np.array(scores_l)
    
    stats_complete = np.nan * np.ones((idxes_model.shape[0], idxes_model.shape[1]), dtype=np.float)
    for trial in range(idxes_model.shape[0]):
        for t in range(h, idxes_model.shape[1]-h):
            stat = 0.5/h * (scores[trial, t, int(idxes_model[trial, t])] - \
                            (scores_f[trial, t, int(idxes_model_f[trial, t])] + \
                            scores_l[trial, t, int(idxes_model_l[trial, t])] ))
            stats_complete[trial, t] = stat
            
    return stats_complete

In [16]:
def codelen_integer(k):
    codelen = np.log(2.865)
    while k >= 0.0:
        codelen += k
        k = np.log(k)
        
    return codelen

In [17]:
with open(os.path.join(indir, 'X_abrupt.pkl'), 'rb') as f:
    X_all = pickle.load(f)
with open(os.path.join(indir, 'Z_abrupt.pkl'), 'rb') as f:
    Z_true_all = pickle.load(f)

In [18]:
X_all.shape

(20, 80, 100, 100)

In [19]:
def estimate_sbm_each_trial(X, trial, K, T):
    pi_list_trial = []
    theta_list_trial = []
    z_list_trial = []
    
    for t in tqdm.tqdm(range(T)):
        seed = trial * T + t
        numpy2ri.activate()
        sbm = blockmodels.BM_bernoulli(membership_type="SBM", 
                                       adj=np.array(X[trial, t, :, :]),
                                       verbosity=0,
                                       exploration_factor=1.5,
                                       explore_min=K,
                                       explore_max=K)

        estimate = dollar(sbm, "estimate")
        estimate()

        pi_list = []
        theta_list = []
        z_posterior_list = []
        for k in range(K):
            n_clusters = k + 1
            theta = np.array(dollar(dollar(sbm, "model_parameters")[k], "pi"))
            z_posterior = np.array(dollar(dollar(sbm, "memberships")[k], "Z"))
            pi = np.sum(z_posterior, axis=0) + 10 * EPS
            pi /= np.sum(pi)

            theta_list.append(theta)
            z_posterior_list.append(z_posterior)
            pi_list.append(pi)

        numpy2ri.deactivate()
        
        pi_list_trial.append(pi_list)
        theta_list_trial.append(theta_list)
        z_list_trial.append(z_posterior_list)

    return pi_list_trial, theta_list_trial, z_list_trial

In [20]:
EPS = np.finfo(np.float).eps

N_trial = X_all.shape[0]
T = X_all.shape[1]

K = 10

pi1 = None
pi2 = None
a0 = 1.0
b0 = 1.0
ratio = 0.02

pi_all = []
theta_all = []
z_all = []

for trial in tqdm.tqdm(range(N_trial)):
    pi1 = None
    pi2 = None
    theta = None
    
    pi_list_trial = []
    theta_list_trial = []
    z_list_trial = []
    
    numpy2ri.activate()    
    
    for t in range(T):
        seed = trial*T + t
        
        X = X_all[trial, t, :, :]
            
        sbm = blockmodels.BM_bernoulli(membership_type="SBM", adj=np.array(X),
                                           verbosity=0,
                                           exploration_factor=1.5,
                                           explore_min=K,
                                           explore_max=K)

        estimate = dollar(sbm, "estimate")
        estimate()

        theta_list = []
        pi_list = []
        z_posterior_list = []
        for k in range(K):
            n_clusters = k + 1
            theta = np.array(dollar(dollar(sbm, "model_parameters")[k], "pi"))
            z_posterior = np.array(dollar(dollar(sbm, "memberships")[k], "Z"))
            pi = np.sum(z_posterior, axis=0) + 10 * EPS
            pi /= np.sum(pi)

            theta_list.append(theta)
            z_posterior_list.append(z_posterior)
            pi_list.append(pi)
       
        pi_list_trial.append(pi_list)
        theta_list_trial.append(theta_list)
        z_list_trial.append(z_posterior_list)
        
    numpy2ri.deactivate()

    pi_all.append(pi_list_trial)
    theta_all.append(theta_list_trial)
    z_all.append(z_list_trial)

    with open(os.path.join(outdir, 'pi_abrupt.pkl'), 'wb') as f:
        pickle.dump(pi_all, f)
    with open(os.path.join(outdir, 'theta_abrupt.pkl'), 'wb') as f:
        pickle.dump(theta_all, f)
    with open(os.path.join(outdir, 'z_abrupt.pkl'), 'wb') as f:
        pickle.dump(z_all, f)

  0%|          | 0/20 [00:00<?, ?it/s]





















































































  5%|▌         | 1/20 [05:11<1:38:37, 311.45s/it]





















































































 10%|█         | 2/20 [11:34<1:39:50, 332.82s/it]





















































































 15%|█▌        | 3/20 [16:41<1:32:09, 325.26s/it]





















































































 20%|██        | 4/20 [20:21<1:18:19, 293.72s/it]





















































































 25%|██▌       | 5/20 [48:50<2:59:33, 718.25s/it]





















































































 30%|███       | 6/20 [52:59<2:14:43, 577.38s/it]





















































































 35%|███▌      | 7/20 [1:22:39<3:23:16, 938.21s/it]





















































































 40%|████      | 8/20 [1:55:22<4:09:07, 1245.64s/it]





















































































 45%|████▌     | 9/20 [2:01:38<3:00:31, 984.69s/it] 





















































































 50%|█████     | 10/20 [2:07:06<2:11:18, 787.87s/it]





















































































 55%|█████▌    | 11/20 [2:14:26<1:42:29, 683.30s/it]





















































































 60%|██████    | 12/20 [2:20:52<1:19:13, 594.13s/it]





















































































 65%|██████▌   | 13/20 [2:25:14<57:42, 494.64s/it]  





















































































 70%|███████   | 14/20 [2:30:54<44:48, 448.11s/it]





















































































 75%|███████▌  | 15/20 [2:38:04<36:53, 442.67s/it]





















































































 80%|████████  | 16/20 [2:43:59<27:45, 416.44s/it]





















































































 85%|████████▌ | 17/20 [6:07:19<3:17:34, 3951.35s/it]





















































































 90%|█████████ | 18/20 [6:14:12<1:36:20, 2890.13s/it]





















































































 95%|█████████▌| 19/20 [6:22:14<36:07, 2167.47s/it]  





















































































100%|██████████| 20/20 [7:16:12<00:00, 2488.79s/it]


In [21]:
with open(os.path.join(outdir, 'pi_abrupt.pkl'), 'rb') as f:
    pi_all = pickle.load(f)
with open(os.path.join(outdir, 'theta_abrupt.pkl'), 'rb') as f:
    theta_all = pickle.load(f)
with open(os.path.join(outdir, 'z_abrupt.pkl'), 'rb') as f:
    Z_all = pickle.load(f)

In [22]:
# relabeling
N_trial = X_all.shape[0]
K = 10
T = X_all.shape[1]

numpy2ri.activate()
for trial in tqdm.tqdm(range(N_trial)):
    for k in range(1, K):
        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(20)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(20)])
        )
        permutations = np.array(dollar(run, "permutations"))

        for t in range(20):
            Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[t, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(19, 38)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(19, 38)])
        )
        permutations = np.array(dollar(run, "permutations"))

        for i, t in enumerate(range(19, 38)):
            if t == 19:
                for tt in range(20):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]
        
        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(37, 41)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(37, 41)])
        )        
        permutations = np.array(dollar(run, "permutations"))
        
        for i, t in enumerate(range(37, 41)):
            if t == 37:
                for tt in range(38):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(40, 58)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(40, 58)])
        )
        permutations = np.array(dollar(run, "permutations"))
        
        for i, t in enumerate(range(40, 58)):
            if t == 40:
                for tt in range(41):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(57, 61)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(57, 61)])
        )
        permutations = np.array(dollar(run, "permutations"))

        for i, t in enumerate(range(57, 61)):
            if t == 57:
                for tt in range(58):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(60, 80)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(60, 80)])
        )        

        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(60, 80)):
            if t == 60:
                for tt in range(61):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

numpy2ri.deactivate()

100%|██████████| 20/20 [00:23<00:00,  1.17s/it]


In [23]:
with open(os.path.join(outdir, 'z_abrupt.pkl'), 'wb') as f:
    pickle.dump(Z_all, f)

In [25]:
n_trial = X_all.shape[0]
T = X_all.shape[1]
K = 10

for h in tqdm.tqdm([1, 2, 3]):
    dnml_whole_list, nml_x_z_whole_list, nml_z_whole_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    dnml_former_list, nml_x_z_former_list, nml_z_former_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    dnml_latter_list, nml_x_z_latter_list, nml_z_latter_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))

    for trial in tqdm.tqdm(range(n_trial)):
        for t in range(h, T-h):
            for k in range(K):
                t_start = t-h
                t_end = t+h

                # whole
                n_whole = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)] 
                                  for tt in range(t_start, t_end)], axis=0)

                theta_hat_whole = np.sum(
                        [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                           for k2 in range(k+1)] for k1 in range(k+1)] 
                           for tt in range(t_start, t_end)], axis=0) / \
                        np.sum(
                            [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                            for tt in range(t_start, t_end)], axis=0)

                theta_hat_whole = theta_hat_whole[n_whole !=0, :][:, n_whole !=0]
                pi_hat_whole = n_whole[n_whole !=0] /np.sum(n_whole)

                n_cluster_whole = len(pi_hat_whole)

                res_w = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_whole, theta_hat_whole,
                                                    n_cluster_whole, n_cluster_whole) 
                                  for tt in range(t_start, t_end)])

                dnml_whole = np.sum(res_w[:, 0])
                nml_x_z_whole = np.sum(res_w[:, 1])
                nml_z_whole = np.sum(res_w[:, 2])

                val = dnml_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_whole < val) )):
                    dnml_whole_list[trial, t, n_cluster_whole-1] = dnml_whole

                val = nml_x_z_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_whole < val) )):
                    nml_x_z_whole_list[trial, t, n_cluster_whole-1] = nml_x_z_whole

                val = nml_z_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_whole < val) )):
                    nml_z_whole_list[trial, t, n_cluster_whole-1] = nml_z_whole

                # former
                n_former = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)] 
                                   for tt in range(t_start, t_start+h)], axis=0)

                theta_hat_former = np.sum(
                        [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                         for tt in range(t_start, t_start+h)], axis=0) / \
                        np.sum(
                            [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)]
                            for tt in range(t_start, t_start+h)], axis=0)

                theta_hat_former = theta_hat_former[n_former !=0, :][:, n_former !=0]
                pi_hat_former = n_former[n_former !=0] /np.sum(n_former)

                n_cluster_former = len(pi_hat_former)

                res_f = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_former, theta_hat_former,
                                                    n_cluster_former, n_cluster_former) 
                                  for tt in range(t_start, t_start+h)])

                dnml_former = np.sum(res_f[:, 0])
                nml_x_z_former = np.sum(res_f[:, 1])
                nml_z_former = np.sum(res_f[:, 2])

                val = dnml_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_former < val) )):
                    dnml_former_list[trial, t, n_cluster_former-1] = dnml_former

                val = nml_x_z_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_former < val) )):
                    nml_x_z_former_list[trial, t, n_cluster_former-1] = nml_x_z_former

                val = nml_z_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_z_former < val) )):
                    nml_z_former_list[trial, t, n_cluster_former-1] = nml_z_former

                # latter
                n_latter = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)]
                                   for tt in range(t_start+h, t_end)], axis=0)

                theta_hat_latter = np.sum(
                         [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                            for k2 in range(k+1)] for k1 in range(k+1)] for tt in range(t_start+h, t_end)], axis=0) / \
                        np.sum(
                         [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] for tt in range(t_start+h, t_end)], axis=0)

                theta_hat_latter = theta_hat_latter[n_latter != 0, :][:, n_latter !=0]
                pi_hat_latter = n_latter[n_latter != 0] /np.sum(n_latter)

                n_cluster_latter = len(pi_hat_latter)
                res_l = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_latter, theta_hat_latter,
                                                    n_cluster_latter, n_cluster_latter) for tt in range(t_start+h, t_end)])

                dnml_latter = np.sum(res_l[:, 0])
                nml_x_z_latter = np.sum(res_l[:, 1])
                nml_z_latter = np.sum(res_l[:, 2])

                val = dnml_latter_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_latter < val) )):
                    dnml_latter_list[trial, t, n_cluster_latter-1] = dnml_latter

                val = nml_x_z_latter_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_latter < val) )):
                    nml_x_z_latter_list[trial, t, n_cluster_latter-1] = nml_x_z_latter

                val = nml_z_latter_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_z_latter < val) )):
                    nml_z_latter_list[trial, t, n_cluster_latter-1] = nml_z_latter

    with open(os.path.join(outdir, 'dnml_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_whole_list, f)
    with open(os.path.join(outdir, 'dnml_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_former_list, f)
    with open(os.path.join(outdir, 'dnml_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_latter_list, f)

    with open(os.path.join(outdir, 'nml_x_z_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_whole_list, f)
    with open(os.path.join(outdir, 'nml_x_z_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_former_list, f)
    with open(os.path.join(outdir, 'nml_x_z_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_latter_list, f)

    with open(os.path.join(outdir, 'nml_z_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_whole_list, f)
    with open(os.path.join(outdir, 'nml_z_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_former_list, f)
    with open(os.path.join(outdir, 'nml_z_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_latter_list, f)

  0%|          | 0/3 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:08<02:41,  8.52s/it][A
 10%|█         | 2/20 [00:18<02:41,  8.97s/it][A
 15%|█▌        | 3/20 [00:30<02:46,  9.77s/it][A
 20%|██        | 4/20 [00:38<02:29,  9.33s/it][A
 25%|██▌       | 5/20 [00:49<02:26,  9.79s/it][A
 30%|███       | 6/20 [00:59<02:16,  9.76s/it][A
 35%|███▌      | 7/20 [01:10<02:13, 10.24s/it][A
 40%|████      | 8/20 [01:20<02:02, 10.23s/it][A
 45%|████▌     | 9/20 [01:30<01:51, 10.13s/it][A
 50%|█████     | 10/20 [01:40<01:41, 10.11s/it][A
 55%|█████▌    | 11/20 [01:50<01:31, 10.15s/it][A
 60%|██████    | 12/20 [02:00<01:19,  9.92s/it][A
 65%|██████▌   | 13/20 [02:09<01:07,  9.61s/it][A
 70%|███████   | 14/20 [02:18<00:57,  9.51s/it][A
 75%|███████▌  | 15/20 [02:29<00:50, 10.03s/it][A
 80%|████████  | 16/20 [02:39<00:40, 10.14s/it][A
 85%|████████▌ | 17/20 [02:49<00:29,  9.88s/it][A
 90%|█████████ | 18/20 [03:00<00:20, 10.17s/it][A
 95%|█████████▌| 19/20 [03:10<00:10, 10.16s/it][A
1