In [None]:
import os
import sys
import pathlib
import numpy as np
import pandas as pd
from numpy.random import multivariate_normal
from scipy.stats import mode
from scipy.spatial.distance import cdist
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from MulticoreTSNE import MulticoreTSNE as TSNE
# from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN, KMeans
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from lifelines import (KaplanMeierFitter, WeibullFitter, ExponentialFitter,
                       LogNormalFitter, LogLogisticFitter, PiecewiseExponentialFitter,
                       GeneralizedGammaFitter, SplineFitter)

# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.keras.losses import kl_divergence

import umap
import hdbscan
from mdutils.mdutils import MdUtils

from flwr.common.typing import Parameters

FLP_path = pathlib.Path('/home/relogu/Desktop/OneDrive/UNIBO/Magistrale/Federated Learning Project').absolute()
sys.path.insert(1, str(FLP_path))
from py.losses import get_keras_loss
from py.dataset_util import get_euromds_dataset, get_euromds_ids, get_outcome_euromds_dataset, fillcolumn_prob
from py.dec.util import create_autoencoder, create_clustering_model, target_distribution
from py.util import compute_centroid_np, get_dims_from_weights, return_not_binary_indices
import py.metrics as my_metrics

results_path = pathlib.Path('/home/relogu/Desktop/OneDrive/UNIBO/Magistrale/Federated Learning Project/output_dec_euromds')
N_CLUSTERS = 20
LOSS = 'mse'

# disable possible gpu devices for this kernel
tf.config.set_visible_devices([], 'GPU')

In [None]:
# prefix = 'aggregated_weights_'
encoder_param = np.load(results_path/str('encoder.npz'),
                        allow_pickle=True)
encoder_param = np.squeeze(np.array([encoder_param[p] for p in encoder_param]))
encoder_ft_param = np.load(results_path/str('encoder_ft.npz'),
                           allow_pickle=True)
encoder_ft_param = np.squeeze(np.array([encoder_ft_param[p] for p in encoder_ft_param]))
encoder_final_param = np.load(results_path/str('encoder_final.npz'),
                              allow_pickle=True)
encoder_final_param = np.squeeze(np.array([encoder_final_param[p] for p in encoder_final_param]))
initial_centroids = np.load(results_path/'initial_centroids.npz',
                            allow_pickle=True)
initial_centroids = np.array([initial_centroids[p] for p in initial_centroids])
final_centroids = np.load(results_path/str('final_centroids.npz'),
                          allow_pickle=True)
final_centroids = np.array([final_centroids[p] for p in final_centroids])

In [None]:
print('Initial Centroids: {}'.format(initial_centroids))
print('Final Centroids: {}'.format(final_centroids))
delta_centroids = final_centroids-initial_centroids
print('Delta Centroids: {}'.format(delta_centroids))

In [None]:
FILL = True
groups = ['CNA', 'Genetics']
ex_cols = ['UTX', 'CSF3R', 'SETBP1', 'PPM1D']
if FILL:
    accept_nan = 2044
    img_shape = (6,9)
else:
    accept_nan = 0
    img_shape = (5,8)
x = get_euromds_dataset(
    groups=groups,
    exclude_cols=ex_cols,
    accept_nan=accept_nan,
    fill_fn=fillcolumn_prob)
columns = x.columns
# getting labels from HDP
prob = get_euromds_dataset(groups=['HDP'])
y = []
for label, row in prob.iterrows():
    if np.sum(row) > 0:
        y.append(row.argmax())
    else:
        y.append(-1)
y = np.array(y)
# getting the outcomes
outcomes = get_outcome_euromds_dataset()
# getting IDs
ids = get_euromds_ids()
n_features = len(x.columns)
x = np.array(x)
outcomes = np.array(outcomes[['outcome_3', 'outcome_2']])
ids = np.array(ids)
y_pred = pd.read_csv(results_path/'pred.csv')
y_pred = np.array(y_pred.sort_values(by=['Unnamed: 0'])['label'])

In [None]:
up_frequencies = np.array([np.array(np.count_nonzero(
    x[:, i])/x.shape[0]) for i in range(n_features)])
nb_idx = return_not_binary_indices(x)
b_idx = list(range(len(x[0,:])))[len(nb_idx):]

