In [None]:
import pandas as pd
from scipy import stats
import itertools

In [None]:
import os
import sys
from glob import glob
sys.path.append(os.path.join("..",".."))
from discotec import *

In [None]:
import seaborn as sns

import matplotlib.pyplot as plt
import scienceplots

plt.style.use('science')

# Generating partitions

In [None]:
@partial(jax.jit, static_argnums=[1, 2])
def generate_reference_partition(random_key, n_samples, n_clusters):
    y_true = jax.random.choice(random_key, n_clusters, shape=(n_samples,))
    return y_true

In [None]:
@partial(jax.jit, static_argnums=[3])
def generate_fixedK_partition(key, reference, conservation_prob, n_clusters):
    relabelling_key, new_cluster_key = jax.random.split(key, 2)
    to_conserve = jax.random.bernoulli(relabelling_key, p=conservation_prob, shape=reference.shape)
    # To keep the accuracy between expected bounds, we make sure that
    # new_clusters is always different from the reference partition
    new_clusters = jax.random.choice(new_cluster_key, n_clusters-1, shape=reference.shape)
    new_clusters = (new_clusters+reference+1)%n_clusters
    

    return reference*to_conserve+new_clusters*(1-to_conserve)

In [None]:
@partial(jax.jit, static_argnums=[1,2,3,4,5])
def generate_scenario(random_key, n_samples, n_models, n_clusters, min_accuracy, max_accuracy):
    reference_key, switch_key, models_key = jax.random.split(random_key, 3)

    # We start by generate the labels of this scenario
    y_true = generate_reference_partition(reference_key, n_samples, n_clusters)

    if min_accuracy==max_accuracy:
        conservation_probs = jnp.ones(n_models)*(1-min_accuracy)
    else:
        conservation_probs = jax.random.uniform(switch_key, minval=min_accuracy, maxval=max_accuracy, shape=(n_models,))

    model_sampler = jax.vmap(generate_fixedK_partition, in_axes=[0, None, 0, None])
    models_keys = jax.random.split(models_key, n_models)
    y_pred = model_sampler(models_keys, y_true, conservation_probs, n_clusters)

    return y_true, y_pred

# Run the simulations

In [None]:
n_samples = 200
n_models = 5
n_runs = 50
n_clusters = 10
max_accuracies = [0.2,0.5,0.9]
min_accuracy = 0.1

In [None]:
all_scores = []
for M in jnp.arange(20, 41, step=4):
    master_key = jax.random.key(0)
    n_models = 5*M.item()
    filename =  f"bounded_accuracy_results_{n_models}.csv"
    for max_accuracy in max_accuracies:
        print(max_accuracy)
        if os.path.exists(filename):
            continue
        for i in range(n_runs):
            print(i, end=" ")
            # Split the random key
            master_key, dataset_key = jax.random.split(master_key)
        
            # Generate the scenario
            y_true, y_pred = generate_scenario(dataset_key, n_samples, n_models, n_clusters, min_accuracy, max_accuracy)
        
            # Compute the consensus matrix
            centroid = compute_consensus_matrix(y_pred)
        
            # Evaluate all metrics
            ## External validity index
            true_aris = jnp.array([metrics.adjusted_rand_score(y_true, y) for y in y_pred])
        
            # Connectivity based index
            ## Notice that we negate to get a metric to maximise (instead of minimising)
            tv_ranking_scores = -compute_tv_ranking(y_pred, centroid)
            hellinger_ranking_scores = -compute_hellinger_ranking(y_pred, centroid)
            kl_ranking_scores = -compute_kl_ranking(y_pred, centroid)
    
            quantised_centroid = (centroid>centroid.mean()).astype(float)
            quantised_scores = -compute_tv_ranking(y_pred, quantised_centroid)
            
            pairwise_ari_scores = pairwise_score(y_pred)
            pairwise_nmi_scores = pairwise_score(y_pred, method="nmi")
        
            for name, scores in zip(["DISCO_TV","DISCO_KL","DISCO_H","AARI","ANMI", "DISCO_Q"],
                                    [tv_ranking_scores, kl_ranking_scores, hellinger_ranking_scores,
                                    pairwise_ari_scores, pairwise_nmi_scores, quantised_scores]):
                for corr_name, corr_fct in zip(["Pearson", "Spearman", "Kendall"], [stats.pearsonr, stats.spearmanr, stats.kendalltau]):
                    all_scores += [{
                        "Score":name,
                        "Correlation":corr_name,
                        "Value":corr_fct(true_aris, scores).statistic,
                        "Run":i,
                        "Max_acc":max_accuracy
                    }]
    if not os.path.exists(filename):
        df = pd.DataFrame(all_scores)
        df.to_csv(filename)

In [None]:
all_dataframes = []
for filename in glob("*.csv"):
    n_models = int(filename.split("_")[-1][:-4])
    df = pd.read_csv(filename, index_col=0)
    df["M"] = n_models
    all_dataframes += [df]
df = pd.concat(all_dataframes)
filtered_df = df[~df.Score.isin(["DISCO_TV", "DISCO_H"])].replace({"DISCO_TV":"Total variation", "DISCO_H":"Hellinger", "DISCO_Q":"Binary", "DISCO_KL":"KL"})
filtered_df = filtered_df[(filtered_df.M%20==0)|(filtered_df.M<=10)]

In [None]:
for max_acc, subdf in filtered_df.groupby("Max_acc", as_index=False):

    plt.figure(figsize=(15, 5))

    for i, corr_name in enumerate(["Pearson", "Spearman", "Kendall"]):
        plt.subplot(1,3,i+1)
        sns.lineplot(data=subdf[subdf.Correlation==corr_name], x="M", y="Value", hue="Score")
        plt.title("Correlation = "+corr_name)
    plt.suptitle(f"Max accuracy = {max_acc:.1%}")
    plt.tight_layout()
    plt.show()