In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_orig()

from matplotlib import rc
rc('text', usetex=True)
rc('font', **{'family' : "sans-serif"})
params = {'text.latex.preamble' : [r'\usepackage{siunitx}', r'\usepackage{amsmath}']}
plt.rcParams.update(params)

from scipy.stats import bernoulli
from scipy.special import loggamma

from sklearn.preprocessing import OneHotEncoder

import tqdm

from rpy2.robjects import numpy2ri
from rpy2.robjects.packages import importr

In [2]:
import os
import pickle

In [3]:
from joblib import Parallel, delayed

In [4]:
import warnings
warnings.resetwarnings()
warnings.simplefilter('ignore', UserWarning)

In [5]:
EPS = np.finfo(np.float).eps

In [6]:
pd.options.display.max_rows = 200
pd.options.display.max_columns = 200

In [7]:
indir = './data'
outdir = './output'
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [8]:
# R packages
## blockmodels
blockmodels = importr("blockmodels")
## label.switching
label_switching = importr("label.switching")
## base
base = importr("base")

dollar = base.__dict__["$"]

In [9]:
from functools import lru_cache

@lru_cache(maxsize = 10000)
def normterm_discrete(n, k):
    if n == 1:
        return np.log(k)
    if k == 1:
        return 1.0
    elif k == 2:
        return np.sum(sorted([ np.exp(loggamma(n+1) - loggamma(t+1) - loggamma(n-t+1) + 
                               t*(np.log(t) - np.log(n)) + (n-t)*(np.log(n-t) - np.log(n))
                        )
                        for t in range(1, n)]))
    else:
        return normterm_discrete(n, k-1) + n/(k-2) * normterm_discrete(n, k-2)

In [10]:
def check_latent_index_variable(z):
    unique_z = sorted(np.unique(z))
    if len(unique_z) == np.max(z) + 1:
        return z
    new_z = np.zeros(z.shape, dtype=np.int)
    for index, current in enumerate(unique_z):
        new_z[z == current] = index
    return new_z

In [11]:
def calc_dnml(X, Z1, Z2, K=3, L=3):    
    N = X.shape[0]
    
    codelen_x_z = 0.0
    codelen_z = 0.0

    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg

            if n_all >=2:
                codelen_x_z += n_all * np.log(n_all)
                codelen_x_z += np.log(normterm_discrete(n_all, 2))
            
            if n_pos >=2:
                codelen_x_z -= n_pos * np.log(n_pos)
            if n_neg >=2:
                codelen_x_z -= n_neg * np.log(n_neg)
            
        n_k = np.sum(Z1 == k)
        if n_k >= 1:
            codelen_z += n_k * (np.log(N) - np.log(n_k))

    codelen_z += np.log(normterm_discrete(N, K))
    
    codelen = codelen_x_z + codelen_z

    return codelen, codelen_x_z, codelen_z

In [12]:
def calc_dnml_by_prob(X, Z, alpha, theta, K=3, L=3, eps=1e-12):    
    N = X.shape[0]
    
    Z1 = np.argmax(Z, axis=1)
    
    Z1 = check_latent_index_variable(Z1)
    Z2 = Z1
    
    codelen_x_z = 0.0
    codelen_z = 0.0
        
    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg
            
            if theta[k, l] < eps:
                theta[k, l] = eps
            
            if theta[k, l] > 1.0 - eps:
                theta[k, l] = 1.0 - eps
            
            codelen_x_z += -n_pos * np.log(theta[k, l]) - n_neg * np.log(1.0 - theta[k, l])
            
            if n_all >=2:
                codelen_x_z += np.log(normterm_discrete(n_all, 2))
    
        n_k = np.sum(Z1 == k)
        codelen_z += -n_k * np.log(alpha[k])

    codelen_z += np.log(normterm_discrete(N, K))

    codelen = codelen_x_z + codelen_z

    return codelen, codelen_x_z, codelen_z

