- load datasets
- for each dataset
  - for each projection
      - for each n clusters (true_clusters/2:true_clusters*2
          - for i in 100 trials
              - run KMeans
              - compute silhouette score
              - compute NMI, v-measure

In [1]:
import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt



In [2]:
from sklearn.cluster import KMeans
from sklearn.metrics import homogeneity_completeness_v_measure

In [3]:
from tfumap.paths import ensure_dir, MODEL_DIR, DATA_DIR

In [4]:
from tfumap.silhouette import silhouette_score_block

In [5]:
### load datasets and projections

In [6]:
from tfumap.load_datasets import load_MNIST, load_FMNIST, load_CIFAR10, load_MACOSKO, load_CASSINS

##### load dataset

In [7]:
output_dir = MODEL_DIR/'projections' 

In [8]:
classes = ['umap-learn', 'direct', 'network', 'autoencoder', 'PCA', 'TSNE', 'parametric-tsne']

In [9]:
projection_df = pd.DataFrame(columns = ['dataset', 'class_', 'train_z', 'train_label', 'dim'])

for dataset in ['mnist', 'fmnist', 'cifar10', 'macosko2015', 'cassins_dtw']:
    # load Y
    if dataset == 'mnist':
        #X_train, X_test, X_valid, Y_train, Y_test, Y_valid  = load_MNIST()
        _, _, _, Y_train, _, _  = load_MNIST()
    elif dataset == 'fmnist':
        _, _, _, Y_train, _, _  = load_FMNIST()
    elif dataset == 'cifar10':
        _, _, _, Y_train, _, _  = load_CIFAR10()
    elif dataset == 'macosko2015':
        _, _, _, Y_train, _, _  = load_MACOSKO()
    elif dataset in ['cassins', 'cassins_dtw']:
        _, _, _, Y_train, _, _  = load_CASSINS()
    # load projections
    for class_ in classes:
        try:
            loc = output_dir / dataset / class_ / 'z.npy'
            z = np.load(loc)
            projection_df.loc[len(projection_df)] = [dataset, class_, z, Y_train, 2]
        except:
            print(loc)
        try:
            loc = output_dir / dataset / '64' / class_ / 'z.npy'
            z = np.load(loc)
            projection_df.loc[len(projection_df)] = [dataset, class_, z, Y_train, 64]
        except:
            print(loc)

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/models/projections/mnist/64/TSNE/z.npy
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/models/projections/fmnist/64/TSNE/z.npy
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/models/projections/cifar10/64/TSNE/z.npy
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/models/projections/macosko2015/64/TSNE/z.npy
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/models/projections/cassins_dtw/64/TSNE/z.npy


In [10]:
projection_df[:3]

Unnamed: 0,dataset,class_,train_z,train_label,dim
0,mnist,umap-learn,"[[4.918474, 7.342497], [14.501698, 5.7053337],...","[5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, ...",2
1,mnist,umap-learn,"[[5.931274, 4.8182616, 3.7269907, 9.074417, 2....","[5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, ...",64
2,mnist,direct,"[[1.8560303, 2.3066766], [14.330162, 0.5864011...","[5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, ...",2


In [17]:
def get_cluster_metrics(row, n_init=5):
    
    # load cluster information
    save_loc = DATA_DIR / 'clustering_metric_df'/ ('_'.join([row.class_, str(row.dim), row.dataset]) + '.pickle')
    print(save_loc)
    if save_loc.exists() and save_loc.is_file():
        
        cluster_df = pd.read_pickle(save_loc)
        return cluster_df
    
    # make cluster metric dataframe
    cluster_df = pd.DataFrame(
        columns=[
            "dataset",
            "class_",
            "dim",
            "silhouette",
            "homogeneity",
            "completeness",
            "v_measure",
            "init_",
            "n_clusters",
            "model",
        ]
    )
    y = row.train_label
    z = row.train_z
    n_labels = len(np.unique(y))
    for n_clusters in tqdm(np.arange(n_labels - int(n_labels / 2), n_labels + int(n_labels / 2)), leave=False, desc = 'n_clusters'):
        for init_ in tqdm(range(n_init), leave=False, desc='init'):
            kmeans = KMeans(n_clusters=n_clusters, random_state=init_).fit(z)
            clustered_y = kmeans.labels_
            homogeneity, completeness, v_measure = homogeneity_completeness_v_measure(
                y, clustered_y
            )
            ss, _ = silhouette_score_block(z, clustered_y)
            cluster_df.loc[len(cluster_df)] = [
                row.dataset,
                row.class_,
                row.dim,
                ss,
                homogeneity,
                completeness,
                v_measure,
                init_,
                n_clusters,
                kmeans,
            ]
    
     # save cluster df in case this fails somewhere
    ensure_dir(save_loc)
    cluster_df.to_pickle(save_loc)
    return cluster_df

In [19]:
metric_dfs = [
    get_cluster_metrics(row, n_init=5)
    for idx, row in tqdm(projection_df.iterrows(), total=len(projection_df), desc = 'metric')
]

HBox(children=(IntProgress(value=0, description='metric', max=65, style=ProgressStyle(description_width='initi…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/umap-learn_2_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/umap-learn_64_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/direct_2_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/direct_64_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/network_2_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/network_64_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/autoencoder_2_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/autoencoder_64_mnist.pickle
/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/PCA_2_mnist.pickle
/mnt/cube/tsainbur/Projec

HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/umap-learn_64_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/direct_2_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/direct_64_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/network_2_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/network_64_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/autoencoder_2_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/autoencoder_64_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/PCA_2_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/PCA_64_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/TSNE_2_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/parametric-tsne_2_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

/mnt/cube/tsainbur/Projects/github_repos/umap_tf_networks/data/clustering_metric_df/parametric-tsne_64_cassins_dtw.pickle


HBox(children=(IntProgress(value=0, description='n_clusters', max=20, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='init', max=5, style=ProgressStyle(description_width='initial'…

In [None]:
metric_df = pd.concat(metric_dfs)

In [None]:
### choose the best fit model (by silhouette) for each dataset

In [None]:
metric_df.reset_index()[:50]

In [None]:
metric_df.reset_index()[50:100]

In [None]:
metric_df.reset_index()[100:]

In [None]:
# for a dimension, plot dataset 

In [None]:
import seaborn as sns

In [None]:
palette = sns.color_palette('Reds', 3)[1:] + sns.color_palette('Blues', 6)[3:] + sns.color_palette('Greens', 1) 
sns.palplot(palette)

In [None]:
metric_df

In [None]:
fg = sns.catplot(
    x="n_clusters",
    y="v_measure",
    hue="class_",
    hue_order=["TSNE", "parametric-tsne", "umap-learn",  "network", "autoencoder", "PCA"],
    #order = ['mnist', 'fmnist', 'macosko2015', 'cassins_dtw', 'cifar10'],
    col="dim",
    height=2.5, 
    aspect=1.75,
    data=metric_df,
    kind="bar",
    palette=palette,
    #legend=False
)
(fg.despine(bottom=True)
 #.set_xticklabels(['MNIST', 'FMNIST', 'Retina', 'Cassin\'s', 'CIFAR10'])
 .set_axis_labels("", "Silhouette Score")
)
ax = fg.axes[0][0]
ax.tick_params(axis=u'both', which=u'both',length=0)
fg._legend._legend_title_box._text.set_text("")
#fg._legend.texts[2].set_text("P. UMAP")
#fg._legend.texts[3].set_text("UMAP AE")
ax.set_title('2 Dimensions')
ax = fg.axes[0][1]
ax.tick_params(axis=u'both', which=u'both',length=0)

ax.set_title('64 Dimensions')


fg._legend.texts[np.where((np.array([i._text for i in fg._legend.texts]) == 'umap-learn'))[0][0]].set_text('UMAP-learn')
fg._legend.texts[np.where((np.array([i._text for i in fg._legend.texts]) == 'network'))[0][0]].set_text('P. UMAP')
fg._legend.texts[np.where((np.array([i._text for i in fg._legend.texts]) == 'autoencoder'))[0][0]].set_text('UMAP AE')
fg._legend.texts[np.where((np.array([i._text for i in fg._legend.texts]) == 'parametric-tsne'))[0][0]].set_text('P. TSNE')

#save_fig(FIGURE_DIR/'silhouette-test', save_pdf=True, dpi=300)