In [None]:
import pandas as pd
from scipy import stats
import itertools

In [None]:
import os
import sys
sys.path.append(os.path.join("..",".."))
from discotec import *

In [None]:
from scipy.optimize import linear_sum_assignment
import seaborn as sns
import matplotlib as mpl
#mpl.use("pgf")

import matplotlib.pyplot as plt
import scienceplots

plt.style.use('science')

# Generating partitions

In [None]:
@partial(jax.jit, static_argnums=[1, 2])
def generate_reference_partition(random_key, n_samples, n_clusters):
    y_true = jax.random.choice(random_key, n_clusters, shape=(n_samples,))
    return y_true

In [None]:
@partial(jax.jit, static_argnums=[3])
def generate_fixedK_partition(key, reference, conservation_prob, n_clusters):
    relabelling_key, new_cluster_key = jax.random.split(key, 2)
    to_conserve = jax.random.bernoulli(relabelling_key, p=conservation_prob, shape=reference.shape)
    # To keep the accuracy between expected bounds, we make sure that
    # new_clusters is always different from the reference partition
    new_clusters = jax.random.choice(new_cluster_key, n_clusters-1, shape=reference.shape)
    new_clusters = (new_clusters+reference+1)%n_clusters
    

    return reference*to_conserve+new_clusters*(1-to_conserve)

In [None]:
@partial(jax.jit, static_argnums=[1,2,3,4,5])
def generate_scenario(random_key, n_samples, n_models, n_clusters, min_accuracy, max_accuracy):
    reference_key, switch_key, models_key = jax.random.split(random_key, 3)

    # We start by generate the labels of this scenario
    y_true = generate_reference_partition(reference_key, n_samples, n_clusters)

    if min_accuracy==max_accuracy:
        conservation_probs = jnp.ones(n_models)*(1-min_accuracy)
    else:
        conservation_probs = jax.random.uniform(switch_key, minval=min_accuracy, maxval=max_accuracy, shape=(n_models,))

    model_sampler = jax.vmap(generate_fixedK_partition, in_axes=[0, None, 0, None])
    models_keys = jax.random.split(models_key, n_models)
    y_pred = model_sampler(models_keys, y_true, conservation_probs, n_clusters)

    return y_true, y_pred

In [None]:
def unsupervised_accuracy(y_true,y_pred):
    confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
    r, c = linear_sum_assignment(confusion_matrix, maximize=True)

    return confusion_matrix[r,c].sum()/confusion_matrix.sum()

# Run the simulations