In [13]:
def calc_loglik_by_prob(X, Z, alpha, theta, K=3, L=3, eps=1e-12):    
    N = X.shape[0]
    
    Z1 = np.argmax(Z, axis=1)
    
    Z1 = check_latent_index_variable(Z1)
    Z2 = Z1
    
    loglik_x_z = 0.0
    loglik_z = 0.0
        
    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg
            
            if theta[k, l] < eps:
                theta[k, l] = eps
            
            if theta[k, l] > 1.0 - eps:
                theta[k, l] = 1.0 - eps
            
            loglik_x_z += -n_pos * np.log(theta[k, l]) - n_neg * np.log(1.0 - theta[k, l])
                
        n_k = np.sum(Z1 == k)
        loglik_z += -n_k * np.log(alpha[k])

    loglik = loglik_x_z + loglik_z

    return loglik, loglik_x_z, loglik_z

In [14]:
def calc_lsc(X, Z1, Z2, K=3, L=3):    
    codelen = 0.0
    N = X.shape[0]

    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg

            if n_all >=2:
                codelen += n_all * np.log(n_all)
            
            if n_pos >=2:
                codelen -= n_pos * np.log(n_pos)
            if n_neg >=2:
                codelen -= n_neg * np.log(n_neg)

        n_k = np.sum(Z1 == k)
        if n_k >= 1:
            codelen += n_k * (np.log(N) - np.log(n_k))

        codelen += (k + (k+1)*(k+2))/2 * np.log(N/(2.0*np.pi)) -(k+1)/2 * np.log(2.0) + \
           (k+1) * loggamma((k+3)/2) - loggamma((k+1)*(k+3)/2) + (k+1)*(k+2)/2 * np.log(np.pi)

    return codelen

In [15]:
def calc_stats(X, #z, 
               scores, scores_f, scores_l, h, delta, K=10):
    scores = np.array(scores)
    scores_f = np.array(scores_f)
    scores_l = np.array(scores_l)
    
    K = scores.shape[2]
    N_trial = scores.shape[0]
    T = scores.shape[1]
    
    codelens = np.array([ codelen_integer(k) for k in range(1, K+1)])
    
    idxes_all = np.argmin(scores + np.tile(codelens[np.newaxis, np.newaxis], (N_trial, T, 1))[0, 0, :], axis=2)
    
    models_estimated = np.nan * np.ones((N_trial, T), dtype=np.float)
    models_former = np.nan * np.ones((N_trial, T), dtype=np.float)
    models_latter = np.nan * np.ones((N_trial, T), dtype=np.float)
    stats_complete = np.nan * np.ones((N_trial, T), dtype=np.float)
    
    for trial in range(scores.shape[0]):
        n_change = 0  # number of changes so far.
        for t in range(h, T-h):
            alpha= (n_change+1/2) / (t+1+1)
            m_estimated = idxes_all[trial, t]

            # Lv.3 change (Model change)
            stats_half_t = np.zeros((K, K), dtype=np.float)
            for k1 in range(K):
                stats_former = scores_f[trial, t, k1]
                for k2 in range(K):
                    if k1 == k2:
                        p = 1.0 - alpha
                    else:
                        p = alpha/(K-1)
                    stats_latter = scores_l[trial, t, k2]
                    stats_half_t[k1, k2] = (stats_former + stats_latter) + codelens[k1] - np.log(p)
            m_former_estimated, m_latter_estimated = np.unravel_index(np.nanargmin(stats_half_t), (K, K))
            models_former[trial, t] = m_former_estimated
            models_latter[trial, t] = m_latter_estimated
            
            stat = 0.5 / h *(scores[trial, t, m_estimated] + codelens[m_estimated] - stats_half_t[m_former_estimated, m_latter_estimated])
            stats_complete[trial, t] = stat
            
            if (m_estimated == m_latter_estimated):
                model_t = m_estimated
            else:
                model_t = m_latter_estimated
                        
            if t >= 1:
                model_prev = models_estimated[trial, t-1]
                if model_t != model_prev:
                    n_change += 1            
            models_estimated[trial, t] = model_t
    
    return stats_complete, models_estimated, models_former, models_latter

