# Imports

In [None]:
model_results_dir = '.'
reproducible = True  # True, False

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.optimize import linear_sum_assignment
from sklearn.metrics import confusion_matrix, adjusted_rand_score

if not os.path.isdir(model_results_dir):
    os.mkdir(model_results_dir)

%matplotlib inline

from mbcah import (
    sampling,
    plotting,
    two_steps_fit,
    model_selection,
)
from mbcah.utils import similarity_matrices

# Sampling data

In [None]:
N = 200  # number of curves
T = 100  # number of timesteps
K = 2  # number of clusters
S = np.arange(0, T // 4, 3)  # set of shifts

alpha_rho = 1.  # parameter of the dirichlet used to sample parameter rho
alpha_gamma = 1. # parameter of the dirichlet used to sample parameter gamma

sigma2_gp_mu = 1e-3  # controls the smoothness of the mu_k
sigma2_gp_sigma2 = 1e5  # controls the smoothness of the sigma2_k, set to high value for constant variances sigma2_jk = sigma2_k
s2 = 1e-1  # observation noise

# mask for missing data
n_intervals = 20
p_missing = .2

In [None]:
random_state = np.random.default_rng(1234) if reproducible else np.random
dtype = 'float64'

sampler = sampling.Sampler(
    N, T, K, S,
    alpha_rho,
    alpha_gamma,
    sigma2_gp_mu,
    sigma2_gp_sigma2,
    s2,
    random_state,
    dtype
)
sampler.sample()

mask = sampling.sample_intervals_mask(
    X=sampler.X,
    n_intervals=n_intervals,
    p_missing=p_missing,
    random_state=sampler.random_state
)
sampler.X[mask] = np.nan

plt.figure(figsize=(8, 4))
plt.plot(sampler.X.T)
plt.title('Observed curves')

plt.figure(figsize=(8, 4))
plt.plot(sampler.mu)
plt.title('Cluster means mu_jk')

plt.figure(figsize=(8, 4))
plt.plot(sampler.sigma2)
plt.title('Cluster variances sigma2_jk')

f, ax = plt.subplots(figsize=(8, 4))
sns.heatmap(mask, ax=ax, cbar=False)
ax.set_title('Mask of missing values');

sizes = sampler.z.sum(0)
f, ax = plt.subplots(sampler.K, figsize=(8, 4 * sampler.K))
for k in range(K):
    ax[k].plot(sampler.X[sampler.Z == k].T)
    ax[k].set_title(f'Observed curves in cluster {k} of size {sizes[k]}')

f, ax = plt.subplots(sampler.K, figsize=(4, 2 * sampler.K), tight_layout=True)
for k in range(sampler.K):
    ax[k].bar(sampler.S, sampler.gamma[k])
    ax[k].set_title(f'Distribution of shifts in cluster K{k} gamma_k');

# Model

## Build semi-supervision

Choose one between "Unsupervised" and "Semi-supervised"

### Unsupervised

In [None]:
C = None

### Semi-supervised

In [None]:
frac = .01
frac_noise = 0.
stratified = False
path_only  = True

In [None]:
C = (
    similarity_matrices.build_C_strat(sampler.Z, frac, frac_noise, path_only)
    if stratified
    else
    similarity_matrices.build_C(sampler.Z, frac, frac_noise)
)
sparsity = 100 * (C != 0).sum() / C.size
plt.hist(C[C != 0].flatten());
plt.title(f'Distribution of the nonzero entries of the constraint matrix C with {sparsity = :.2f}%');

## Fit model

### Set parameters

In [None]:
# range for the number of clusters
Ks = [1, 2, 3]  # number of clusters tested in model selection

# semi supervision
etas_min, etas_max, n_steps = 100., 1000., 5
etas = np.linspace(etas_min, etas_max, n_steps)
init_partition_with_C = False  # True, False
ss_mode = 'mixture'  # mixture, all
damping_factor = None  # None or float in (0., 1.)


# model parameters
constrained_sigma2 = 'sigma_k'  # sigma_jk, sigma_k
constrained_gamma = 'gamma_km'  # gamma, gamma_m, gamma_km

# priors
alpha_zero = 1.1  # set to 1.1 for ML estimation
beta_zero = 1.1  # set to 1.1 for ML estimation
prior_sigma2 = 'minimum_dof'  # minimum_dof, no_prior 

# EM algorithm
em_type_first_step = 'CEM'  # SEM, CEM, VEM, EM
em_type_second_step = 'VEM' # SEM, CEM, VEM, EM
n_init_first_step = 10
n_selected_inits = 1
n_init_second_step_per_selected_init = 5
n_loops_e_step_max = 1
full_sequential = True  # True, False
max_iter_em = 40
atol_iter_em = 1e-8
rtol_iter_em = 1e-5

# parallelism
n_parallel_runs = 1
n_processes = 1
joblib_parallel_backend = 'loky'

# initialization
n_init_km = 20
algo_init_partition = 'kmeans'
algo_init_tau = 'sim'  # naive, sim
coarsening_factor = 3
proba_init = .5
cluster_perturbation_rate = .3

# misc
dtype = 'float64'
min_float = 1e-30
shallow = False
write_params = True    
random_state = 123
debug = False
verbose = 0

### Run

In [None]:
nu_zero = (
    T + sampler.S.max() + 2
    if prior_sigma2 == 'minimum_dof'
    else
    None
)
etas = [0] if C is None else etas
model_results = two_steps_fit.parallel_fit_over_etas(
    X=sampler.X,
    Ks=Ks,
    S=S,
    etas=etas,
    C=C,
    n_init_first_step=n_init_first_step,
    n_selected_inits=n_selected_inits,
    n_init_second_step_per_selected_init=n_init_second_step_per_selected_init,
    n_loops_e_step_max=n_loops_e_step_max,
    em_type_first_step=em_type_first_step,
    em_type_second_step=em_type_second_step,
    ss_mode=ss_mode,
    damping_factor=damping_factor,
    full_sequential=full_sequential,
    constrained_gamma=constrained_gamma,
    constrained_sigma2=constrained_sigma2,
    alpha_zero=alpha_zero,
    beta_zero=beta_zero,
    nu_zero=nu_zero,
    algo_init_partition=algo_init_partition,
    algo_init_tau=algo_init_tau,
    coarsening_factor=coarsening_factor,
    cluster_perturbation_rate=cluster_perturbation_rate,
    proba_init=proba_init,
    n_init_km=n_init_km,
    atol_iter_em=atol_iter_em,
    rtol_iter_em=rtol_iter_em,
    max_iter_em=max_iter_em,
    min_float=min_float,
    n_processes=n_processes,
    n_parallel_runs=n_parallel_runs,
    joblib_parallel_backend=joblib_parallel_backend,
    verbose=verbose,
    dtype=dtype,
    debug=debug,
    shallow=shallow,
    write_params=write_params,
    model_results_dir=model_results_dir,
)

## Convergence and criterions

In [None]:
em_types = set([mr.em_type for mr in model_results])
semi_supervised = C is not None and etas.size > 0
n_variable_params = len(etas) * len(Ks)

f, ax = plt.subplots(
    n_variable_params, len(em_types), 
    figsize=(6 * len(em_types), 3 * n_variable_params), 
    sharey=False, squeeze=False, tight_layout=True
)

i = 0
for k in Ks:
    for eta in etas:
        for j, em_type in enumerate(em_types):
            for mr in model_results:
                if mr.em_type == em_type and mr.K == k and (not mr.semi_supervised or mr.eta == eta):
                    ax[i, j].plot(mr.iter_criterions)
                    ax[i, j].set_title(f'K = {k}, {eta = :.2f}\n{em_type} criterions')
        i += 1

In [None]:
plotting.plot_similarity_scores(model_results)

## Model selection

In [None]:
min_R = .0

In [None]:
crit1, crit2 = 'bic', 'scw'
df_res = model_selection.get_metrics_df(model_results)
df_res = model_selection.assign_pareto_efficient(df_res, crit1, crit2)

fs = 8
fs_selcted = 8
scatter_size = 50.
selected_size = 800.

y_offset = .0002
lw = 2.

if not semi_supervised:
    K_found = Ks[np.argmax(df_res.groupby('K').agg({crit1: 'max'}).values)]
    df_bic = df_res.query("K == @K_found")

    mr_id = df_bic.iloc[[np.argmax(df_bic[crit1].values)]].mr_id.values[0]
    model_result = [mr for mr in model_results if mr.mr_id == mr_id].pop()
else:
    df_res_filtered = (
        df_res
        .query("K_pareto_efficient")
        .query("scw >= @min_R")
    )
    assert df_res_filtered.size >= 2, 'value of min_R is too high'
    X_crits_eff = df_res_filtered[[crit1, crit2]].values
    selected_point = model_selection.heuristic_model_selection_on_pareto_front(
        X_crits_eff, method='farthest'
    )
    selected_mr_id = df_res_filtered.iloc[selected_point].mr_id
    model_result = [
        mr
        for mr in model_results
        if mr.mr_id == df_res_filtered.iloc[selected_point].mr_id
    ].pop()

    d = df_res_filtered.query("mr_id == @selected_mr_id")
    x_plot, y_plot, k_plot = d[crit1], d[crit2], k
    first_model = {}

    f, ax = plt.subplots(1, 1, figsize=(10, 5))
    n_colors = np.unique(df_res_filtered['K']).size
    cmap = sns.color_palette('colorblind', n_colors, as_cmap=True)

    for k, g in df_res_filtered.groupby('K'):
        ax.scatter(
            g[crit1], g[crit2],
            s=scatter_size, alpha=.6, color=cmap[k % len(cmap)], zorder=1
        )  
        x, y = g[crit1].values, g[crit2].values
        order = np.argsort(x)
        first_model[k] = g.mr_id.values[order][0]
        ax.plot(
            x[order], y[order],
            color=cmap[k % len(cmap)], zorder=1, lw=lw
        )
    ax.scatter(
        x_plot, y_plot,
        s=2 * scatter_size, alpha=.9, color=cmap[k_plot % len(cmap)],
        edgecolor='k', linewidth=1.2, zorder=2, marker='D'
    )

    for i, (x, y, k, eta, mr_id) in enumerate(zip(
            df_res_filtered[crit1],
            df_res_filtered[crit2],
            df_res_filtered.K,
            df_res_filtered.eta,
            df_res_filtered.mr_id
        )):
        if i == selected_point:
            text = (
                'K=' + str(k) + '\n' + 
                'eta=' + str(eta).split('.')[0]
            )
            ax.text(
                x, y + y_offset, text, fontsize=fs_selcted,
                verticalalignment='bottom', horizontalalignment='left', zorder=2, weight='bold'
            )
        if mr_id == first_model[k]:
            text = 'K=' + str(k)
            ax.text(
                x, y + y_offset, text, fontsize=fs,
                verticalalignment='center', horizontalalignment='center', zorder=2
            )

    ax.set_xlabel('BIC');
    ax.set_ylabel('R');
    
model_result.load_shallow_from_pickle()
if not model_result.write_params:
    print('Model is empty')

print(f'Found K = {model_result.K} and eta = {model_result.eta:.2f}')

# Metrics

In [None]:
if C is not None:
    if (C > 0.).any():
        _, R = similarity_matrices.similarity_concordance(model_result.Z, np.clip(C, 0, None))
        print(f'ML {R = :.4f}')
    if (C < 0.).any():
        _, R = similarity_matrices.similarity_concordance(model_result.Z, np.clip(C, None, 0))
        print(f'CL {R = :.4f}')
    _, R = similarity_matrices.similarity_concordance(model_result.Z, C)
    print(f'ALL {R = :.4f}')
    print()

ari = adjusted_rand_score(sampler.Z, model_result.Z)
mae = np.abs(sampler.tau - model_result.tau).mean()
cmat = confusion_matrix(sampler.Z, model_result.Z)
row_ind, col_ind = linear_sum_assignment(- cmat)
cmat = cmat[np.ix_(row_ind, col_ind)]
print(f'ARI = {ari}')
print(f'MAE tau = {mae}')
print('clustering confusion matrix: ')
print('C_ij is equal to the number of observations known to be in group i and predicted to be in group j.')
print('C = ')
print(cmat)
print()

# Plots

In [None]:
all_plots = [
    ('X_unshifted', ),
    ('X_shifted', ),
    ('mu', 'sigma2'),
    ('shifts', )
]
subfigsize = (5, 3)
n_sigma = 3

In [None]:
variance_reduction_factors = model_selection.compute_variance_reduction_factors(
    model_result.Z, model_result.K, model_result.tau_inds,
    model_result.X, model_result.S,
    model_result.constrained_sigma2, model_result.sigma2_zero, model_result.nu_zero,
    model_result.min_den, model_result.min_float
)

plotting.combined_plot(
    all_plots=all_plots,
    X=model_result.X,
    S=model_result.S,
    tau=model_result.tau,
    VS_tau=model_result.VS_tau,
    Z=model_result.Z,
    mu=model_result.mu,
    sigma2=model_result.sigma2,
    n_sigma=n_sigma,
    variance_reduction_factors=variance_reduction_factors,
    subfigsize=subfigsize,
)