In [38]:
from pathlib import Path
import os,sys
import pickle
import pandas as pd
import numpy as np
import importlib
import yaml
from tqdm.notebook import tqdm
from itertools import product

sys.path.insert(0, str(Path().resolve().parents[1]))

from gower import gower_matrix

import fusemix.clustering as clust_utils 
import fusemix.mige as migeClust
from fusemix.mica import compute_MICA
from fusemix.mixture_missing import run_mghm, run_mcnm

from fusemix.evaluation_metrics import *

importlib.reload(migeClust)
importlib.reload(clust_utils)

from scipy.sparse import csr_matrix

import warnings
warnings.filterwarnings('ignore')
import seaborn as sns 


In [11]:
def read_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)
    
def write_pickle(var, path):
    with open(path, 'wb') as f:
       pickle.dump(var, f)

In [12]:
with open("../../test_data/simulation_config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

In [13]:
for conf in product(
    cfg['dataset_ids'],
    cfg['md_param_grid']['props'],
    cfg['md_param_grid']['mf_proportions'],
    cfg['md_param_grid']['mnar_proportions'],
    range(cfg['n_runs'])):

    print(conf)

(33, 0.75, 0.75, 0.5, 0)
(33, 0.75, 0.75, 0.5, 1)
(33, 0.75, 0.75, 0.5, 2)
(33, 0.75, 0.75, 0.5, 3)
(33, 0.75, 0.75, 0.5, 4)
(33, 0.75, 0.75, 0.5, 5)
(33, 0.75, 0.75, 0.5, 6)
(33, 0.75, 0.75, 0.5, 7)
(33, 0.75, 0.75, 0.5, 8)
(33, 0.75, 0.75, 0.5, 9)


In [110]:
test_data_complete = read_pickle("../../test_data/fetched/dataset_17.pkl")
test_data_missing = read_pickle("../../test_data/missing_data/17/0.75_0.75_0.5/data_pipeline_0.pkl")
test_data = read_pickle("../../test_data/imputed_data/17/0.75_0.75_0.5/data_imputed_0.pkl")

In [111]:

incomplete_data = test_data_missing.amputer.incomplete_dataset
complete_data = test_data_complete['X_complete']
true_labels = test_data_complete['y_complete'].values.flatten()
cat_mask = test_data_complete['cat_mask']
num_classes = test_data_complete['num_classes']
multiple_imputed_data = test_data
seed = 0

In [112]:
mige_labels = migeClust.mige(
    multiple_imputed_data,
    n_clusters=num_classes,
    cat_mask=cat_mask,
    seed=seed,
    p_min = 0.75,
    p_max = 1,
    num_projections = 1,
    k_nn = 20,
    co_threshold = 0.5
    )

PERFORMING FINAL CONSENSUS


In [113]:
mica_labels = compute_MICA(
    multiple_imputed_data,
    num_clusters=num_classes,
    seed=seed
)

In [114]:
kpod_labels = clust_utils.compute_kpod(
    incomplete_data,
    num_clusters=num_classes,
    seed=seed
    )

In [115]:
try:
    mghm_labels = run_mcnm(
        incomplete_data,
        G=num_classes,
        seed=seed
        )
except RuntimeError:
    mghm_labels = None 

try:
    mcnm_labels = run_mcnm(
        incomplete_data,
        G=num_classes,
        seed=seed
        )
except RuntimeError:
    mcnm_labels = None 



In [116]:
sc_si_knn_labels = clust_utils.compute_spectral_si_knn(
    incomplete_data,
    seed=seed,
    num_clusters=num_classes,
    cat_mask=cat_mask
)
sc_si_mi_labels = clust_utils.compute_spectral_si_mi(
    multiple_imputed_data,
    seed=seed,
    num_clusters=num_classes,
    cat_mask=cat_mask
)
km_si_knn_labels = clust_utils.compute_kmeans_si_knn(
    incomplete_data,
    num_clusters=num_classes,
    seed=seed
)
km_si_mi_labels = clust_utils.compute_kmeans_si_mi(
    multiple_imputed_data,
    num_clusters=num_classes,
    seed=seed
)

In [117]:
cca_spectral_labels = clust_utils.compute_spectral_complete(
    complete_data,
    cat_mask=cat_mask,
    num_clusters=num_classes,
    seed=seed
)

cca_kmeans_labels = clust_utils.compute_kmeans_complete(
    complete_data,
    num_clusters=num_classes,
    seed=seed
)

In [118]:
predicted_labels = {
    'mige': mige_labels,
    'mica': mica_labels,
    'kpod': kpod_labels,
    'mcnm': mcnm_labels,
    'mghm': mghm_labels,
    'sc_knn': sc_si_knn_labels,
    'sc_mi': sc_si_mi_labels,
    'km_knn': km_si_knn_labels,
    'km_si': km_si_mi_labels
}

In [119]:
int_metrics = dict.fromkeys(predicted_labels.keys())
ext_metrics = dict.fromkeys(predicted_labels.keys())
ext_metrics_cca_sc = dict.fromkeys(predicted_labels.keys())
ext_metrics_cca_km = dict.fromkeys(predicted_labels.keys())

In [120]:
for method,comp in zip(predicted_labels.keys(),predicted_labels.values()):
    try:
        int_metrics[method] = internal_metrics(comp, complete_data, cat_mask)
        ext_metrics[method] = external_metrics(true_labels, comp)
        ext_metrics_cca_sc[method] = external_metrics(cca_spectral_labels, comp)
        ext_metrics_cca_km[method] = external_metrics(cca_kmeans_labels, comp)

    except:
        int_metrics[method] = np.nan
        ext_metrics[method] = np.nan
        ext_metrics_cca_sc[method] = np.nan
        ext_metrics_cca_km[method] = np.nan

        

In [121]:
pd.DataFrame(ext_metrics)

Unnamed: 0,mige,mica,kpod,mcnm,mghm,sc_knn,sc_mi,km_knn,km_si
ari,0.792122,0.481057,0.069728,0.623275,0.623275,0.594111,0.61711,0.481057,0.48623
ami,0.701832,0.455904,0.109172,0.541416,0.541416,0.558722,0.555607,0.455904,0.459944
vm,0.702238,0.456708,0.111064,0.542058,0.542058,0.559354,0.556235,0.456708,0.46074
cs,0.715835,0.510052,0.253062,0.569823,0.569823,0.601878,0.590782,0.510052,0.513419


In [122]:
pd.DataFrame(ext_metrics_cca_sc)

Unnamed: 0,mige,mica,kpod,mcnm,mghm,sc_knn,sc_mi,km_knn,km_si
ari,0.725082,0.64915,0.167913,0.635001,0.635001,0.829169,0.899159,0.64915,0.65569
ami,0.660368,0.492618,0.168669,0.483933,0.483933,0.704672,0.81322,0.492618,0.499313
vm,0.660864,0.493427,0.170654,0.484711,0.484711,0.705127,0.813504,0.493427,0.50011
cs,0.62836,0.510892,0.348781,0.474275,0.474275,0.705127,0.803677,0.510892,0.516741


In [123]:
pd.DataFrame(int_metrics)

Unnamed: 0,mige,mica,kpod,mcnm,mghm,sc_knn,sc_mi,km_knn,km_si
sh,0.404314,0.385033,0.322172,0.370765,0.370765,0.404535,0.415979,0.385033,0.386055
ch,695.825636,1295.919101,183.588836,934.34722,934.34722,981.511427,650.959774,1295.919101,1294.415921
db,0.698456,0.503305,0.619719,0.60188,0.60188,0.583313,0.720567,0.503305,0.504381


In [124]:
pd.DataFrame(ext_metrics_cca_km)

Unnamed: 0,mige,mica,kpod,mcnm,mghm,sc_knn,sc_mi,km_knn,km_si
ari,0.598008,0.98464,0.207421,0.762409,0.762409,0.752249,0.604355,0.98464,0.977029
ami,0.526874,0.962115,0.193563,0.672164,0.672164,0.614442,0.452212,0.962115,0.941039
vm,0.527585,0.962177,0.195586,0.672673,0.672673,0.615055,0.453071,0.962177,0.941136
cs,0.487687,0.966022,0.382029,0.639315,0.639315,0.597019,0.434633,0.966022,0.943001
