In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_orig()

from matplotlib import rc
rc('text', usetex=True)
rc('font', **{'family' : "sans-serif"})
params = {'text.latex.preamble' : [r'\usepackage{siunitx}', r'\usepackage{amsmath}']}
plt.rcParams.update(params)

from scipy.stats import bernoulli
from scipy.special import loggamma

from sklearn.preprocessing import OneHotEncoder

import tqdm

from rpy2.robjects import numpy2ri
from rpy2.robjects.packages import importr

In [None]:
import os
import pickle

In [None]:
from joblib import Parallel, delayed

In [None]:
import warnings
warnings.resetwarnings()
warnings.simplefilter('ignore', UserWarning)

In [None]:
EPS = np.finfo(np.float).eps

In [None]:
pd.options.display.max_rows = 200
pd.options.display.max_columns = 200

In [None]:
indir = './data'
outdir = './output'
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [None]:
# R packages
## blockmodels
blockmodels = importr("blockmodels")
## label.switching
label_switching = importr("label.switching")
## base
base = importr("base")

dollar = base.__dict__["$"]

In [None]:
from functools import lru_cache

@lru_cache(maxsize = 10000)
def normterm_discrete(n, k):
    if n == 1:
        return np.log(k)
    if k == 1:
        return 1.0
    elif k == 2:
        return np.sum(sorted([ np.exp(loggamma(n+1) - loggamma(t+1) - loggamma(n-t+1) + 
                               t*(np.log(t) - np.log(n)) + (n-t)*(np.log(n-t) - np.log(n))
                        )
                        for t in range(1, n)]))
    else:
        return normterm_discrete(n, k-1) + n/(k-2) * normterm_discrete(n, k-2)

In [10]:
def check_latent_index_variable(z):
    unique_z = sorted(np.unique(z))
    if len(unique_z) == np.max(z) + 1:
        return z
    new_z = np.zeros(z.shape, dtype=np.int)
    for index, current in enumerate(unique_z):
        new_z[z == current] = index
    return new_z

In [11]:
def calc_dnml(X, Z1, Z2, K=3, L=3):    
    N = X.shape[0]
    
    codelen_x_z = 0.0
    codelen_z = 0.0

    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg

            if n_all >=2:
                codelen_x_z += n_all * np.log(n_all)
                codelen_x_z += np.log(normterm_discrete(n_all, 2))
            
            if n_pos >=2:
                codelen_x_z -= n_pos * np.log(n_pos)
            if n_neg >=2:
                codelen_x_z -= n_neg * np.log(n_neg)
            
        n_k = np.sum(Z1 == k)
        if n_k >= 1:
            codelen_z += n_k * (np.log(N) - np.log(n_k))

    codelen_z += np.log(normterm_discrete(N, K))
    
    codelen = codelen_x_z + codelen_z

    return codelen, codelen_x_z, codelen_z

In [12]:
def calc_dnml_by_prob(X, Z, alpha, theta, K=3, L=3, eps=1e-12):    
    N = X.shape[0]
    
    Z1 = np.argmax(Z, axis=1)
    
    Z1 = check_latent_index_variable(Z1)
    Z2 = Z1
    
    codelen_x_z = 0.0
    codelen_z = 0.0
        
    for k in range(K):
        for l in range(L):
            n_pos = np.sum(X[Z1 == k, :][:, Z2 == l] == 1)
            n_neg = np.sum(X[Z1 == k, :][:, Z2 == l] == 0)
            n_all = n_pos + n_neg
            
            if theta[k, l] < eps:
                theta[k, l] = eps
            if theta[k, l] > 1.0 - eps:
                theta[k, l] = 1.0 - eps
            
            codelen_x_z += -n_pos * np.ma.log(theta[k, l]) - n_neg * np.ma.log(1.0 - theta[k, l])

            if n_all >=2:
                codelen_x_z += np.log(normterm_discrete(n_all, 2))
    
        n_k = np.sum(Z1 == k)
        codelen_z += -n_k * np.log(alpha[k])

    codelen_z += np.log(normterm_discrete(N, K))

    codelen = codelen_x_z + codelen_z

    return codelen, codelen_x_z, codelen_z

In [13]:
def calc_stats(X, #z, 
               scores, scores_f, scores_l, h, delta, K=10):
    scores = np.array(scores)
    scores_f = np.array(scores_f)
    scores_l = np.array(scores_l)
    
    K = scores.shape[2]
    N_trial = scores.shape[0]
    T = scores.shape[1]
    
    codelens = np.array([ codelen_integer(k) for k in range(1, K+1)])
    
    idxes_all = np.argmin(scores + np.tile(codelens[np.newaxis, np.newaxis], (N_trial, T, 1))[0, 0, :], axis=2)
    
    models_estimated = np.nan * np.ones((N_trial, T), dtype=np.float)
    models_former = np.nan * np.ones((N_trial, T), dtype=np.float)
    models_latter = np.nan * np.ones((N_trial, T), dtype=np.float)
    stats_complete = np.nan * np.ones((N_trial, T), dtype=np.float)
    
    for trial in range(scores.shape[0]):
        n_change = 0  # number of changes so far.
        for t in range(h, T-h):
            alpha= (n_change+1/2) / (t+1+1)
            m_estimated = idxes_all[trial, t]

            # Lv.3 change (Model change)
            stats_half_t = np.zeros((K, K), dtype=np.float)
            for k1 in range(K):
                stats_former = scores_f[trial, t, k1]
                for k2 in range(K):
                    if k1 == k2:
                        p = 1.0 - alpha
                    else:
                        p = alpha/(K-1)
                    stats_latter = scores_l[trial, t, k2]
                    stats_half_t[k1, k2] = (stats_former + stats_latter) + codelens[k1] - np.log(p)
            m_former_estimated, m_latter_estimated = np.unravel_index(np.nanargmin(stats_half_t), (K, K))
            models_former[trial, t] = m_former_estimated
            models_latter[trial, t] = m_latter_estimated
            
            stat = 0.5 / h *(scores[trial, t, m_estimated] + codelens[m_estimated] - stats_half_t[m_former_estimated, m_latter_estimated])
            stats_complete[trial, t] = stat
            
            if (m_estimated == m_latter_estimated):
                # no model change happened
                model_t = m_estimated
            else:
                model_t = m_latter_estimated
            
            if t >= 1:
                model_prev = models_estimated[trial, t-1]
                if model_t != model_prev:
                    n_change += 1            
            models_estimated[trial, t] = model_t
    
    return stats_complete, models_estimated, models_former, models_latter

In [14]:
def calc_stats_with_modelidx(scores, scores_f, scores_l, idxes_model, idxes_model_f, idxes_model_l, h):
    scores = np.array(scores)
    scores_f = np.array(scores_f)
    scores_l = np.array(scores_l)
    
    stats_complete = np.nan * np.ones((idxes_model.shape[0], idxes_model.shape[1]), dtype=np.float)
    for trial in range(idxes_model.shape[0]):
        for t in range(h, idxes_model.shape[1]-h):
            stat = 0.5/h * (scores[trial, t, int(idxes_model[trial, t])] - \
                            (scores_f[trial, t, int(idxes_model_f[trial, t])] + \
                            scores_l[trial, t, int(idxes_model_l[trial, t])] ))
            stats_complete[trial, t] = stat
            
    return stats_complete

In [15]:
def codelen_integer(k):
    codelen = np.log(2.865)
    k = np.log(k)
    while k >= 0.0:
        codelen += k
        k = np.log(k)
        
    return codelen

In [16]:
with open(os.path.join(indir, 'X_gradual.pkl'), 'rb') as f:
    X_all = pickle.load(f)
with open(os.path.join(indir, 'Z_gradual.pkl'), 'rb') as f:
    Z_true_all = pickle.load(f)

In [17]:
X_all.shape

(8, 90, 100, 100)

In [18]:
def estimate_sbm_each_trial(X, trial, K, T):
    pi_list_trial = []
    theta_list_trial = []
    z_list_trial = []
    
    for t in tqdm.tqdm(range(T)):
        seed = trial * T + t
        numpy2ri.activate()
        sbm = blockmodels.BM_bernoulli(membership_type="SBM", 
                                       adj=np.array(X[trial, t, :, :]),
                                       verbosity=0,
                                       exploration_factor=1.5,
                                       explore_min=K,
                                       explore_max=K)

        estimate = dollar(sbm, "estimate")
        estimate()

        pi_list = []
        theta_list = []
        z_posterior_list = []
        for k in range(K):
            n_clusters = k + 1
            theta = np.array(dollar(dollar(sbm, "model_parameters")[k], "pi"))
            z_posterior = np.array(dollar(dollar(sbm, "memberships")[k], "Z"))
            pi = np.sum(z_posterior, axis=0) + 10 * EPS
            pi /= np.sum(pi)

            theta_list.append(theta)
            z_posterior_list.append(z_posterior)
            pi_list.append(pi)

        numpy2ri.deactivate()
        
        pi_list_trial.append(pi_list)
        theta_list_trial.append(theta_list)
        z_list_trial.append(z_posterior_list)

    return pi_list_trial, theta_list_trial, z_list_trial

In [19]:
EPS = np.finfo(np.float).eps

N_trial = X_all.shape[0]
T = X_all.shape[1]

K = 10

pi1 = None
pi2 = None
a0 = 1.0
b0 = 1.0
ratio = 0.02

pi_all = []
theta_all = []
z_all = []