In [16]:
def calc_stats_with_modelidx(scores, scores_f, scores_l, idxes_model, idxes_model_f, idxes_model_l, h):
    scores = np.array(scores)
    scores_f = np.array(scores_f)
    scores_l = np.array(scores_l)
    
    stats_complete = np.nan * np.ones((idxes_model.shape[0], idxes_model.shape[1]), dtype=np.float)
    for trial in range(idxes_model.shape[0]):
        for t in range(h, idxes_model.shape[1]-h):
            stat = 0.5/h * (scores[trial, t, int(idxes_model[trial, t])] - \
                            (scores_f[trial, t, int(idxes_model_f[trial, t])] + \
                            scores_l[trial, t, int(idxes_model_l[trial, t])] ))
            stats_complete[trial, t] = stat
            
    return stats_complete

In [17]:
def codelen_integer(k):
    codelen = np.log(2.865)
    while k >= 0.0:
        codelen += k
        k = np.log(k)
        
    return codelen

In [18]:
with open(os.path.join(indir, 'X_abrupt.pkl'), 'rb') as f:
    X_all = pickle.load(f)
with open(os.path.join(indir, 'Z_abrupt.pkl'), 'rb') as f:
    Z_true_all = pickle.load(f)

In [19]:
X_all.shape

(10, 80, 100, 100)

In [20]:
def estimate_sbm_each_trial(X, trial, K, T):
    pi_list_trial = []
    theta_list_trial = []
    z_list_trial = []
    
    for t in tqdm.tqdm(range(T)):
        seed = trial * T + t
        numpy2ri.activate()
        sbm = blockmodels.BM_bernoulli(membership_type="SBM", 
                                       adj=np.array(X[trial, t, :, :]),
                                       verbosity=0,
                                       exploration_factor=1.5,
                                       explore_min=K,
                                       explore_max=K)

        estimate = dollar(sbm, "estimate")
        estimate()

        pi_list = []
        theta_list = []
        z_posterior_list = []
        for k in range(K):
            n_clusters = k + 1
            theta = np.array(dollar(dollar(sbm, "model_parameters")[k], "pi"))
            z_posterior = np.array(dollar(dollar(sbm, "memberships")[k], "Z"))
            pi = np.sum(z_posterior, axis=0) + 10 * EPS
            pi /= np.sum(pi)

            theta_list.append(theta)
            z_posterior_list.append(z_posterior)
            pi_list.append(pi)

        numpy2ri.deactivate()
        
        pi_list_trial.append(pi_list)
        theta_list_trial.append(theta_list)
        z_list_trial.append(z_posterior_list)

    return pi_list_trial, theta_list_trial, z_list_trial

In [21]:
EPS = np.finfo(np.float).eps

N_trial = X_all.shape[0]
T = X_all.shape[1]

K = 10

pi1 = None
pi2 = None
a0 = 1.0
b0 = 1.0
ratio = 0.02

pi_all = []
theta_all = []
z_all = []

for trial in tqdm.tqdm(range(N_trial)):
    pi1 = None
    pi2 = None
    theta = None
    
    pi_list_trial = []
    theta_list_trial = []
    z_list_trial = []
    
    numpy2ri.activate()    
    
    for t in range(T):
        seed = trial*T + t
        
        X = X_all[trial, t, :, :]
            
        sbm = blockmodels.BM_bernoulli(membership_type="SBM", adj=np.array(X),
                                           verbosity=0,
                                           exploration_factor=1.5,
                                           explore_min=K,
                                           explore_max=K)

        estimate = dollar(sbm, "estimate")
        estimate()

        theta_list = []
        pi_list = []
        z_posterior_list = []
        for k in range(K):
            n_clusters = k + 1
            theta = np.array(dollar(dollar(sbm, "model_parameters")[k], "pi"))
            z_posterior = np.array(dollar(dollar(sbm, "memberships")[k], "Z"))
            pi = np.sum(z_posterior, axis=0) + 10 * EPS
            pi /= np.sum(pi)

            theta_list.append(theta)
            z_posterior_list.append(z_posterior)
            pi_list.append(pi)
       
        pi_list_trial.append(pi_list)
        theta_list_trial.append(theta_list)
        z_list_trial.append(z_posterior_list)
        
    numpy2ri.deactivate()

    pi_all.append(pi_list_trial)
    theta_all.append(theta_list_trial)
    z_all.append(z_list_trial)

    with open(os.path.join(outdir, 'pi_abrupt.pkl'), 'wb') as f:
        pickle.dump(pi_all, f)
    with open(os.path.join(outdir, 'theta_abrupt.pkl'), 'wb') as f:
        pickle.dump(theta_all, f)
    with open(os.path.join(outdir, 'z_abrupt.pkl'), 'wb') as f:
        pickle.dump(z_all, f)

  0%|          | 0/10 [00:00<?, ?it/s]



















































































 10%|█         | 1/10 [07:59<1:11:58, 479.78s/it]



















































































 20%|██        | 2/10 [18:45<1:10:35, 529.50s/it]



















































































 30%|███       | 3/10 [27:29<1:01:36, 528.04s/it]



















































































 40%|████      | 4/10 [35:13<50:52, 508.68s/it]  



















































































 50%|█████     | 5/10 [43:11<41:38, 499.62s/it]



















































































 60%|██████    | 6/10 [49:30<30:52, 463.21s/it]



















































































 70%|███████   | 7/10 [57:37<23:31, 470.58s/it]



















































































 80%|████████  | 8/10 [1:07:51<17:06, 513.47s/it]



















































































 90%|█████████ | 9/10 [1:18:07<09:04, 544.35s/it]



















































