In [None]:
# setting up the autoencoder
dims = get_dims_from_weights(encoder_param)
config = {
    'binary': False,
    'tied': True,
    'dims': dims,
    'dropout': 0.0,
    'ran_flip': 0.0,
    'act': 'selu',
    'ortho': False,
    'u_norm': True,
    'init': None,
    'b_idx': b_idx,
    'use_bias': True,
    'ae_metrics': [
        # my_metrics.get_rounded_accuracy(idx=b_idx),
        # my_metrics.get_slice_accuracy(idx=nb_idx),
        # my_metrics.get_slice_hamming_loss(mode='multilabel', threshold=0.50, idx=b_idx),
        # my_metrics.get_slice_log_mse_loss(idx=nb_idx)
        ],
}
autoencoder, encoder, decoder = create_autoencoder(
    config, None
)
# compiling the autoencoder (necessary for evaluating)
autoencoder.compile(
    loss=get_keras_loss(LOSS),
    metrics=config['ae_metrics'],
)

In [None]:
kmeans.cluster_centers_

In [None]:
# compiling the clustering model (necessary for evaluating)
encoder.set_weights(encoder_ft_param)
z = encoder(x).numpy()
kmeans = KMeans(
    n_clusters=N_CLUSTERS,
    n_init=20
    ).fit(z)
centroids = []
n_classes = len(np.unique(kmeans.labels_))
for i in np.unique(kmeans.labels_):
    idx = (kmeans.labels_ == i)
    centroids.append(compute_centroid_np(z[idx,:]))
# saving the model weights
centroids = np.array([centroids])
clustering_model = create_clustering_model(
    n_clusters=8,
    alpha=1,
    encoder=encoder)
clustering_model.get_layer(
    name='clustering').set_weights(centroids)
clustering_model.compile(
    optimizer=tf.keras.optimizers.SGD(
        learning_rate=0.001,
        momentum=0.9
        ),
    loss='kld')
clustering_model.summary()

In [None]:
train_q = clustering_model(x).numpy()
# update the auxiliary target distribution p
train_p = target_distribution(train_q)
clustering_model.fit(x=x,
                     y=train_p,
                     verbose=2,
                     batch_size=100)
q = clustering_model(x).numpy().argmax(1)
tol = float(1 - np.sum(kmeans.labels_==q)/len(x))
print(tol)

In [None]:
new_centroids = clustering_model.get_layer(
    name='clustering').get_weights()
diff = new_centroids - np.array([kmeans.cluster_centers_])
print(np.sum(diff))
print(diff)

In [None]:
new_centroids = clustering_model.get_layer(
    name='clustering').get_weights()
diff = new_centroids - np.array([kmeans.cluster_centers_])
np.sum(diff)

In [None]:
encoder.set_weights(encoder_ft_param)
z = encoder(x).numpy()
tsne = TSNE(
    n_components=2,
    n_jobs=-1,
    random_state=51550).fit_transform(z)
dbcl_tsne = DBSCAN(
    min_samples=40,
    eps=3,
    ).fit(tsne)
dbcl = DBSCAN(
    #min_samples=5,
    eps=35,
    ).fit(z)
N_CLUSTERS = 8
kmeans = KMeans(
    n_clusters=N_CLUSTERS
    ).fit(z)
print('Accuracies obtained (HDP as ground-truth, 6+1 labels)')
labels_list = [dbcl_tsne.labels_, y_pred, dbcl.labels_, kmeans.labels_, y, None]
descs_list = ['dbscan_tsne', 'dec', 'dbscan', 'kmeans', 'hdp', 'nolabels']
# for labels, desc in zip(labels_list[:-1], descs_list[:-1]):
#     n_classes = len(np.unique(labels))
#     accuracy =  my_metrics.acc(y, labels)
#     s_accuracy = (n_classes/(n_classes-1))*accuracy-(1/(n_classes-1))
#     nmi_score = my_metrics.nmi(y, labels)
#     print('\t{}: acc {}, s_acc {}, nmi {}, {} labels'. \
#         format(desc, accuracy, s_accuracy, nmi_score, n_classes))