for trial in range(N_trial):
    pi1 = None
    pi2 = None
    theta = None
    
    pi_list_trial = []
    theta_list_trial = []
    z_list_trial = []
    
    numpy2ri.activate()    
    
    for t in tqdm.tqdm(range(T)):
        seed = trial*T + t
        
        X = X_all[trial, t, :, :]
            
        sbm = blockmodels.BM_bernoulli(membership_type="SBM", adj=np.array(X),
                                           verbosity=0,
                                           exploration_factor=1.5,
                                           explore_min=K,
                                           explore_max=K)

        estimate = dollar(sbm, "estimate")
        estimate()

        theta_list = []
        pi_list = []
        z_posterior_list = []
        for k in range(K):
            n_clusters = k + 1
            theta = np.array(dollar(dollar(sbm, "model_parameters")[k], "pi"))
            z_posterior = np.array(dollar(dollar(sbm, "memberships")[k], "Z"))
            pi = np.sum(z_posterior, axis=0) + 10 * EPS
            pi /= np.sum(pi)

            theta_list.append(theta)
            z_posterior_list.append(z_posterior)
            pi_list.append(pi)
        
        pi_list_trial.append(pi_list)
        theta_list_trial.append(theta_list)
        z_list_trial.append(z_posterior_list)
        
    numpy2ri.deactivate()

    pi_all.append(pi_list_trial)
    theta_all.append(theta_list_trial)
    z_all.append(z_list_trial)

    with open(os.path.join(outdir, 'pi_gradual.pkl'), 'wb') as f:
        pickle.dump(pi_all, f)
    with open(os.path.join(outdir, 'theta_gradual.pkl'), 'wb') as f:
        pickle.dump(theta_all, f)
    with open(os.path.join(outdir, 'z_gradual.pkl'), 'wb') as f:
        pickle.dump(z_all, f)

  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:10<16:14, 10.94s/it]




  2%|▏         | 2/90 [00:19<15:12, 10.37s/it]




  3%|▎         | 3/90 [00:28<14:22,  9.91s/it]




  4%|▍         | 4/90 [00:36<13:02,  9.10s/it]




  6%|▌         | 5/90 [00:43<12:02,  8.50s/it]




  7%|▋         | 6/90 [00:51<11:58,  8.56s/it]




  8%|▊         | 7/90 [00:59<11:26,  8.28s/it]




  9%|▉         | 8/90 [01:04<10:06,  7.40s/it]




 10%|█         | 9/90 [01:11<09:53,  7.32s/it]




 11%|█         | 10/90 [01:18<09:15,  6.95s/it]




 12%|█▏        | 11/90 [01:25<09:22,  7.12s/it]




 13%|█▎        | 12/90 [01:33<09:35,  7.37s/it]




 14%|█▍        | 13/90 [01:44<10:58,  8.56s/it]




 16%|█▌        | 14/90 [01:49<09:27,  7.47s/it]




 17%|█▋        | 15/90 [01:56<09:06,  7.29s/it]




 18%|█▊        | 16/90 [02:02<08:18,  6.74s/it]




 19%|█▉        | 17/90 [02:05<07:08,  5.87s/it]




 20%|██        | 18/90 [02:11<06:54,  5.76s/it]




 21%|██        | 19/90 [02:15<06:12,  5.24s/it]




 22%|██▏       | 20/90 [02:19<05:51,  5.03s/it]




 23%|██▎       | 21/90 [02:24<05:37,  4.88s/it]




 24%|██▍       | 22/90 [02:29<05:38,  4.97s/it]




 26%|██▌       | 23/90 [02:37<06:36,  5.92s/it]




 27%|██▋       | 24/90 [02:44<06:51,  6.23s/it]




 28%|██▊       | 25/90 [02:50<06:30,  6.00s/it]




 29%|██▉       | 26/90 [02:55<06:18,  5.92s/it]




 30%|███       | 27/90 [03:02<06:20,  6.03s/it]




 31%|███       | 28/90 [03:07<06:03,  5.86s/it]




 32%|███▏      | 29/90 [03:14<06:10,  6.08s/it]




 33%|███▎      | 30/90 [03:19<05:55,  5.93s/it]




 34%|███▍      | 31/90 [03:26<06:05,  6.20s/it]




 36%|███▌      | 32/90 [03:36<06:57,  7.20s/it]




 37%|███▋      | 33/90 [03:42<06:40,  7.02s/it]




 38%|███▊      | 34/90 [03:49<06:29,  6.95s/it]




 39%|███▉      | 35/90 [03:54<05:45,  6.28s/it]




 40%|████      | 36/90 [03:58<05:07,  5.69s/it]




 41%|████      | 37/90 [04:03<04:41,  5.31s/it]




 42%|████▏     | 38/90 [04:08<04:29,  5.19s/it]




 43%|████▎     | 39/90 [04:12<04:14,  5.00s/it]




 44%|████▍     | 40/90 [04:18<04:25,  5.31s/it]




 46%|████▌     | 41/90 [04:25<04:48,  5.89s/it]




 47%|████▋     | 42/90 [04:31<04:35,  5.74s/it]




 48%|████▊     | 43/90 [04:38<04:57,  6.33s/it]




 49%|████▉     | 44/90 [04:44<04:46,  6.23s/it]




 50%|█████     | 45/90 [04:50<04:29,  5.98s/it]




 51%|█████     | 46/90 [04:55<04:07,  5.62s/it]




 52%|█████▏    | 47/90 [05:03<04:40,  6.51s/it]




 53%|█████▎    | 48/90 [05:08<04:12,  6.02s/it]




 54%|█████▍    | 49/90 [05:13<03:57,  5.80s/it]




 56%|█████▌    | 50/90 [05:19<03:48,  5.71s/it]




 57%|█████▋    | 51/90 [05:25<03:45,  5.79s/it]




 58%|█████▊    | 52/90 [05:32<03:52,  6.11s/it]




 59%|█████▉    | 53/90 [05:36<03:30,  5.68s/it]




 60%|██████    | 54/90 [05:40<03:05,  5.16s/it]




 61%|██████    | 55/90 [05:45<03:00,  5.16s/it]




 62%|██████▏   | 56/90 [05:53<03:16,  5.79s/it]




 63%|██████▎   | 57/90 [06:00<03:29,  6.35s/it]




 64%|██████▍   | 58/90 [06:06<03:12,  6.02s/it]




 66%|██████▌   | 59/90 [06:11<03:02,  5.90s/it]




 67%|██████▋   | 60/90 [06:17<02:53,  5.78s/it]




 68%|██████▊   | 61/90 [06:21<02:32,  5.28s/it]




 69%|██████▉   | 62/90 [06:27<02:38,  5.67s/it]




 70%|███████   | 63/90 [06:31<02:19,  5.17s/it]




 71%|███████   | 64/90 [06:36<02:12,  5.11s/it]




 72%|███████▏  | 65/90 [06:43<02:18,  5.56s/it]




 73%|███████▎  | 66/90 [06:51<02:28,  6.19s/it]




 74%|███████▍  | 67/90 [06:59<02:39,  6.94s/it]




 76%|███████▌  | 68/90 [07:04<02:19,  6.33s/it]




 77%|███████▋  | 69/90 [07:09<02:05,  5.98s/it]




 78%|███████▊  | 70/90 [07:14<01:50,  5.54s/it]




 79%|███████▉  | 71/90 [07:20<01:47,  5.64s/it]




 80%|████████  | 72/90 [07:26<01:43,  5.74s/it]




 81%|████████  | 73/90 [07:32<01:40,  5.90s/it]




 82%|████████▏ | 74/90 [07:37<01:31,  5.74s/it]




 83%|████████▎ | 75/90 [07:43<01:24,  5.61s/it]




 84%|████████▍ | 76/90 [07:49<01:22,  5.88s/it]




 86%|████████▌ | 77/90 [07:55<01:14,  5.74s/it]




 87%|████████▋ | 78/90 [08:03<01:19,  6.60s/it]




 88%|████████▊ | 79/90 [08:07<01:01,  5.62s/it]




 89%|████████▉ | 80/90 [08:13<00:59,  5.94s/it]




 90%|█████████ | 81/90 [08:18<00:49,  5.48s/it]




 91%|█████████ | 82/90 [08:23<00:43,  5.43s/it]




 92%|█████████▏| 83/90 [08:28<00:36,  5.25s/it]




 93%|█████████▎| 84/90 [08:32<00:28,  4.76s/it]




 94%|█████████▍| 85/90 [08:36<00:23,  4.66s/it]




 96%|█████████▌| 86/90 [08:42<00:19,  4.99s/it]




 97%|█████████▋| 87/90 [08:48<00:16,  5.45s/it]




 98%|█████████▊| 88/90 [08:54<00:10,  5.43s/it]




 99%|█████████▉| 89/90 [09:00<00:05,  5.61s/it]