100%|██████████| 10/10 [1:27:37<00:00, 525.79s/it]


In [22]:
with open(os.path.join(outdir, 'pi_abrupt.pkl'), 'rb') as f:
    pi_all = pickle.load(f)
with open(os.path.join(outdir, 'theta_abrupt.pkl'), 'rb') as f:
    theta_all = pickle.load(f)
with open(os.path.join(outdir, 'z_abrupt.pkl'), 'rb') as f:
    Z_all = pickle.load(f)

In [23]:
# relabeling
N_trial = X_all.shape[0]
K = 10
T = X_all.shape[1]

numpy2ri.activate()
for trial in tqdm.tqdm(range(N_trial)):
    for k in range(1, K):
        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(20)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(20)])
        )
        permutations = np.array(dollar(run, "permutations"))

        for t in range(20):
            Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[t, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(19, 38)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(19, 38)])
        )
        permutations = np.array(dollar(run, "permutations"))

        for i, t in enumerate(range(19, 38)):
            if t == 19:
                for tt in range(20):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]
        
        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(37, 41)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(37, 41)])
        )        
        permutations = np.array(dollar(run, "permutations"))
        
        for i, t in enumerate(range(37, 41)):
            if t == 37:
                for tt in range(38):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(40, 58)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(40, 58)])
        )
        permutations = np.array(dollar(run, "permutations"))
        
        for i, t in enumerate(range(40, 58)):
            if t == 40:
                for tt in range(41):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(57, 61)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(57, 61)])
        )
        permutations = np.array(dollar(run, "permutations"))

        for i, t in enumerate(range(57, 61)):
            if t == 57:
                for tt in range(58):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(60, 80)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(60, 80)])
        )        

        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(60, 80)):
            if t == 60:
                for tt in range(61):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

numpy2ri.deactivate()

100%|██████████| 10/10 [00:18<00:00,  1.84s/it]


In [24]:
with open(os.path.join(outdir, 'z_abrupt.pkl'), 'wb') as f:
    pickle.dump(Z_all, f)

In [25]:
from itertools import product

In [26]:
def calc_sc(X_array, Z_array, K, L):
    T = X_array.shape[0]
    N = X_array.shape[1]
    
    print('T = %d' % (T))
    
    logp = 0.0
    for k in range(K):
        for l in range(L):
            n_T = np.zeros(T, dtype=np.int)
            for t in range(T):
                n_k = np.sum(Z_array[t, :] == k)
                n_l = np.sum(Z_array[t, :] == l)
                n_T[t] = n_k * n_l

            for n in tqdm.tqdm(product(*[np.arange(n_T[t]+1) for t in range(T)])):
                for t in range(T):
                    logp += loggamma(N+1) - loggamma(n[t]+1) - loggamma(N -n[t]+1)
                n_all = np.sum(n_T)
                n_pos = np.sum(n)
                n_neg = n_all - n_pos
                logp += -n_pos * (np.log(n_pos) - np.log(n_all)) - n_neg * (np.log(n_neg) - np.log(n_all))
    return logp