In [None]:
# ## some stats
# # initializing clustering model
# clustering_model = create_clustering_model(
#     n_clusters=13,
#     alpha=1,
#     encoder=encoder)
# # compiling the clustering model (necessary for evaluating)
# clustering_model.compile(
#     loss='kld')
# encoder.set_weights(encoder_final_param)
# clustering_model.get_layer(
#     name='clustering').set_weights(final_centroids[0])
# feat_test = encoder(x)
# # computations to see the effective retro-projecting capability of DEC
# loss = autoencoder.evaluate(x, x, verbose=0)
# x_ae_test = autoencoder(x)
# y_pred = clustering_model.predict(x, verbose=0).argmax(1)
# y_ae_pred = clustering_model.predict(np.round(x_ae_test), verbose=0).argmax(1)
# print('Cycle accuracy is {}\n{} loss is {}'.
#       format(my_metrics.acc(y_pred, y_ae_pred), LOSS, loss))
# confusion_matrix = sklearn.metrics.confusion_matrix(y_pred, y_ae_pred)
# plt.figure(figsize=(16, 14))
# sns.heatmap(confusion_matrix, annot=True, fmt="d", annot_kws={"size": 20})
# plt.title("Confusion matrix", fontsize=30)
# plt.ylabel('Predicted Label', fontsize=25)
# plt.xlabel('Cycle Predicted Label', fontsize=25)
# plt.savefig(results_path/'conf_matrix_cycle_accuracy.png')
# plt.close()

# print("{},{},{},{},{}\n{},{},{},{},{}\n{},{},{},{},{}\n{},{},{},{},{}\n"
#         "{},{},{},{},{}\n{},{},{},{},{}\n{},{},{},{},{}\n{},{},{},{},{}".format(*columns))

# for labels, desc in zip(labels_list[:-1], descs_list[:-1]):
#     print("Get properties of {} labels".format(desc))
#     centroids = None
#     if labels is not None:
#         centroids = []
#         n_clusters_found = len(np.unique(labels))
#         for i in np.unique(labels):
#             idx = (labels == i)
#             centroids.append(compute_centroid_np(z[idx, :]))
#     #centroids
#     x_ae_centroids = decoder.predict(np.array(centroids))
#     r_x_ae_centroids = np.round(x_ae_centroids)
#     drivers = pd.DataFrame(data=r_x_ae_centroids, columns=columns)
#     font = {'family': 'sans-serif',
#             'size': 15}
#     matplotlib.rc('font', **font)
#     #cov = np.array([1*np.std(centroids, axis=0)]*N_CLUSTERS)
#     #print('Centroid {}\nVariations{}'.format(centroids[0], multivariate_normal(centroids[0], cov, size=10)))
#     silhests, nearests, s_means, s_modes, means, modes = \
#         pd.DataFrame(), pd.DataFrame(), pd.DataFrame(
#         ), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
#     silh = sklearn.metrics.silhouette_samples(feat_test, labels)
#     for i, feat_centroid in enumerate(centroids):
#         if len(labels[labels == i]) > 0:
#             distances = [np.linalg.norm(feat_centroid-f_t)
#                          for f_t in feat_test]
#             cov = np.array([5*np.std(feat_test[labels == i], axis=0)]*dims[-1])
#             feat_samples = multivariate_normal(feat_centroid, cov, size=100)
#             x_samples = decoder.predict(feat_samples)
#             r_x_samples = np.round(x_samples)
#             x_i = x[labels == i]
#             nearests = nearests.append(pd.DataFrame(
#                 data=[x[np.argmin(distances)]], columns=columns))
#             silhests = silhests.append(pd.DataFrame(
#                 data=[x_i[np.argmax(silh[labels == i])]], columns=columns))
#             means = means.append(pd.DataFrame(
#                 data=[np.average(x_i, axis=0)], columns=columns))
#             modes = modes.append(pd.DataFrame(
#                 data=mode(x_i, axis=0)[0], columns=columns))

#             s_means = s_means.append(pd.DataFrame(
#                 data=[np.average(x_samples, axis=0)], columns=columns))
#             s_modes = s_modes.append(pd.DataFrame(
#                 data=mode(r_x_samples, axis=0)[0], columns=columns))