100%|██████████| 90/90 [09:08<00:00,  6.10s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:08<13:03,  8.81s/it]




  2%|▏         | 2/90 [00:14<11:43,  7.99s/it]




  3%|▎         | 3/90 [00:24<12:24,  8.55s/it]




  4%|▍         | 4/90 [00:30<10:51,  7.58s/it]




  6%|▌         | 5/90 [00:35<09:46,  6.89s/it]




  7%|▋         | 6/90 [00:43<10:18,  7.36s/it]




  8%|▊         | 7/90 [00:47<08:46,  6.34s/it]




  9%|▉         | 8/90 [00:54<08:52,  6.50s/it]




 10%|█         | 9/90 [00:58<07:33,  5.60s/it]




 11%|█         | 10/90 [01:04<07:57,  5.97s/it]




 12%|█▏        | 11/90 [01:11<08:14,  6.26s/it]




 13%|█▎        | 12/90 [01:17<08:02,  6.18s/it]




 14%|█▍        | 13/90 [01:23<07:49,  6.10s/it]




 16%|█▌        | 14/90 [01:30<07:50,  6.19s/it]




 17%|█▋        | 15/90 [01:34<07:04,  5.67s/it]




 18%|█▊        | 16/90 [01:38<06:14,  5.06s/it]




 19%|█▉        | 17/90 [01:44<06:42,  5.51s/it]




 20%|██        | 18/90 [01:50<06:49,  5.68s/it]




 21%|██        | 19/90 [01:59<07:51,  6.65s/it]




 22%|██▏       | 20/90 [02:09<08:45,  7.51s/it]




 23%|██▎       | 21/90 [02:18<09:14,  8.03s/it]




 24%|██▍       | 22/90 [02:27<09:21,  8.25s/it]




 26%|██▌       | 23/90 [02:40<10:49,  9.70s/it]




 27%|██▋       | 24/90 [02:45<09:04,  8.24s/it]




 28%|██▊       | 25/90 [02:54<09:07,  8.42s/it]




 29%|██▉       | 26/90 [02:58<07:49,  7.34s/it]




 30%|███       | 27/90 [03:04<07:09,  6.82s/it]




 31%|███       | 28/90 [03:10<06:48,  6.59s/it]




 32%|███▏      | 29/90 [03:22<08:20,  8.21s/it]




 33%|███▎      | 30/90 [03:29<07:46,  7.78s/it]




 34%|███▍      | 31/90 [03:40<08:40,  8.83s/it]




 36%|███▌      | 32/90 [03:47<07:51,  8.12s/it]




 37%|███▋      | 33/90 [03:54<07:24,  7.79s/it]




 38%|███▊      | 34/90 [04:06<08:27,  9.07s/it]




 39%|███▉      | 35/90 [04:12<07:38,  8.34s/it]




 40%|████      | 36/90 [04:19<07:00,  7.78s/it]




 41%|████      | 37/90 [04:24<06:17,  7.12s/it]




 42%|████▏     | 38/90 [04:29<05:26,  6.27s/it]




 43%|████▎     | 39/90 [04:36<05:28,  6.45s/it]




 44%|████▍     | 40/90 [04:42<05:17,  6.36s/it]




 46%|████▌     | 41/90 [04:50<05:38,  6.91s/it]




 47%|████▋     | 42/90 [04:55<05:02,  6.30s/it]




 48%|████▊     | 43/90 [05:02<05:07,  6.55s/it]




 49%|████▉     | 44/90 [05:10<05:19,  6.95s/it]




 50%|█████     | 45/90 [05:19<05:45,  7.68s/it]




 51%|█████     | 46/90 [05:29<06:00,  8.18s/it]




 52%|█████▏    | 47/90 [05:38<06:05,  8.50s/it]




 53%|█████▎    | 48/90 [05:42<05:01,  7.19s/it]




 54%|█████▍    | 49/90 [05:46<04:14,  6.20s/it]




 56%|█████▌    | 50/90 [05:50<03:38,  5.46s/it]




 57%|█████▋    | 51/90 [05:57<03:55,  6.05s/it]




 58%|█████▊    | 52/90 [06:05<04:15,  6.71s/it]




 59%|█████▉    | 53/90 [06:10<03:50,  6.24s/it]




 60%|██████    | 54/90 [06:17<03:45,  6.26s/it]




 61%|██████    | 55/90 [06:25<04:03,  6.95s/it]




 62%|██████▏   | 56/90 [06:33<04:05,  7.22s/it]




 63%|██████▎   | 57/90 [06:43<04:27,  8.11s/it]




 64%|██████▍   | 58/90 [06:50<04:04,  7.65s/it]




 66%|██████▌   | 59/90 [06:55<03:32,  6.84s/it]




 67%|██████▋   | 60/90 [06:58<02:56,  5.88s/it]




 68%|██████▊   | 61/90 [07:04<02:44,  5.68s/it]




 69%|██████▉   | 62/90 [07:12<03:03,  6.55s/it]




 70%|███████   | 63/90 [07:19<03:00,  6.68s/it]




 71%|███████   | 64/90 [07:25<02:46,  6.39s/it]




 72%|███████▏  | 65/90 [07:30<02:26,  5.85s/it]




 73%|███████▎  | 66/90 [07:36<02:28,  6.17s/it]




 74%|███████▍  | 67/90 [07:43<02:24,  6.27s/it]




 76%|███████▌  | 68/90 [07:48<02:10,  5.93s/it]




 77%|███████▋  | 69/90 [07:53<01:58,  5.62s/it]




 78%|███████▊  | 70/90 [07:59<01:52,  5.63s/it]




 79%|███████▉  | 71/90 [08:03<01:37,  5.13s/it]




 80%|████████  | 72/90 [08:07<01:27,  4.89s/it]




 81%|████████  | 73/90 [08:11<01:19,  4.67s/it]




 82%|████████▏ | 74/90 [08:16<01:14,  4.67s/it]




 83%|████████▎ | 75/90 [08:22<01:17,  5.14s/it]




 84%|████████▍ | 76/90 [08:28<01:13,  5.26s/it]




 86%|████████▌ | 77/90 [08:32<01:05,  5.07s/it]




 87%|████████▋ | 78/90 [08:40<01:11,  5.97s/it]




 88%|████████▊ | 79/90 [08:46<01:04,  5.83s/it]




 89%|████████▉ | 80/90 [08:50<00:53,  5.40s/it]




 90%|█████████ | 81/90 [08:55<00:47,  5.25s/it]




 91%|█████████ | 82/90 [08:59<00:39,  4.88s/it]




 92%|█████████▏| 83/90 [09:04<00:33,  4.84s/it]




 93%|█████████▎| 84/90 [09:11<00:34,  5.68s/it]




 94%|█████████▍| 85/90 [09:18<00:29,  5.92s/it]




 96%|█████████▌| 86/90 [09:28<00:28,  7.12s/it]




 97%|█████████▋| 87/90 [09:32<00:19,  6.35s/it]




 98%|█████████▊| 88/90 [09:37<00:11,  5.90s/it]




 99%|█████████▉| 89/90 [09:42<00:05,  5.58s/it]




100%|██████████| 90/90 [09:48<00:00,  6.54s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:04<06:41,  4.51s/it]




  2%|▏         | 2/90 [00:10<07:16,  4.96s/it]




  3%|▎         | 3/90 [00:14<06:50,  4.72s/it]




  4%|▍         | 4/90 [00:31<11:56,  8.34s/it]




  6%|▌         | 5/90 [00:38<11:25,  8.07s/it]




  7%|▋         | 6/90 [00:43<10:00,  7.14s/it]




  8%|▊         | 7/90 [00:49<09:24,  6.80s/it]




  9%|▉         | 8/90 [00:57<09:49,  7.19s/it]




 10%|█         | 9/90 [01:03<09:05,  6.73s/it]




 11%|█         | 10/90 [01:09<08:31,  6.39s/it]




 12%|█▏        | 11/90 [01:15<08:32,  6.49s/it]




 13%|█▎        | 12/90 [01:20<07:45,  5.97s/it]




 14%|█▍        | 13/90 [01:27<07:47,  6.07s/it]




 16%|█▌        | 14/90 [01:34<08:08,  6.43s/it]




 17%|█▋        | 15/90 [01:41<08:14,  6.59s/it]




 18%|█▊        | 16/90 [01:47<08:07,  6.59s/it]




 19%|█▉        | 17/90 [01:55<08:28,  6.97s/it]




 20%|██        | 18/90 [02:00<07:40,  6.40s/it]




 21%|██        | 19/90 [02:09<08:14,  6.97s/it]




 22%|██▏       | 20/90 [02:14<07:45,  6.64s/it]




 23%|██▎       | 21/90 [02:19<06:57,  6.06s/it]




 24%|██▍       | 22/90 [02:23<06:13,  5.49s/it]




 26%|██▌       | 23/90 [02:29<06:18,  5.65s/it]




 27%|██▋       | 24/90 [02:35<06:11,  5.63s/it]




 28%|██▊       | 25/90 [02:40<06:05,  5.62s/it]




 29%|██▉       | 26/90 [02:50<07:07,  6.68s/it]




 30%|███       | 27/90 [02:58<07:33,  7.20s/it]




 31%|███       | 28/90 [03:07<07:53,  7.64s/it]




 32%|███▏      | 29/90 [03:16<08:17,  8.15s/it]




 33%|███▎      | 30/90 [03:23<07:38,  7.64s/it]




 34%|███▍      | 31/90 [03:32<08:06,  8.25s/it]




 36%|███▌      | 32/90 [03:37<06:58,  7.21s/it]




 37%|███▋      | 33/90 [03:45<07:08,  7.51s/it]




 38%|███▊      | 34/90 [03:52<06:50,  7.32s/it]




 39%|███▉      | 35/90 [03:59<06:41,  7.30s/it]




 40%|████      | 36/90 [04:04<05:57,  6.62s/it]




 41%|████      | 37/90 [04:10<05:33,  6.29s/it]




 42%|████▏     | 38/90 [04:22<06:59,  8.06s/it]




 43%|████▎     | 39/90 [04:30<06:45,  7.96s/it]




 44%|████▍     | 40/90 [04:39<06:57,  8.35s/it]




 46%|████▌     | 41/90 [04:48<06:55,  8.48s/it]




 47%|████▋     | 42/90 [04:53<06:00,  7.50s/it]




 48%|████▊     | 43/90 [05:01<06:04,  7.75s/it]




 49%|████▉     | 44/90 [05:05<05:03,  6.60s/it]




 50%|█████     | 45/90 [05:10<04:35,  6.13s/it]




 51%|█████     | 46/90 [05:17<04:31,  6.17s/it]




 52%|█████▏    | 47/90 [05:25<04:47,  6.70s/it]




 53%|█████▎    | 48/90 [05:32<04:57,  7.07s/it]




 54%|█████▍    | 49/90 [05:38<04:31,  6.63s/it]




 56%|█████▌    | 50/90 [05:47<04:51,  7.30s/it]




 57%|█████▋    | 51/90 [05:54<04:39,  7.18s/it]




 58%|█████▊    | 52/90 [06:00<04:22,  6.92s/it]




 59%|█████▉    | 53/90 [06:07<04:13,  6.86s/it]




 60%|██████    | 54/90 [06:13<03:54,  6.51s/it]




 61%|██████    | 55/90 [06:23<04:27,  7.65s/it]




 62%|██████▏   | 56/90 [06:30<04:12,  7.43s/it]




 63%|██████▎   | 57/90 [06:35<03:42,  6.75s/it]




 64%|██████▍   | 58/90 [06:40<03:23,  6.34s/it]




 66%|██████▌   | 59/90 [06:48<03:32,  6.86s/it]




 67%|██████▋   | 60/90 [06:53<03:02,  6.10s/it]




 68%|██████▊   | 61/90 [06:59<02:57,  6.11s/it]




 69%|██████▉   | 62/90 [07:04<02:38,  5.67s/it]




 70%|███████   | 63/90 [07:09<02:30,  5.57s/it]




 71%|███████   | 64/90 [07:15<02:32,  5.87s/it]




 72%|███████▏  | 65/90 [07:21<02:22,  5.69s/it]




 73%|███████▎  | 66/90 [07:25<02:06,  5.27s/it]




 74%|███████▍  | 67/90 [07:30<02:00,  5.26s/it]




 76%|███████▌  | 68/90 [07:36<01:55,  5.27s/it]




 77%|███████▋  | 69/90 [07:42<01:56,  5.57s/it]




 78%|███████▊  | 70/90 [07:49<02:03,  6.17s/it]




 79%|███████▉  | 71/90 [07:57<02:03,  6.52s/it]




 80%|████████  | 72/90 [08:01<01:45,  5.88s/it]




 81%|████████  | 73/90 [08:11<02:01,  7.15s/it]




 82%|████████▏ | 74/90 [08:15<01:37,  6.09s/it]




 83%|████████▎ | 75/90 [08:26<01:54,  7.60s/it]




 84%|████████▍ | 76/90 [08:33<01:44,  7.46s/it]




 86%|████████▌ | 77/90 [08:37<01:22,  6.38s/it]




 87%|████████▋ | 78/90 [08:44<01:19,  6.65s/it]




 88%|████████▊ | 79/90 [08:55<01:26,  7.85s/it]




 89%|████████▉ | 80/90 [09:04<01:22,  8.25s/it]




 90%|█████████ | 81/90 [09:12<01:13,  8.20s/it]




 91%|█████████ | 82/90 [09:25<01:17,  9.75s/it]




 92%|█████████▏| 83/90 [09:33<01:04,  9.20s/it]




 93%|█████████▎| 84/90 [09:40<00:50,  8.36s/it]




 94%|█████████▍| 85/90 [09:49<00:42,  8.56s/it]




 96%|█████████▌| 86/90 [09:55<00:31,  7.78s/it]




 97%|█████████▋| 87/90 [10:01<00:22,  7.41s/it]




 98%|█████████▊| 88/90 [10:07<00:13,  6.83s/it]




 99%|█████████▉| 89/90 [10:11<00:05,  5.98s/it]