In [None]:
n_samples = 200
n_models = 5
n_runs = 50
n_clusters = 10
max_accuracies = [0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
min_accuracy = 0.1
filename =  f"bounded_accuracy_results_{n_models}.csv"

In [None]:
master_key = jax.random.key(0)
all_scores = []
for max_accuracy in max_accuracies:
    print(max_accuracy)
    if os.path.exists(filename):
        continue
    for i in range(n_runs):
        print(i, end=" ")
        # Split the random key
        master_key, dataset_key = jax.random.split(master_key)
    
        # Generate the scenario
        y_true, y_pred = generate_scenario(dataset_key, n_samples, n_models, n_clusters, min_accuracy, max_accuracy)
    
        # Compute the consensus matrix
        centroid = compute_consensus_matrix(y_pred)
    
        # Evaluate all metrics
        ## External validity index
        true_aris = jnp.array([metrics.adjusted_rand_score(y_true, y) for y in y_pred])
    
        # Connectivity based index
        ## Notice that we negate to get a metric to maximise (instead of minimising)
        tv_ranking_scores = -compute_tv_ranking(y_pred, centroid)
        hellinger_ranking_scores = -compute_hellinger_ranking(y_pred, centroid)
        kl_ranking_scores = -compute_kl_ranking(y_pred, centroid)
        weighted_tv_scores = -compute_weighted_tv(y_pred, centroid)
        weighted_hellinger_scores = -compute_weighted_hellinger(y_pred, centroid)
        weighted_kl_scores = -compute_weighted_kl(y_pred, centroid)

        quantised_centroid = (centroid>centroid.mean()).astype(float)
        quantised_scores = -compute_tv_ranking(y_pred, quantised_centroid)
        weighted_quantised = -compute_weighted_tv(y_pred, quantised_centroid)
        
        pairwise_ari_scores = pairwise_score(y_pred)
        pairwise_nmi_scores = pairwise_score(y_pred, method="nmi")
    
        for name, scores in zip(["DISCO_TV","DISCO_WTV","DISCO_KL","DISCO_WKL","DISCO_H","DISCO_WH","AARI","ANMI", "DISCO_Q", "DISCO_WQ"],
                                [tv_ranking_scores, weighted_tv_scores, kl_ranking_scores, weighted_kl_scores, hellinger_ranking_scores, weighted_hellinger_scores,
                                pairwise_ari_scores, pairwise_nmi_scores, quantised_scores, weighted_quantised]):
            for corr_name, corr_fct in zip(["Pearson", "Spearman", "Kendall"], [stats.pearsonr, stats.spearmanr, stats.kendalltau]):
                all_scores += [{
                    "Score":name,
                    "Correlation":corr_name,
                    "Value":corr_fct(true_aris, scores).statistic,
                    "Run":i,
                    "Max_acc":max_accuracy
                }]
if not os.path.exists(filename):
    df = pd.DataFrame(all_scores)
    df.to_csv(filename)

In [None]:
for n_models in [5,25,50]:
    df = pd.read_csv(f"bounded_accuracy_results_{n_models}.csv")
    # For the storytelling purposes, I dropped the weighted scores
    filtered_df = df[~df.Score.isin(["DISCO_WQ","DISCO_WH","DISCO_WKL","DISCO_WTV", "DISCO_TV", "DISCO_H"])].replace({"DISCO_TV":"Total variation", "DISCO_H":"Hellinger", "DISCO_Q":"Binary", "DISCO_KL":"KL"})
    
    for correlation, subdf in filtered_df.groupby("Correlation"):
        axes = sns.lineplot(subdf, x="Max_acc", y="Value", hue="Score")
        plt.ylim((-0.2,1.1))
        plt.ylabel("Ranking Correlation")
        plt.xlabel(r"$\rho_\text{max}$")
        plt.savefig(f"{correlation}_{n_models}.pdf", bbox_inches="tight")
        plt.show()

In [None]:
# Let us save as well some figures of the consensus matrices
master_key = jax.random.key(0)
plt.figure(figsize=(15,8))
for i, rho in enumerate([0.2, 0.5, 0.9]):
    plt.subplot(2,3,i+1)
    # Split the random key
    master_key, dataset_key = jax.random.split(master_key)

    # Generate the scenario
    y_true, y_pred = generate_scenario(dataset_key, 200, 50, 10, 0.1, rho)
    order = jnp.argsort(y_true)

    # Compute the consensus matrix
    centroid = compute_consensus_matrix(y_pred)

    plt.imshow(centroid[order][:,order])
    plt.title(r"$\rho_\text{max}$ = "+f"{rho:.1f}")

    if i==0:
        plt.ylabel("Raw matrix")

    plt.subplot(2,3,i+4)

    quantised_centroid = (centroid>centroid.mean()).astype(float)

    plt.imshow(quantised_centroid[order][:,order])
    if i==0:
        plt.ylabel("Binarised matrix")
plt.tight_layout()
plt.savefig("example_consensus.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Let us save as well some figures of the consensus matrices
master_key = jax.random.key(0)
plt.figure(figsize=(15,8))
for i, rho in enumerate([0.2, 0.5, 0.9]):
    plt.subplot(2,3,i+1)
    # Split the random key
    master_key, dataset_key = jax.random.split(master_key)

    # Generate the scenario
    y_true, y_pred = generate_scenario(dataset_key, 200, 50, 10, 0.1, rho)

    # Compute the consensus matrix
    centroid = compute_consensus_matrix(y_pred)

    # Compute the quantised score
    quantised_centroid = (centroid>centroid.mean()).astype(float)

    quantised_score = compute_tv_ranking(y_pred, quantised_centroid)
    best_model = jnp.argmin(quantised_score)
    order = jnp.argsort(y_pred[best_model])
    

    plt.imshow(centroid[order][:,order])
    plt.title(r"$\rho_\text{max}$ = "+f"{rho:.1f}")

    if i==0:
        plt.ylabel("Raw matrix")

    plt.subplot(2,3,i+4)


    plt.imshow(quantised_centroid[order][:,order])
    if i==0:
        plt.ylabel("Binarised matrix")
plt.tight_layout()
plt.savefig("example_selection.pdf", bbox_inches="tight")
plt.show()