#     names = ['Mean', 'Mode', 'Nearest', 'Silh-est',
#                 'Centroid', 'S Mean', 'S Mode']
#     dfs = [means, modes, nearests, silhests,
#             drivers, s_means, s_modes]
#     for name, df in zip(names, dfs):
#         fig, ax = plt.subplots(
#             1, n_clusters_found, figsize=(2*n_clusters_found, 3))
#         for j, i in enumerate(np.unique(labels)):
#             print('j{},i{}'.format(j, i))
#             img = df.to_numpy()[j].reshape(img_shape)
#             ax[j].set_title('{} {}'.format(name, i))
#             ax[j].imshow(img.squeeze(), interpolation='none', cmap='gray')
#             ax[j].axis('off')
#             if name in set(names[:2]):
#                 ax[j].text(-0.5, 9.0,
#                             "{} samples".format(len(x[labels == i])))
#         plt.savefig(
#             results_path/'{}_{}_imgs.png'.format(name.lower().replace(' ', '_'), desc))
#         plt.close()
#     del drivers, silhests, nearests, s_means, s_modes, means, modes
#     del silh, x_ae_test, y_ae_pred, x_ae_centroids, r_x_ae_centroids


In [None]:
## Gradients plot
# set weights
encoder.set_weights(encoder_ft_param)
z = encoder(x).numpy()
for labels, desc in zip(labels_list[:-1], descs_list[:-1]):
#for labels, desc in zip([kmeans.labels_], 'kmeans'):
# for ii in range(3, 15):
#     kmeans = KMeans(
#         n_clusters=int(ii)
#         ).fit(z)
#     labels = kmeans.labels_
#     desc = 'kmeans{}'.format(ii)
    print('Visualizing {} gradients'.format(desc))
    centroids = None
    lables_list = np.unique(labels[labels>=0])
    n_clusters_found = len(lables_list)
    if n_clusters_found > 1:
        centroids = []
        for i in lables_list:
            idx = (labels == i)
            centroids.append(compute_centroid_np(z[idx,:]))
        # initializing clustering model
        clustering_model = create_clustering_model(
            n_clusters=n_clusters_found,
            alpha=1,#max(n_clusters_found-1, 1),
            encoder=encoder)
        # compiling the clustering model (necessary for evaluating)
        clustering_model.compile(
            loss='kld')
        clustering_model.get_layer(
            name='clustering').set_weights(np.array([centroids]))
        clustering_model.get_layer(name='clustering').trainable = False
        cl_layer = clustering_model.get_layer(name='clustering')
        # get soft assignments
        q = clustering_model(x).numpy()
        soft_labels = q.argmax(1)
        p = target_distribution(q)
        datapoints = tf.Variable(x, dtype=tf.float32)
        soft_assignments = tf.Variable(p)

        # # compute gradients dec way
        # z = encoder(x).numpy()
        # qq = (q.T/q.sum(axis=1)).T
        # pp = (qq**2)
        # pp = (pp.T/pp.sum(axis=1)).T
        # grad = 2.0/(1.0+cdist(z, centroids, 'sqeuclidean'))*(pp-qq)*cdist(z, centroids, 'cityblock')

            
        # fig, axs = plt.subplots(1, n_clusters_found,
        #                         figsize=(10*n_clusters_found, 10),
        #                         squeeze=True,
        #                         facecolor='white',
        #                         dpi=200)
        # k=0
        # ind = np.bincount(q.argmax(axis=1)).argmin()
        # colors = cm.rainbow(np.linspace(0, 1, n_clusters_found))
        # colors = cm.rainbow(np.linspace(0, 1, n_clusters_found))
        # for i in range(n_clusters_found):
        #     idx = (soft_labels == lables_list[k])
        #     ax = axs[i]
        #     scatter = ax.scatter(qq[idx, ind],
        #                         grad[idx, ind],
        #                         color=colors[k],
        #                         s=[20]*len(p[idx, lables_list[k]]),
        #                         )
        #     ax.set_xlabel(r'$p_{i%d}$' % lables_list[k])
        #     ax.set_ylabel(r'$|\frac{\partial L}{\partial z_i}|$')
        #     ax.grid()
        #     ax.set_facecolor('gray')
        #     k+=1
        
        # compute gradients tf way
        with tf.GradientTape() as tape:
            embedded_points = encoder(datapoints)
            outputs = cl_layer(embedded_points)
            kld = kl_divergence(soft_assignments, outputs)
            dc_da = tape.gradient(kld, embedded_points)

        norm_dc_da = np.linalg.norm(dc_da, axis=1)
            
        fig, axs = plt.subplots(1, n_clusters_found,
                                figsize=(10*n_clusters_found, 10),
                                squeeze=True,
                                facecolor='white',
                                dpi=200)
        k=0
        colors = cm.rainbow(np.linspace(0, 1, n_clusters_found))
        for i in range(n_clusters_found):
            idx = (soft_labels == lables_list[k])
            ax = axs[i]
            scatter = ax.scatter(p[idx, lables_list[k]],
                                norm_dc_da[idx],
                                color=colors[k],
                                s=[20]*len(p[idx, lables_list[k]]),
                                )
            ax.set_xlabel(r'$p_{i%d}$' % lables_list[k])
            ax.set_ylabel(r'$|\frac{\partial L}{\partial z_i}|$')
            ax.grid()
            ax.set_facecolor('gray')
            k+=1
        plt.savefig(results_path/'initial_gradients_{}.png'.format(desc),
                    facecolor=fig.get_facecolor(),
                    edgecolor='none')
        plt.show()