100%|██████████| 90/90 [10:15<00:00,  6.84s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:06<10:11,  6.87s/it]




  2%|▏         | 2/90 [00:11<09:16,  6.32s/it]




  3%|▎         | 3/90 [00:19<09:35,  6.62s/it]




  4%|▍         | 4/90 [00:25<09:32,  6.66s/it]




  6%|▌         | 5/90 [00:35<10:28,  7.39s/it]




  7%|▋         | 6/90 [00:42<10:14,  7.32s/it]




  8%|▊         | 7/90 [00:46<08:59,  6.50s/it]




  9%|▉         | 8/90 [00:51<08:16,  6.05s/it]




 10%|█         | 9/90 [00:57<08:01,  5.94s/it]




 11%|█         | 10/90 [01:05<08:37,  6.47s/it]




 12%|█▏        | 11/90 [01:09<07:46,  5.90s/it]




 13%|█▎        | 12/90 [01:15<07:25,  5.71s/it]




 14%|█▍        | 13/90 [01:22<08:05,  6.30s/it]




 16%|█▌        | 14/90 [01:30<08:21,  6.60s/it]




 17%|█▋        | 15/90 [01:40<09:34,  7.66s/it]




 18%|█▊        | 16/90 [01:45<08:31,  6.91s/it]




 19%|█▉        | 17/90 [01:53<08:54,  7.32s/it]




 20%|██        | 18/90 [01:59<08:20,  6.95s/it]




 21%|██        | 19/90 [02:13<10:42,  9.05s/it]




 22%|██▏       | 20/90 [02:18<09:10,  7.87s/it]




 23%|██▎       | 21/90 [02:26<09:09,  7.96s/it]




 24%|██▍       | 22/90 [02:35<09:21,  8.26s/it]




 26%|██▌       | 23/90 [02:45<09:46,  8.75s/it]




 27%|██▋       | 24/90 [02:52<09:06,  8.28s/it]




 28%|██▊       | 25/90 [02:56<07:25,  6.85s/it]




 29%|██▉       | 26/90 [03:03<07:13,  6.77s/it]




 30%|███       | 27/90 [03:10<07:27,  7.10s/it]




 31%|███       | 28/90 [03:16<06:54,  6.68s/it]




 32%|███▏      | 29/90 [03:24<07:14,  7.12s/it]




 33%|███▎      | 30/90 [03:32<07:23,  7.39s/it]




 34%|███▍      | 31/90 [03:37<06:34,  6.68s/it]




 36%|███▌      | 32/90 [03:45<06:37,  6.86s/it]




 37%|███▋      | 33/90 [03:49<05:56,  6.25s/it]




 38%|███▊      | 34/90 [03:54<05:22,  5.75s/it]




 39%|███▉      | 35/90 [04:02<05:50,  6.38s/it]




 40%|████      | 36/90 [04:08<05:35,  6.22s/it]




 41%|████      | 37/90 [04:13<05:13,  5.92s/it]




 42%|████▏     | 38/90 [04:19<05:04,  5.85s/it]




 43%|████▎     | 39/90 [04:24<04:49,  5.68s/it]




 44%|████▍     | 40/90 [04:29<04:37,  5.54s/it]




 46%|████▌     | 41/90 [04:35<04:30,  5.53s/it]




 47%|████▋     | 42/90 [04:43<05:06,  6.39s/it]




 48%|████▊     | 43/90 [04:48<04:44,  6.06s/it]




 49%|████▉     | 44/90 [04:54<04:37,  6.04s/it]




 50%|█████     | 45/90 [05:04<05:15,  7.02s/it]




 51%|█████     | 46/90 [05:12<05:27,  7.44s/it]




 52%|█████▏    | 47/90 [05:17<04:51,  6.79s/it]




 53%|█████▎    | 48/90 [05:28<05:31,  7.90s/it]




 54%|█████▍    | 49/90 [05:34<04:58,  7.28s/it]




 56%|█████▌    | 50/90 [05:40<04:42,  7.05s/it]




 57%|█████▋    | 51/90 [05:48<04:39,  7.16s/it]




 58%|█████▊    | 52/90 [05:56<04:42,  7.42s/it]




 59%|█████▉    | 53/90 [06:01<04:11,  6.81s/it]




 60%|██████    | 54/90 [06:08<04:04,  6.78s/it]




 61%|██████    | 55/90 [06:11<03:25,  5.88s/it]




 62%|██████▏   | 56/90 [06:20<03:52,  6.83s/it]




 63%|██████▎   | 57/90 [06:27<03:39,  6.65s/it]




 64%|██████▍   | 58/90 [06:32<03:23,  6.37s/it]




 66%|██████▌   | 59/90 [06:40<03:24,  6.58s/it]




 67%|██████▋   | 60/90 [06:44<03:02,  6.09s/it]




 68%|██████▊   | 61/90 [06:50<02:47,  5.78s/it]




 69%|██████▉   | 62/90 [06:54<02:27,  5.27s/it]




 70%|███████   | 63/90 [06:58<02:16,  5.05s/it]




 71%|███████   | 64/90 [07:02<02:01,  4.68s/it]




 72%|███████▏  | 65/90 [07:07<02:01,  4.87s/it]




 73%|███████▎  | 66/90 [07:13<02:06,  5.27s/it]




 74%|███████▍  | 67/90 [07:19<02:00,  5.23s/it]




 76%|███████▌  | 68/90 [07:23<01:50,  5.04s/it]




 77%|███████▋  | 69/90 [07:27<01:38,  4.67s/it]




 78%|███████▊  | 70/90 [07:32<01:38,  4.91s/it]




 79%|███████▉  | 71/90 [07:38<01:37,  5.13s/it]




 80%|████████  | 72/90 [07:41<01:22,  4.58s/it]




 81%|████████  | 73/90 [07:45<01:13,  4.31s/it]




 82%|████████▏ | 74/90 [07:53<01:25,  5.37s/it]




 83%|████████▎ | 75/90 [07:58<01:20,  5.39s/it]




 84%|████████▍ | 76/90 [08:02<01:09,  4.99s/it]




 86%|████████▌ | 77/90 [08:06<01:00,  4.63s/it]




 87%|████████▋ | 78/90 [08:13<01:04,  5.39s/it]




 88%|████████▊ | 79/90 [08:17<00:54,  4.92s/it]




 89%|████████▉ | 80/90 [08:23<00:50,  5.08s/it]




 90%|█████████ | 81/90 [08:28<00:46,  5.20s/it]




 91%|█████████ | 82/90 [08:33<00:40,  5.11s/it]




 92%|█████████▏| 83/90 [08:38<00:36,  5.21s/it]




 93%|█████████▎| 84/90 [08:43<00:29,  4.86s/it]




 94%|█████████▍| 85/90 [08:48<00:25,  5.12s/it]




 96%|█████████▌| 86/90 [08:56<00:24,  6.04s/it]




 97%|█████████▋| 87/90 [09:05<00:20,  6.86s/it]




 98%|█████████▊| 88/90 [09:15<00:15,  7.79s/it]




 99%|█████████▉| 89/90 [09:22<00:07,  7.44s/it]