In [27]:
n_trial = X_all.shape[0]
T = X_all.shape[1]
K = 10

for h in tqdm.tqdm([1, 2, 3]):
    dnml_whole_list, nml_x_z_whole_list, nml_z_whole_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    dnml_former_list, nml_x_z_former_list, nml_z_former_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    dnml_latter_list, nml_x_z_latter_list, nml_z_latter_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    
    theta_hat_whole_list_k3, theta_hat_former_list_k3, theta_hat_latter_list_k3 = \
        np.nan * np.ones((n_trial, T, 3, 3)), np.nan * np.ones((n_trial, T, 3, 3)), np.nan * np.ones((n_trial, T, 3, 3))
    
    theta_hat_whole_list_k4, theta_hat_former_list_k4, theta_hat_latter_list_k4 = \
        np.nan * np.ones((n_trial, T, 4, 4)), np.nan * np.ones((n_trial, T, 4, 4)), np.nan * np.ones((n_trial, T, 4, 4))

    
    for trial in tqdm.tqdm(range(n_trial)):
        for t in range(h, T-h):
            for k in range(K):
                t_start = t-h
                t_end = t+h

                # whole
                n_whole = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)] 
                                  for tt in range(t_start, t_end)], axis=0)

                theta_hat_whole = np.sum(
                        [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                           for k2 in range(k+1)] for k1 in range(k+1)] 
                           for tt in range(t_start, t_end)], axis=0) / \
                        np.sum(
                            [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                            for tt in range(t_start, t_end)], axis=0)

                theta_hat_whole = theta_hat_whole[n_whole !=0, :][:, n_whole !=0]
                pi_hat_whole = n_whole[n_whole !=0] /np.sum(n_whole)

                n_cluster_whole = len(pi_hat_whole)

                res_w = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_whole, theta_hat_whole,
                                                    n_cluster_whole, n_cluster_whole) 
                                  for tt in range(t_start, t_end)])

                dnml_whole = np.sum(res_w[:, 0])
                nml_x_z_whole = np.sum(res_w[:, 1])
                nml_z_whole = np.sum(res_w[:, 2])

                val = dnml_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_whole < val) )):
                    dnml_whole_list[trial, t, n_cluster_whole-1] = dnml_whole
                    
                val = nml_x_z_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_whole < val) )):
                    nml_x_z_whole_list[trial, t, n_cluster_whole-1] = nml_x_z_whole
                    
                val = nml_z_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_z_whole < val) )):
                    nml_z_whole_list[trial, t, n_cluster_whole-1] = nml_z_whole
                                    
                # former
                n_former = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)] 
                                   for tt in range(t_start, t_start+h)], axis=0)

                theta_hat_former = np.sum(
                        [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                         for tt in range(t_start, t_start+h)], axis=0) / \
                        np.sum(
                            [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)]
                            for tt in range(t_start, t_start+h)], axis=0)

                theta_hat_former = theta_hat_former[n_former !=0, :][:, n_former !=0]
                pi_hat_former = n_former[n_former !=0] /np.sum(n_former)

                n_cluster_former = len(pi_hat_former)

                res_f = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_former, theta_hat_former,
                                                    n_cluster_former, n_cluster_former) 
                                  for tt in range(t_start, t_start+h)])

                dnml_former = np.sum(res_f[:, 0])
                nml_x_z_former = np.sum(res_f[:, 1])
                nml_z_former = np.sum(res_f[:, 2])

                val = dnml_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_former < val) )):
                    dnml_former_list[trial, t, n_cluster_former-1] = dnml_former

                val = nml_x_z_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_former < val) )):
                    nml_x_z_former_list[trial, t, n_cluster_former-1] = nml_x_z_former
                    
                val = nml_z_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_z_former < val) )):
                    nml_z_former_list[trial, t, n_cluster_former-1] = nml_z_former

                # latter
                n_latter = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)]
                                   for tt in range(t_start+h, t_end)], axis=0)

                theta_hat_latter = np.sum(
                         [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                            for k2 in range(k+1)] for k1 in range(k+1)] for tt in range(t_start+h, t_end)], axis=0) / \
                        np.sum(
                         [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] for tt in range(t_start+h, t_end)], axis=0)

                theta_hat_latter = theta_hat_latter[n_latter != 0, :][:, n_latter !=0]
                pi_hat_latter = n_latter[n_latter != 0] /np.sum(n_latter)

                n_cluster_latter = len(pi_hat_latter)
                res_l = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_latter, theta_hat_latter,
                                                    n_cluster_latter, n_cluster_latter) for tt in range(t_start+h, t_end)])

                dnml_latter = np.sum(res_l[:, 0])
                nml_x_z_latter = np.sum(res_l[:, 1])
                nml_z_latter = np.sum(res_l[:, 2])

                val = dnml_latter_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_latter < val) )):
                    dnml_latter_list[trial, t, n_cluster_latter-1] = dnml_latter

                val = nml_x_z_latter_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_x_z_latter < val) )):
                    nml_x_z_latter_list[trial, t, n_cluster_latter-1] = nml_x_z_latter
                    
                val = nml_z_whole_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (nml_z_latter < val) )):
                    nml_z_latter_list[trial, t, n_cluster_latter-1] = nml_z_latter
                
    with open(os.path.join(outdir, 'dnml_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_whole_list, f)
    with open(os.path.join(outdir, 'dnml_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_former_list, f)
    with open(os.path.join(outdir, 'dnml_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_latter_list, f)

    with open(os.path.join(outdir, 'nml_x_z_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_whole_list, f)
    with open(os.path.join(outdir, 'nml_x_z_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_former_list, f)
    with open(os.path.join(outdir, 'nml_x_z_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_latter_list, f)

    with open(os.path.join(outdir, 'nml_z_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_whole_list, f)
    with open(os.path.join(outdir, 'nml_z_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_former_list, f)
    with open(os.path.join(outdir, 'nml_z_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_latter_list, f)

    with open(os.path.join(outdir, 'theta_hat_all_k3_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(theta_hat_whole_list_k3, f)
    with open(os.path.join(outdir, 'theta_hat_former_k3_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(theta_hat_former_list_k3, f)
    with open(os.path.join(outdir, 'theta_hat_latter_k3_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(theta_hat_latter_list_k3, f)

    with open(os.path.join(outdir, 'theta_hat_all_k4_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(theta_hat_whole_list_k4, f)
    with open(os.path.join(outdir, 'theta_hat_former_k4_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(theta_hat_former_list_k4, f)
    with open(os.path.join(outdir, 'theta_hat_latter_k4_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(theta_hat_latter_list_k4, f)

  0%|          | 0/3 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:22<03:22, 22.46s/it][A
 20%|██        | 2/10 [00:45<03:00, 22.60s/it][A
 30%|███       | 3/10 [01:05<02:34, 22.00s/it][A
 40%|████      | 4/10 [01:20<01:58, 19.75s/it][A
 50%|█████     | 5/10 [01:39<01:37, 19.53s/it][A
 60%|██████    | 6/10 [01:57<01:15, 18.95s/it][A
 70%|███████   | 7/10 [02:18<00:58, 19.56s/it][A
 80%|████████  | 8/10 [02:36<00:38, 19.30s/it][A
 90%|█████████ | 9/10 [02:56<00:19, 19.43s/it][A
100%|██████████| 10/10 [03:14<00:00, 19.46s/it][A
 33%|███▎      | 1/3 [03:14<06:29, 194.65s/it]
  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:30<04:36, 30.71s/it][A
 20%|██        | 2/10 [01:07<04:19, 32.45s/it][A
 30%|███       | 3/10 [01:44<03:56, 33.76s/it][A
 40%|████      | 4/10 [02:16<03:20, 33.44s/it][A
 50%|█████     | 5/10 [02:52<02:49, 33.99s/it][A
 60%|██████    | 6/10 [03:31<02:22, 35.61s/it][A
 70%|███████   | 7/10 [04:24<02:02, 40.83s/it][A
 80%|████████  | 8/10