In [None]:
axs

In [None]:
encoder.set_weights(encoder_final_param)
z = encoder(x).numpy()

for labels, desc in zip(labels_list, descs_list):
    print('Visualizing {} labels'.format(desc))
    a = z
    centroids = None
    if labels is not None:
        centroids = []
        n_clusters_found = len(np.unique(labels))
        for i in np.unique(labels):
            idx = (labels == i)
            centroids.append(compute_centroid_np(z[idx,:]))
        
        a = np.concatenate((z, centroids), axis=0)
        
    tsne = TSNE(
        n_components=2,
        n_jobs=-1,
        random_state=51550).fit_transform(a)
    fig, ax = plt.subplots(figsize=(16, 8))
    scatter = plt.scatter(tsne[:len(x), 0],
                          tsne[:len(x), 1],
                          c=labels,
                          s=[25]*len(x),
                          cmap='Spectral',
                          alpha=0.6)
    if labels is not None:
        scatter1 = plt.scatter(tsne[len(x):, 0],
                               tsne[len(x):, 1],
                               c=range(n_clusters_found),
                               marker='+',
                               s=[1000]*n_clusters_found,
                               cmap='Spectral')
    plt.grid()
    if labels is not None:
        ax.set_facecolor('gray')
        legend = ax.legend(*scatter.legend_elements(), title="Labels",framealpha=0.3)
    plt.xlabel(r'$\tilde{x}$')
    plt.ylabel(r'$\tilde{y}$')
    plt.savefig(results_path/'finetune_tsne_space_{}.png'.format(desc))
    plt.show()
    #plt.close()
del tsne#, a

In [None]:
encoder.set_weights(encoder_ft_param)
z = encoder(x).numpy()

for labels, desc in zip(labels_list, descs_list):
    print('Visualizing {} labels'.format(desc))
    a = z
    centroids = None
    if labels is not None:
        centroids = []
        n_clusters_found = len(np.unique(labels))
        for i in np.unique(labels):
            idx = (labels == i)
            centroids.append(compute_centroid_np(z[idx,:]))
        
        a = np.concatenate((z, centroids), axis=0)
        
    tsne = TSNE(
        n_components=2,
        n_jobs=-1,
        random_state=51550).fit_transform(a)
    fig, ax = plt.subplots(figsize=(16, 8))
    scatter = plt.scatter(tsne[:len(x), 0],
                          tsne[:len(x), 1],
                          c=labels,
                          s=[25]*len(x),
                          cmap='Spectral',
                          alpha=0.6)
    if labels is not None:
        scatter1 = plt.scatter(tsne[len(x):, 0],
                               tsne[len(x):, 1],
                               c=range(n_clusters_found),
                               marker='+',
                               s=[1000]*n_clusters_found,
                               cmap='Spectral')
    plt.grid()
    if labels is not None:
        ax.set_facecolor('gray')
        legend = ax.legend(*scatter.legend_elements(), title="Labels",framealpha=0.3)
    plt.xlabel(r'$\tilde{x}$')
    plt.ylabel(r'$\tilde{y}$')
    plt.savefig(results_path/'finetune_tsne_space_{}.png'.format(desc))
    plt.show()
    #plt.close()
del tsne#, a


In [None]:
encoder.set_weights(encoder_param)
a = encoder(x).numpy()
tsne = TSNE(
    n_components=2,
    n_jobs=-1,
    random_state=51550).fit_transform(a)