100%|██████████| 90/90 [09:30<00:00,  6.34s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:08<12:44,  8.60s/it]




  2%|▏         | 2/90 [00:17<12:37,  8.60s/it]




  3%|▎         | 3/90 [00:26<12:47,  8.82s/it]




  4%|▍         | 4/90 [00:31<11:09,  7.79s/it]




  6%|▌         | 5/90 [00:43<12:43,  8.98s/it]




  7%|▋         | 6/90 [00:49<11:06,  7.93s/it]




  8%|▊         | 7/90 [00:54<09:56,  7.19s/it]




  9%|▉         | 8/90 [01:06<11:47,  8.63s/it]




 10%|█         | 9/90 [01:13<11:03,  8.19s/it]




 11%|█         | 10/90 [01:18<09:19,  7.00s/it]




 12%|█▏        | 11/90 [01:22<08:22,  6.36s/it]




 13%|█▎        | 12/90 [01:28<07:49,  6.01s/it]




 14%|█▍        | 13/90 [01:35<08:07,  6.33s/it]




 16%|█▌        | 14/90 [01:41<07:56,  6.28s/it]




 17%|█▋        | 15/90 [01:47<07:41,  6.16s/it]




 18%|█▊        | 16/90 [01:55<08:24,  6.82s/it]




 19%|█▉        | 17/90 [02:00<07:42,  6.34s/it]




 20%|██        | 18/90 [02:08<08:05,  6.74s/it]




 21%|██        | 19/90 [02:13<07:18,  6.18s/it]




 22%|██▏       | 20/90 [02:21<07:52,  6.74s/it]




 23%|██▎       | 21/90 [02:26<07:16,  6.33s/it]




 24%|██▍       | 22/90 [02:35<07:52,  6.95s/it]




 26%|██▌       | 23/90 [02:43<08:05,  7.25s/it]




 27%|██▋       | 24/90 [02:49<07:51,  7.14s/it]




 28%|██▊       | 25/90 [02:58<08:04,  7.45s/it]




 29%|██▉       | 26/90 [03:06<08:20,  7.82s/it]




 30%|███       | 27/90 [03:12<07:35,  7.23s/it]




 31%|███       | 28/90 [03:19<07:18,  7.08s/it]




 32%|███▏      | 29/90 [03:24<06:34,  6.47s/it]




 33%|███▎      | 30/90 [03:33<07:05,  7.10s/it]




 34%|███▍      | 31/90 [03:39<06:39,  6.77s/it]




 36%|███▌      | 32/90 [03:45<06:36,  6.83s/it]




 37%|███▋      | 33/90 [03:55<07:14,  7.61s/it]




 38%|███▊      | 34/90 [04:01<06:35,  7.06s/it]




 39%|███▉      | 35/90 [04:05<05:35,  6.11s/it]




 40%|████      | 36/90 [04:14<06:18,  7.01s/it]




 41%|████      | 37/90 [04:25<07:11,  8.15s/it]




 42%|████▏     | 38/90 [04:36<07:51,  9.07s/it]




 43%|████▎     | 39/90 [04:42<06:53,  8.11s/it]




 44%|████▍     | 40/90 [04:49<06:33,  7.86s/it]




 46%|████▌     | 41/90 [04:54<05:47,  7.10s/it]




 47%|████▋     | 42/90 [05:02<05:51,  7.32s/it]




 48%|████▊     | 43/90 [05:07<05:04,  6.49s/it]




 49%|████▉     | 44/90 [05:11<04:29,  5.86s/it]




 50%|█████     | 45/90 [05:16<04:06,  5.47s/it]




 51%|█████     | 46/90 [05:23<04:23,  5.98s/it]




 52%|█████▏    | 47/90 [05:32<04:53,  6.83s/it]




 53%|█████▎    | 48/90 [05:41<05:14,  7.49s/it]




 54%|█████▍    | 49/90 [05:49<05:16,  7.71s/it]




 56%|█████▌    | 50/90 [05:54<04:36,  6.91s/it]




 57%|█████▋    | 51/90 [06:00<04:22,  6.73s/it]




 58%|█████▊    | 52/90 [06:07<04:12,  6.66s/it]




 59%|█████▉    | 53/90 [06:15<04:27,  7.22s/it]




 60%|██████    | 54/90 [06:23<04:22,  7.28s/it]




 61%|██████    | 55/90 [06:35<05:10,  8.86s/it]




 62%|██████▏   | 56/90 [06:41<04:26,  7.85s/it]




 63%|██████▎   | 57/90 [06:49<04:19,  7.86s/it]




 64%|██████▍   | 58/90 [07:01<04:56,  9.25s/it]




 66%|██████▌   | 59/90 [07:07<04:17,  8.31s/it]




 67%|██████▋   | 60/90 [07:11<03:28,  6.94s/it]




 68%|██████▊   | 61/90 [07:17<03:13,  6.67s/it]




 69%|██████▉   | 62/90 [07:25<03:18,  7.08s/it]




 70%|███████   | 63/90 [07:35<03:38,  8.09s/it]




 71%|███████   | 64/90 [07:42<03:22,  7.77s/it]




 72%|███████▏  | 65/90 [07:50<03:10,  7.60s/it]




 73%|███████▎  | 66/90 [07:55<02:49,  7.08s/it]




 74%|███████▍  | 67/90 [08:02<02:39,  6.95s/it]




 76%|███████▌  | 68/90 [08:14<03:04,  8.37s/it]




 77%|███████▋  | 69/90 [08:19<02:34,  7.35s/it]




 78%|███████▊  | 70/90 [08:29<02:45,  8.27s/it]




 79%|███████▉  | 71/90 [08:34<02:19,  7.35s/it]




 80%|████████  | 72/90 [08:38<01:53,  6.32s/it]




 81%|████████  | 73/90 [08:43<01:36,  5.70s/it]




 82%|████████▏ | 74/90 [08:50<01:40,  6.30s/it]




 83%|████████▎ | 75/90 [08:56<01:31,  6.07s/it]




 84%|████████▍ | 76/90 [09:01<01:21,  5.82s/it]




 86%|████████▌ | 77/90 [09:06<01:10,  5.43s/it]




 87%|████████▋ | 78/90 [09:11<01:04,  5.35s/it]




 88%|████████▊ | 79/90 [09:15<00:56,  5.13s/it]




 89%|████████▉ | 80/90 [09:19<00:47,  4.71s/it]




 90%|█████████ | 81/90 [09:26<00:47,  5.26s/it]




 91%|█████████ | 82/90 [09:32<00:45,  5.70s/it]




 92%|█████████▏| 83/90 [09:47<00:57,  8.26s/it]




 93%|█████████▎| 84/90 [09:54<00:48,  8.08s/it]




 94%|█████████▍| 85/90 [10:02<00:39,  7.93s/it]




 96%|█████████▌| 86/90 [10:08<00:29,  7.42s/it]




 97%|█████████▋| 87/90 [10:16<00:22,  7.54s/it]




 98%|█████████▊| 88/90 [10:23<00:14,  7.35s/it]




 99%|█████████▉| 89/90 [10:27<00:06,  6.53s/it]




100%|██████████| 90/90 [10:32<00:00,  7.03s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:08<13:11,  8.90s/it]




  2%|▏         | 2/90 [00:15<12:13,  8.33s/it]




  3%|▎         | 3/90 [00:22<11:08,  7.69s/it]




  4%|▍         | 4/90 [00:27<09:57,  6.95s/it]




  6%|▌         | 5/90 [00:32<08:57,  6.33s/it]




  7%|▋         | 6/90 [00:38<08:44,  6.24s/it]




  8%|▊         | 7/90 [00:44<08:47,  6.35s/it]




  9%|▉         | 8/90 [00:49<08:05,  5.92s/it]




 10%|█         | 9/90 [00:57<08:52,  6.58s/it]




 11%|█         | 10/90 [01:03<08:20,  6.25s/it]




 12%|█▏        | 11/90 [01:07<07:24,  5.62s/it]




 13%|█▎        | 12/90 [01:11<06:43,  5.17s/it]




 14%|█▍        | 13/90 [01:20<08:01,  6.25s/it]




 16%|█▌        | 14/90 [01:29<09:01,  7.12s/it]




 17%|█▋        | 15/90 [01:34<08:06,  6.49s/it]




 18%|█▊        | 16/90 [01:43<08:55,  7.23s/it]




 19%|█▉        | 17/90 [01:49<08:15,  6.79s/it]




 20%|██        | 18/90 [01:53<07:20,  6.12s/it]




 21%|██        | 19/90 [02:00<07:25,  6.27s/it]




 22%|██▏       | 20/90 [02:04<06:31,  5.59s/it]




 23%|██▎       | 21/90 [02:11<06:47,  5.90s/it]




 24%|██▍       | 22/90 [02:22<08:29,  7.49s/it]




 26%|██▌       | 23/90 [02:29<08:24,  7.52s/it]




 27%|██▋       | 24/90 [02:37<08:08,  7.40s/it]




 28%|██▊       | 25/90 [02:44<08:05,  7.47s/it]




 29%|██▉       | 26/90 [02:53<08:32,  8.00s/it]




 30%|███       | 27/90 [02:59<07:43,  7.36s/it]




 31%|███       | 28/90 [03:08<07:55,  7.68s/it]




 32%|███▏      | 29/90 [03:17<08:22,  8.24s/it]




 33%|███▎      | 30/90 [03:24<07:55,  7.92s/it]




 34%|███▍      | 31/90 [03:33<07:52,  8.01s/it]




 36%|███▌      | 32/90 [03:41<07:59,  8.27s/it]




 37%|███▋      | 33/90 [03:48<07:21,  7.74s/it]




 38%|███▊      | 34/90 [03:52<06:14,  6.68s/it]




 39%|███▉      | 35/90 [03:57<05:36,  6.11s/it]




 40%|████      | 36/90 [04:03<05:24,  6.01s/it]




 41%|████      | 37/90 [04:08<05:01,  5.69s/it]




 42%|████▏     | 38/90 [04:13<04:53,  5.65s/it]




 43%|████▎     | 39/90 [04:18<04:34,  5.38s/it]




 44%|████▍     | 40/90 [04:24<04:34,  5.49s/it]




 46%|████▌     | 41/90 [04:30<04:36,  5.64s/it]




 47%|████▋     | 42/90 [04:35<04:31,  5.66s/it]




 48%|████▊     | 43/90 [04:44<05:00,  6.40s/it]




 49%|████▉     | 44/90 [04:52<05:22,  7.01s/it]




 50%|█████     | 45/90 [04:57<04:53,  6.52s/it]




 51%|█████     | 46/90 [05:02<04:18,  5.88s/it]




 52%|█████▏    | 47/90 [05:06<03:47,  5.28s/it]




 53%|█████▎    | 48/90 [05:13<04:13,  6.03s/it]




 54%|█████▍    | 49/90 [05:19<04:03,  5.94s/it]




 56%|█████▌    | 50/90 [05:27<04:22,  6.56s/it]




 57%|█████▋    | 51/90 [05:41<05:36,  8.64s/it]




 58%|█████▊    | 52/90 [05:47<05:04,  8.01s/it]




 59%|█████▉    | 53/90 [05:54<04:42,  7.64s/it]




 60%|██████    | 54/90 [06:02<04:38,  7.73s/it]




 61%|██████    | 55/90 [06:07<04:01,  6.91s/it]




 62%|██████▏   | 56/90 [06:13<03:44,  6.61s/it]




 63%|██████▎   | 57/90 [06:22<04:05,  7.44s/it]




 64%|██████▍   | 58/90 [06:29<03:53,  7.31s/it]




 66%|██████▌   | 59/90 [06:35<03:29,  6.76s/it]




 67%|██████▋   | 60/90 [06:41<03:17,  6.59s/it]




 68%|██████▊   | 61/90 [06:46<02:55,  6.05s/it]




 69%|██████▉   | 62/90 [06:52<02:49,  6.04s/it]




 70%|███████   | 63/90 [06:56<02:27,  5.46s/it]




 71%|███████   | 64/90 [07:03<02:39,  6.12s/it]




 72%|███████▏  | 65/90 [07:09<02:30,  6.02s/it]




 73%|███████▎  | 66/90 [07:14<02:18,  5.76s/it]




 74%|███████▍  | 67/90 [07:20<02:14,  5.83s/it]




 76%|███████▌  | 68/90 [07:27<02:10,  5.94s/it]




 77%|███████▋  | 69/90 [07:32<02:01,  5.79s/it]




 78%|███████▊  | 70/90 [07:36<01:44,  5.21s/it]




 79%|███████▉  | 71/90 [07:40<01:33,  4.91s/it]




 80%|████████  | 72/90 [07:44<01:22,  4.61s/it]




 81%|████████  | 73/90 [07:50<01:25,  5.05s/it]




 82%|████████▏ | 74/90 [07:56<01:24,  5.26s/it]




 83%|████████▎ | 75/90 [08:00<01:13,  4.92s/it]




 84%|████████▍ | 76/90 [08:06<01:14,  5.30s/it]




 86%|████████▌ | 77/90 [08:12<01:11,  5.50s/it]




 87%|████████▋ | 78/90 [08:16<01:01,  5.12s/it]




 88%|████████▊ | 79/90 [08:21<00:55,  5.05s/it]




 89%|████████▉ | 80/90 [08:25<00:46,  4.69s/it]




 90%|█████████ | 81/90 [08:30<00:41,  4.63s/it]




 91%|█████████ | 82/90 [08:34<00:35,  4.45s/it]




 92%|█████████▏| 83/90 [08:38<00:32,  4.57s/it]




 93%|█████████▎| 84/90 [08:45<00:30,  5.16s/it]




 94%|█████████▍| 85/90 [08:49<00:24,  4.83s/it]




 96%|█████████▌| 86/90 [08:54<00:19,  4.88s/it]




 97%|█████████▋| 87/90 [08:59<00:14,  4.81s/it]




 98%|█████████▊| 88/90 [09:06<00:10,  5.47s/it]




 99%|█████████▉| 89/90 [09:13<00:06,  6.14s/it]




100%|██████████| 90/90 [09:19<00:00,  6.21s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:08<12:32,  8.45s/it]




  2%|▏         | 2/90 [00:14<11:21,  7.74s/it]




  3%|▎         | 3/90 [00:19<10:13,  7.05s/it]




  4%|▍         | 4/90 [00:28<10:42,  7.47s/it]




  6%|▌         | 5/90 [00:34<09:49,  6.94s/it]




  7%|▋         | 6/90 [00:39<08:54,  6.37s/it]




  8%|▊         | 7/90 [00:45<08:49,  6.39s/it]




  9%|▉         | 8/90 [00:54<09:40,  7.08s/it]




 10%|█         | 9/90 [01:03<10:20,  7.66s/it]




 11%|█         | 10/90 [01:08<09:05,  6.81s/it]




 12%|█▏        | 11/90 [01:14<08:39,  6.57s/it]




 13%|█▎        | 12/90 [01:22<09:23,  7.23s/it]




 14%|█▍        | 13/90 [01:29<08:51,  6.90s/it]




 16%|█▌        | 14/90 [01:36<08:49,  6.96s/it]




 17%|█▋        | 15/90 [01:43<08:55,  7.14s/it]




 18%|█▊        | 16/90 [01:47<07:43,  6.26s/it]




 19%|█▉        | 17/90 [01:54<07:48,  6.41s/it]




 20%|██        | 18/90 [02:01<07:41,  6.41s/it]




 21%|██        | 19/90 [02:06<07:23,  6.25s/it]




 22%|██▏       | 20/90 [02:16<08:26,  7.23s/it]




 23%|██▎       | 21/90 [02:23<08:19,  7.23s/it]




 24%|██▍       | 22/90 [02:29<07:44,  6.84s/it]




 26%|██▌       | 23/90 [02:34<06:51,  6.15s/it]




 27%|██▋       | 24/90 [02:42<07:27,  6.79s/it]




 28%|██▊       | 25/90 [02:47<06:39,  6.15s/it]




 29%|██▉       | 26/90 [02:56<07:29,  7.02s/it]




 30%|███       | 27/90 [03:01<06:48,  6.49s/it]




 31%|███       | 28/90 [03:08<06:52,  6.65s/it]




 32%|███▏      | 29/90 [03:13<06:15,  6.16s/it]




 33%|███▎      | 30/90 [03:18<05:51,  5.86s/it]




 34%|███▍      | 31/90 [03:23<05:26,  5.53s/it]




 36%|███▌      | 32/90 [03:29<05:32,  5.73s/it]




 37%|███▋      | 33/90 [03:36<05:45,  6.06s/it]




 38%|███▊      | 34/90 [03:40<05:02,  5.41s/it]




 39%|███▉      | 35/90 [03:45<04:58,  5.43s/it]




 40%|████      | 36/90 [03:52<05:17,  5.89s/it]




 41%|████      | 37/90 [03:57<04:55,  5.57s/it]




 42%|████▏     | 38/90 [04:01<04:26,  5.12s/it]




 43%|████▎     | 39/90 [04:06<04:18,  5.07s/it]




 44%|████▍     | 40/90 [04:10<03:55,  4.70s/it]




 46%|████▌     | 41/90 [04:16<04:07,  5.06s/it]




 47%|████▋     | 42/90 [04:25<04:59,  6.24s/it]




 48%|████▊     | 43/90 [04:34<05:37,  7.18s/it]




 49%|████▉     | 44/90 [04:39<04:59,  6.51s/it]




 50%|█████     | 45/90 [04:46<05:00,  6.67s/it]




 51%|█████     | 46/90 [04:50<04:13,  5.75s/it]




 52%|█████▏    | 47/90 [04:54<03:42,  5.17s/it]




 53%|█████▎    | 48/90 [04:59<03:40,  5.24s/it]




 54%|█████▍    | 49/90 [05:08<04:26,  6.51s/it]




 56%|█████▌    | 50/90 [05:14<04:05,  6.14s/it]




 57%|█████▋    | 51/90 [05:21<04:09,  6.40s/it]




 58%|█████▊    | 52/90 [05:28<04:08,  6.53s/it]




 59%|█████▉    | 53/90 [05:32<03:41,  5.99s/it]




 60%|██████    | 54/90 [05:39<03:44,  6.23s/it]




 61%|██████    | 55/90 [05:45<03:36,  6.20s/it]




 62%|██████▏   | 56/90 [05:51<03:27,  6.10s/it]




 63%|██████▎   | 57/90 [05:58<03:25,  6.24s/it]




 64%|██████▍   | 58/90 [06:03<03:12,  6.01s/it]




 66%|██████▌   | 59/90 [06:07<02:50,  5.49s/it]




 67%|██████▋   | 60/90 [06:14<02:51,  5.72s/it]




 68%|██████▊   | 61/90 [06:21<02:55,  6.06s/it]




 69%|██████▉   | 62/90 [06:26<02:48,  6.02s/it]




 70%|███████   | 63/90 [06:31<02:30,  5.57s/it]




 71%|███████   | 64/90 [06:40<02:50,  6.57s/it]




 72%|███████▏  | 65/90 [06:47<02:49,  6.76s/it]




 73%|███████▎  | 66/90 [06:52<02:28,  6.20s/it]




 74%|███████▍  | 67/90 [07:00<02:32,  6.63s/it]




 76%|███████▌  | 68/90 [07:04<02:12,  6.03s/it]




 77%|███████▋  | 69/90 [07:11<02:14,  6.39s/it]




 78%|███████▊  | 70/90 [07:17<02:03,  6.15s/it]




 79%|███████▉  | 71/90 [07:20<01:39,  5.23s/it]




 80%|████████  | 72/90 [07:25<01:29,  5.00s/it]




 81%|████████  | 73/90 [07:29<01:21,  4.80s/it]




 82%|████████▏ | 74/90 [07:34<01:17,  4.87s/it]




 83%|████████▎ | 75/90 [07:39<01:14,  5.00s/it]




 84%|████████▍ | 76/90 [07:44<01:07,  4.83s/it]




 86%|████████▌ | 77/90 [07:47<00:58,  4.51s/it]




 87%|████████▋ | 78/90 [07:52<00:55,  4.60s/it]




 88%|████████▊ | 79/90 [07:57<00:51,  4.68s/it]




 89%|████████▉ | 80/90 [08:01<00:43,  4.38s/it]




 90%|█████████ | 81/90 [08:06<00:40,  4.49s/it]




 91%|█████████ | 82/90 [08:10<00:35,  4.44s/it]




 92%|█████████▏| 83/90 [08:16<00:33,  4.82s/it]




 93%|█████████▎| 84/90 [08:24<00:34,  5.79s/it]




 94%|█████████▍| 85/90 [08:28<00:27,  5.42s/it]




 96%|█████████▌| 86/90 [08:34<00:22,  5.51s/it]




 97%|█████████▋| 87/90 [08:39<00:15,  5.26s/it]




 98%|█████████▊| 88/90 [08:43<00:10,  5.12s/it]




 99%|█████████▉| 89/90 [08:51<00:05,  5.88s/it]