fig, ax = plt.subplots(figsize=(16, 8))
scatter = plt.scatter(tsne[:len(x), 0],
                      tsne[:len(x), 1],
                      s=[25]*len(x),
                      alpha=0.4)
plt.grid()
#ax.set_facecolor('gray')
plt.xlabel(r'$\tilde{x}$')
plt.ylabel(r'$\tilde{y}$')
plt.savefig(results_path/'pretrain_tsne_space.png')
#plt.close()
del tsne, a

In [None]:
encoder.set_weights(encoder_final_param)
z = encoder(x).numpy()

for labels, desc in zip(labels_list, descs_list):
    print('Visualizing {} labels'.format(desc))
    a = z
    centroids = None
    if labels is not None:
        centroids = []
        n_clusters_found = len(np.unique(labels))
        for i in np.unique(labels):
            idx = (labels == i)
            centroids.append(compute_centroid_np(z[idx,:]))
        
        a = np.concatenate((z, centroids), axis=0)
    fig, axs = plt.subplots(dims[-1], dims[-1],
                            figsize=(8*dims[-1], 4*dims[-1]),
                            squeeze=False,
                            facecolor='white',
                            dpi=200)
    for i in range(dims[-1]):
        for j in range(dims[-1]):
            ax = axs[i, j]
            scatter = ax.scatter(a[:len(x), i],
                                a[:len(x), j],
                                c=labels,
                                s=[25]*len(x),
                                cmap='Spectral',
                                alpha=0.6)
            
            if centroids is not None:
                scatter1 = ax.scatter(a[len(x):, i],
                                    a[len(x):, j],
                                    c=range(n_clusters_found),
                                    marker='+',
                                    s=[1000]*n_clusters_found,
                                    cmap='Spectral')
            ax.grid()
            if labels is not None:
                ax.set_facecolor('gray')
                legend = ax.legend(*scatter.legend_elements(),
                                title="Labels",
                                framealpha=0.3,
                                prop={'size': 6})
            ax.set_xlabel(r'$\tilde{x}_{%d}$' % i)
            ax.set_ylabel(r'$\tilde{x}_{%d}$' % j)
            # ax.set_yscale('log')
            # ax.set_xscale('log')
    plt.savefig(results_path/'cluster_feature_space_{}.png'.format(desc),
                facecolor=fig.get_facecolor(),
                edgecolor='none')
    plt.close()
del a

In [None]:
encoder.set_weights(encoder_ft_param)
z = encoder(x).numpy()

# dbcl1 = DBSCAN(
#     #min_samples=5,
#     eps=0.1,
#     ).fit(z)
# zip([dbcl1.labels_], ['dbscan1']):
for labels, desc in zip(labels_list, descs_list):
    #print('Visualizing {} labels'.format(desc))
    a = z
    centroids = None
    if labels is not None:
        centroids = []
        n_clusters_found = len(np.unique(labels))
        for i in np.unique(labels):
            idx = (labels == i)
            centroids.append(compute_centroid_np(z[idx,:]))
        a = np.concatenate((z, centroids), axis=0)
    
    fig, axs = plt.subplots(dims[-1], dims[-1],
                            figsize=(8*dims[-1], 4*dims[-1]),
                            squeeze=False,
                            facecolor='white',
                            dpi=200)
    for i in range(dims[-1]):
        for j in range(dims[-1]):
            ax = axs[i, j]
            scatter = ax.scatter(a[:len(x), i],
                                a[:len(x), j],
                                c=labels,
                                s=[25]*len(x),
                                cmap='Spectral',
                                alpha=0.6)
            
            if centroids is not None:
                scatter1 = ax.scatter(a[len(x):, i],
                                    a[len(x):, j],
                                    c=range(n_clusters_found),
                                    marker='+',
                                    s=[1000]*n_clusters_found,
                                    cmap='Spectral')
            ax.grid()
            if labels is not None:
                ax.set_facecolor('gray')
                legend = ax.legend(*scatter.legend_elements(),
                                title="Labels",
                                framealpha=0.3,
                                prop={'size': 6})
            ax.set_xlabel(r'$\tilde{x}_{%d}$' % i)
            ax.set_ylabel(r'$\tilde{x}_{%d}$' % j)
            # ax.set_yscale('log')
            # ax.set_xscale('log')
    plt.savefig(results_path/'finetune_feature_space_{}.png'.format(desc),
                facecolor=fig.get_facecolor(),
                edgecolor='none')
    plt.close()