100%|██████████| 90/90 [08:56<00:00,  5.97s/it]
  0%|          | 0/90 [00:00<?, ?it/s]




  1%|          | 1/90 [00:08<12:58,  8.75s/it]




  2%|▏         | 2/90 [00:17<12:57,  8.84s/it]




  3%|▎         | 3/90 [00:22<10:56,  7.55s/it]




  4%|▍         | 4/90 [00:28<10:20,  7.22s/it]




  6%|▌         | 5/90 [00:36<10:20,  7.30s/it]




  7%|▋         | 6/90 [00:45<10:51,  7.76s/it]




  8%|▊         | 7/90 [00:54<11:20,  8.20s/it]




  9%|▉         | 8/90 [01:03<11:47,  8.63s/it]




 10%|█         | 9/90 [01:07<09:39,  7.16s/it]




 11%|█         | 10/90 [01:14<09:20,  7.00s/it]




 12%|█▏        | 11/90 [01:21<09:24,  7.14s/it]




 13%|█▎        | 12/90 [01:29<09:37,  7.40s/it]




 14%|█▍        | 13/90 [01:36<09:18,  7.25s/it]




 16%|█▌        | 14/90 [01:41<08:12,  6.47s/it]




 17%|█▋        | 15/90 [01:47<07:51,  6.29s/it]




 18%|█▊        | 16/90 [01:52<07:31,  6.11s/it]




 19%|█▉        | 17/90 [01:59<07:44,  6.36s/it]




 20%|██        | 18/90 [02:08<08:24,  7.01s/it]




 21%|██        | 19/90 [02:15<08:11,  6.92s/it]




 22%|██▏       | 20/90 [02:20<07:37,  6.54s/it]




 23%|██▎       | 21/90 [02:32<09:25,  8.20s/it]




 24%|██▍       | 22/90 [02:37<08:13,  7.26s/it]




 26%|██▌       | 23/90 [02:45<08:14,  7.38s/it]




 27%|██▋       | 24/90 [02:52<07:57,  7.24s/it]




 28%|██▊       | 25/90 [02:59<07:45,  7.16s/it]




 29%|██▉       | 26/90 [03:03<06:43,  6.31s/it]




 30%|███       | 27/90 [03:12<07:15,  6.91s/it]




 31%|███       | 28/90 [03:17<06:45,  6.54s/it]




 32%|███▏      | 29/90 [03:25<07:01,  6.92s/it]




 33%|███▎      | 30/90 [03:30<06:27,  6.46s/it]




 34%|███▍      | 31/90 [03:38<06:46,  6.89s/it]




 36%|███▌      | 32/90 [03:45<06:30,  6.74s/it]




 37%|███▋      | 33/90 [03:55<07:20,  7.73s/it]




 38%|███▊      | 34/90 [04:05<07:52,  8.44s/it]




 39%|███▉      | 35/90 [04:14<07:57,  8.68s/it]




 40%|████      | 36/90 [04:21<07:20,  8.15s/it]




 41%|████      | 37/90 [04:25<06:08,  6.96s/it]




 42%|████▏     | 38/90 [04:34<06:31,  7.53s/it]




 43%|████▎     | 39/90 [04:44<07:03,  8.30s/it]




 44%|████▍     | 40/90 [04:50<06:13,  7.48s/it]




 46%|████▌     | 41/90 [04:54<05:25,  6.64s/it]




 47%|████▋     | 42/90 [05:02<05:26,  6.80s/it]




 48%|████▊     | 43/90 [05:07<05:00,  6.39s/it]




 49%|████▉     | 44/90 [05:14<05:07,  6.69s/it]




 50%|█████     | 45/90 [05:22<05:19,  7.11s/it]




 51%|█████     | 46/90 [05:27<04:37,  6.32s/it]




 52%|█████▏    | 47/90 [05:33<04:24,  6.16s/it]




 53%|█████▎    | 48/90 [05:40<04:28,  6.39s/it]




 54%|█████▍    | 49/90 [05:46<04:23,  6.42s/it]




 56%|█████▌    | 50/90 [05:55<04:42,  7.07s/it]




 57%|█████▋    | 51/90 [06:02<04:38,  7.14s/it]




 58%|█████▊    | 52/90 [06:07<04:09,  6.56s/it]




 59%|█████▉    | 53/90 [06:11<03:35,  5.81s/it]




 60%|██████    | 54/90 [06:17<03:31,  5.88s/it]




 61%|██████    | 55/90 [06:23<03:23,  5.81s/it]




 62%|██████▏   | 56/90 [06:29<03:20,  5.91s/it]




 63%|██████▎   | 57/90 [06:38<03:42,  6.74s/it]




 64%|██████▍   | 58/90 [06:45<03:42,  6.94s/it]




 66%|██████▌   | 59/90 [06:51<03:28,  6.73s/it]




 67%|██████▋   | 60/90 [06:56<02:59,  6.00s/it]




 68%|██████▊   | 61/90 [07:02<02:57,  6.10s/it]




 69%|██████▉   | 62/90 [07:12<03:22,  7.23s/it]




 70%|███████   | 63/90 [07:16<02:50,  6.32s/it]




 71%|███████   | 64/90 [07:23<02:51,  6.60s/it]




 72%|███████▏  | 65/90 [07:29<02:36,  6.24s/it]




 73%|███████▎  | 66/90 [07:35<02:32,  6.35s/it]




 74%|███████▍  | 67/90 [07:40<02:13,  5.79s/it]




 76%|███████▌  | 68/90 [07:47<02:15,  6.18s/it]




 77%|███████▋  | 69/90 [07:56<02:24,  6.88s/it]




 78%|███████▊  | 70/90 [08:01<02:07,  6.39s/it]




 79%|███████▉  | 71/90 [08:05<01:49,  5.77s/it]




 80%|████████  | 72/90 [08:12<01:48,  6.04s/it]




 81%|████████  | 73/90 [08:16<01:35,  5.64s/it]




 82%|████████▏ | 74/90 [08:22<01:29,  5.62s/it]




 83%|████████▎ | 75/90 [08:26<01:16,  5.07s/it]




 84%|████████▍ | 76/90 [08:31<01:09,  4.99s/it]




 86%|████████▌ | 77/90 [08:35<01:02,  4.81s/it]




 87%|████████▋ | 78/90 [08:39<00:54,  4.56s/it]




 88%|████████▊ | 79/90 [08:43<00:47,  4.34s/it]




 89%|████████▉ | 80/90 [08:51<00:53,  5.34s/it]




 90%|█████████ | 81/90 [08:57<00:51,  5.73s/it]




 91%|█████████ | 82/90 [09:02<00:43,  5.38s/it]




 92%|█████████▏| 83/90 [09:05<00:33,  4.85s/it]




 93%|█████████▎| 84/90 [09:09<00:26,  4.35s/it]




 94%|█████████▍| 85/90 [09:16<00:27,  5.40s/it]




 96%|█████████▌| 86/90 [09:21<00:21,  5.26s/it]




 97%|█████████▋| 87/90 [09:30<00:18,  6.17s/it]




 98%|█████████▊| 88/90 [09:35<00:11,  5.86s/it]




 99%|█████████▉| 89/90 [09:43<00:06,  6.59s/it]




100%|██████████| 90/90 [09:50<00:00,  6.57s/it]


In [20]:
with open(os.path.join(outdir, 'pi_gradual.pkl'), 'rb') as f:
    pi_all = pickle.load(f)
with open(os.path.join(outdir, 'theta_gradual.pkl'), 'rb') as f:
    theta_all = pickle.load(f)
with open(os.path.join(outdir, 'z_gradual.pkl'), 'rb') as f:
    Z_all = pickle.load(f)

In [21]:
X_all.shape

(8, 90, 100, 100)

In [22]:
# relabeling
N_trial = X_all.shape[0]
K = 10
T = X_all.shape[1]

numpy2ri.activate()
for trial in tqdm.tqdm(range(N_trial)):
    for k in range(1, K):
        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(10)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(10)])
        )        
        permutations = np.array(dollar(run, "permutations"))
        for t in range(10):
            Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[t, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(9, 14)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(9, 14)])
        )
        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(9, 14)):
            if t == 9:
                for tt in range(10):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]
        
        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(13, 35)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(13, 35)])
        )
        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(13, 35)):
            if t == 13:
                for tt in range(14):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(34, 40)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(34, 40)])
        )
        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(34, 40)):
            if t == 34:
                for tt in range(35):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(39, 59)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(39, 59)])
        )
        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(39, 59)):
            if t == 39:
                for tt in range(40):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(58, 69)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(58, 69)])
        )
        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(58, 69)):
            if t == 58:
                for tt in range(59):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]

        run = label_switching.ecr_iterative_2(
            z=np.vstack([
                  np.argmax(Z_all[trial][t][k], axis=1).reshape(1, -1) for t in range(68, 90)
            ]) + 1,
            K=k+1,
            p=np.stack([Z_all[trial][t][k] for t in range(68, 90)])
        )
        permutations = np.array(dollar(run, "permutations"))
        for i, t in enumerate(range(68, 90)):
            if t == 68:
                for tt in range(69):
                    Z_all[trial][tt][k] = Z_all[trial][tt][k][:, permutations[i, :]-1]
            else:
                Z_all[trial][t][k] = Z_all[trial][t][k][:, permutations[i, :]-1]
            
numpy2ri.deactivate()

100%|██████████| 8/8 [00:24<00:00,  3.10s/it]


In [23]:
with open(os.path.join(outdir, 'z_gradual.pkl'), 'wb') as f:
    pickle.dump(Z_all, f)

In [24]:
n_trial = X_all.shape[0]
T = X_all.shape[1]
K = 10

for h in [1, 2, 3]:
    dnml_whole_list, nml_x_z_whole_list, nml_z_whole_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    dnml_former_list, nml_x_z_former_list, nml_z_former_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))
    dnml_latter_list, nml_x_z_latter_list, nml_z_latter_list = \
        np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K)), np.nan * np.ones((n_trial, T, K))

    for trial in tqdm.tqdm(range(n_trial)):
        for t in tqdm.tqdm(range(h, T-h)):
            for k in range(K):
                t_start = t-h
                t_end = t+h

                # whole
                n_whole = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)] 
                                  for tt in range(t_start, t_end)], axis=0)

                theta_hat_whole = np.sum(
                        [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                           for k2 in range(k+1)] for k1 in range(k+1)] 
                           for tt in range(t_start, t_end)], axis=0) / \
                        np.sum(
                            [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * 
                               np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                            for tt in range(t_start, t_end)], axis=0)
                theta_hat_whole = theta_hat_whole[n_whole !=0, :][:, n_whole !=0]
                pi_hat_whole = n_whole[n_whole !=0] /np.sum(n_whole)

                n_cluster_whole = len(pi_hat_whole)
                res_w = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_whole, theta_hat_whole,
                                                    n_cluster_whole, n_cluster_whole) 
                                  for tt in range(t_start, t_end)])

                dnml_whole = np.sum(res_w[:, 0])
                nml_x_z_whole = np.sum(res_w[:, 1])
                nml_z_whole = np.sum(res_w[:, 2])

                val = dnml_whole_list[trial, t, n_cluster_whole-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_whole < val) )):
                    dnml_whole_list[trial, t, n_cluster_whole-1] = dnml_whole
                    nml_x_z_whole_list[trial, t, n_cluster_whole-1] = nml_x_z_whole
                    nml_z_whole_list[trial, t, n_cluster_whole-1] = nml_z_whole

                # former
                n_former = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)] 
                                   for tt in range(t_start, t_start+h)], axis=0)

                theta_hat_former = np.sum(
                        [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                         for tt in range(t_start, t_start+h)], axis=0) / \
                        np.sum(
                            [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] 
                            for tt in range(t_start, t_start+h)], axis=0)

                theta_hat_former = theta_hat_former[n_former !=0, :][:, n_former !=0]
                pi_hat_former = n_former[n_former !=0] /np.sum(n_former)

                n_cluster_former = len(pi_hat_former)
                res_f = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_former, theta_hat_former,
                                                    n_cluster_former, n_cluster_former) 
                                  for tt in range(t_start, t_start+h)])

                dnml_former = np.sum(res_f[:, 0])
                nml_x_z_former = np.sum(res_f[:, 1])
                nml_z_former = np.sum(res_f[:, 2])

                val = dnml_former_list[trial, t, n_cluster_former-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_former < val) )):
                    dnml_former_list[trial, t, n_cluster_former-1] = dnml_former
                    nml_x_z_former_list[trial, t, n_cluster_former-1] = nml_x_z_former
                    nml_z_former_list[trial, t, n_cluster_former-1] = nml_z_former

                # latter
                n_latter = np.sum([[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == v) for v in range(k+1)]
                                   for tt in range(t_start+h, t_end)], axis=0)

                theta_hat_latter = np.sum(
                         [[[np.sum(X_all[trial][tt][np.argmax(Z_all[trial][tt][k], axis=1) == k1, :][:, np.argmax(Z_all[trial][tt][k], axis=1) == k2])
                            for k2 in range(k+1)] for k1 in range(k+1)] for tt in range(t_start+h, t_end)], axis=0) / \
                        np.sum(
                         [[[np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k1) * 
                            np.sum(np.argmax(Z_all[trial][tt][k], axis=1) == k2) 
                            for k2 in range(k+1)] for k1 in range(k+1)] for tt in range(t_start+h, t_end)], axis=0)

                theta_hat_latter = theta_hat_latter[n_latter != 0, :][:, n_latter !=0]
                pi_hat_latter = n_latter[n_latter != 0] /np.sum(n_latter)

                n_cluster_latter = len(pi_hat_latter)
                res_l = np.array([calc_dnml_by_prob(X_all[trial, tt, :, :], 
                                                    Z_all[trial][tt][k], 
                                                    pi_hat_latter, theta_hat_latter,
                                                    n_cluster_latter, n_cluster_latter) 
                                  for tt in range(t_start+h, t_end)])

                dnml_latter = np.sum(res_l[:, 0])
                nml_x_z_latter = np.sum(res_l[:, 1])
                nml_z_latter = np.sum(res_l[:, 2])

                val = dnml_latter_list[trial, t, n_cluster_latter-1]
                if (np.isnan(val) | ( (np.isfinite(val)) & (dnml_latter < val) )):
                    dnml_latter_list[trial, t, n_cluster_latter-1] = dnml_latter
                    nml_x_z_latter_list[trial, t, n_cluster_latter-1] = nml_x_z_latter
                    nml_z_latter_list[trial, t, n_cluster_latter-1] = nml_z_latter

    with open(os.path.join(outdir, 'dnml_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_whole_list, f)
    with open(os.path.join(outdir, 'dnml_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_former_list, f)
    with open(os.path.join(outdir, 'dnml_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(dnml_latter_list, f)

    with open(os.path.join(outdir, 'nml_x_z_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_whole_list, f)
    with open(os.path.join(outdir, 'nml_x_z_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_former_list, f)
    with open(os.path.join(outdir, 'nml_x_z_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_x_z_latter_list, f)

    with open(os.path.join(outdir, 'nml_z_all_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_whole_list, f)
    with open(os.path.join(outdir, 'nml_z_f_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_former_list, f)
    with open(os.path.join(outdir, 'nml_z_l_h' + str(h) + '.pkl'), 'wb') as f:
        pickle.dump(nml_z_latter_list, f)

  0%|          | 0/8 [00:00<?, ?it/s]
  0%|          | 0/88 [00:00<?, ?it/s][A
  1%|          | 1/88 [00:01<02:44,  1.89s/it][A
  2%|▏         | 2/88 [00:02<02:06,  1.47s/it][A
  3%|▎         | 3/88 [00:02<01:37,  1.15s/it][A
  5%|▍         | 4/88 [00:03<01:21,  1.04it/s][A
  6%|▌         | 5/88 [00:03<01:10,  1.18it/s][A
  7%|▋         | 6/88 [00:04<01:02,  1.31it/s][A
  8%|▊         | 7/88 [00:04<00:55,  1.47it/s][A
  9%|▉         | 8/88 [00:05<00:47,  1.67it/s][A
 10%|█         | 9/88 [00:05<00:46,  1.68it/s][A
 11%|█▏        | 10/88 [00:06<00:43,  1.78it/s][A
 12%|█▎        | 11/88 [00:06<00:37,  2.03it/s][A
 14%|█▎        | 12/88 [00:07<00:33,  2.26it/s][A
 15%|█▍        | 13/88 [00:07<00:32,  2.34it/s][A
 16%|█▌        | 14/88 [00:07<00:32,  2.28it/s][A
 17%|█▋        | 15/88 [00:08<00:29,  2.44it/s][A
 18%|█▊        | 16/88 [00:08<00:27,  2.59it/s][A
 19%|█▉        | 17/88 [00:08<00:25,  2.80it/s][A
 20%|██        | 18/88 [00:09<00:23,  2.97it/s][A
 22%|██▏   