del a

In [None]:
labels

In [None]:
encoder.set_weights(encoder_param)
a = encoder(x).numpy()
fig, axs = plt.subplots(dims[-1], dims[-1],
                        figsize=(8*dims[-1], 4*dims[-1]),
                        squeeze=False,
                        facecolor='white',
                        dpi=200)
for i in range(dims[-1]):
    for j in range(dims[-1]):
        ax = axs[i, j]
        scatter = ax.scatter(a[:len(x), i],
                            a[:len(x), j],
                            s=[25]*len(x),
                            alpha=0.6)
        ax.grid()
        ax.set_xlabel(r'$\tilde{x}_{%d}$' % i)
        ax.set_ylabel(r'$\tilde{x}_{%d}$' % j)
        # ax.set_yscale('log')
        # ax.set_xscale('log')
plt.savefig(results_path/'pretrain_feature_space.png',
            facecolor=fig.get_facecolor(),
            edgecolor='none')
plt.close()
del a

In [None]:
fitters = {'KaplanMeierFitter': KaplanMeierFitter(),
           'WeibullFitter': WeibullFitter(),
            #'ExponentialFitter': ExponentialFitter(),
            #'LogNormalFitter': LogNormalFitter(),
            #'LogLogisticFitter': LogLogisticFitter(),
            #'PiecewiseExponentialFitter': PiecewiseExponentialFitter([40, 60]),
            #'GeneralizedGammaFitter': GeneralizedGammaFitter()
            #'SplineFitter': SplineFitter(T.loc[E.astype(bool)], [0, 50, 100])
            }
n_min = 10
times = outcomes[:, 0]
events = outcomes[:, 1]
for labels, desc in zip(labels_list[:-1], descs_list[:-1]):
    # loop on fitters
    for i, key in enumerate(fitters):
        fig, axes = plt.subplots(1, 1, figsize=(15, 8))
        axes.set_title('{} {}'.format(desc, key))
        # loop on labels
        for j, label in enumerate(np.unique(labels)):
            idx = (labels == label)
            unique, counts = np.unique(idx, return_counts=True)
            if counts[unique==True] > n_min:
                fitters[key].fit(times[idx], events[idx], label='label {}'.format(label))
                fitters[key].plot_survival_function(ax=axes, ci_show=False)
        axes.grid()
        plt.savefig(results_path/'{}_{}_lifelines.png'.format(key, desc))

In [None]:
## Example from EUROMDS
figure = plt.figure(figsize=(20, 4),
                    facecolor='white',)
j = 0
idx = np.random.permutation(len(x))
xx = x[idx, :]

for example in xx[:10]:
  plt.subplot(2, 5, j+1)
  plt.imshow(example.reshape(img_shape), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1
plt.savefig(results_path/'original_data.png',
            facecolor=fig.get_facecolor(),
            edgecolor='none')

In [None]:
## Example from EUROMDS from pretrained autoencoder
encoder.set_weights(encoder_param)
figure = plt.figure(figsize=(20, 4),
                    facecolor='white',)
j = 0

for example in autoencoder(xx[:10]).numpy():
  plt.subplot(2, 5, j+1)
  plt.imshow(example.reshape(img_shape), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1
plt.savefig(results_path/'pretrain_ae_data.png',
            facecolor=fig.get_facecolor(),
            edgecolor='none')

In [None]:
## Example from EUROMDS from finetuned autoencoder
encoder.set_weights(encoder_ft_param)
figure = plt.figure(figsize=(20, 4),
                    facecolor='white',)
j = 0

for example in autoencoder(xx[:10]).numpy():
  plt.subplot(2, 5, j+1)
  plt.imshow(example.reshape(img_shape), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1
plt.savefig(results_path/'finetune_ae_data.png',
            facecolor=fig.get_facecolor(),
            edgecolor='none')

In [None]:
## Example from EUROMDS from final model
encoder.set_weights(encoder_final_param)
figure = plt.figure(figsize=(20, 4),
                    facecolor='white',)
j = 0

for example in autoencoder(xx[:10]).numpy():
  plt.subplot(2, 5, j+1)
  plt.imshow(example.reshape(img_shape), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1
plt.savefig(results_path/'cluster_ae_data.png',
            facecolor=fig.get_facecolor(),
            edgecolor='none')