In [None]:
import numpy as np
import pandas as pd
import os
import pickle
from scipy import stats
import base_functions as bf
import cluster_base_functions as cbf

from sklearn.cluster import KMeans
from sklearn.impute import SimpleImputer
from sklearn import metrics

import warnings
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
# plt.rcParams["font.family"] = "Arial"
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
%matplotlib inline

In [None]:
D = pickle.load(open('./median_data/combat_cluster_multi_modality_raw_data.pkl', 'rb'))
struc_feat = D['struc_feat']
conn_feat = D['conn_feat']
act_feat = D['act_feat']

basic_DF = D['basic_DF']
subjects = D['subjects']

In [None]:
## for each feature modality, determine the optimal cluster number

care_dx = 1 # 1: only care patients; 0: only care HC; -1: use all participants
dx_labels = bf.get_subject_info(basic_DF, subjects, ['diagnosis_group'])
cluster_vals_list = [struc_feat, act_feat, conn_feat]
cluster_vals_list = [c[dx_labels==care_dx] for c in cluster_vals_list]
cluster_subjects = [subjects[i] for i,v in enumerate(dx_labels) if v==care_dx]

for ifeat in range(len(cluster_vals_list)):
    feat = cluster_vals_list[ifeat].copy()
    feat = stats.zscore(feat, axis=0)
    feat[np.isnan(feat)] = 0


    clust_range = np.arange(2, 10)
    CHI_scores = np.zeros((len(clust_range), 2))
    for iclust, clust_num in enumerate(clust_range):
        cluster = KMeans(n_clusters=clust_num, random_state=1).fit(feat)
        CHI_scores[iclust,0] = metrics.silhouette_score(feat, cluster.labels_)
        CHI_scores[iclust,1] = metrics.calinski_harabasz_score(feat, cluster.labels_)

    feat_names = ['anat', 'act', 'conn']
    linewidth = 3
    xylab_size = 15
    xytick_size = 12
    fig, ax1 = plt.subplots(figsize=(6,6))
    color = 'tab:red'
    ax1.set_xlabel('Cluster number', fontsize=xylab_size)
    ax1.set_ylabel('Silhouette score', color=color, fontsize=xylab_size)
    ax1.plot(clust_range, CHI_scores[:,0], color=color, linewidth=linewidth, marker='*', markersize=15)
    ax1.tick_params(axis='y', labelcolor=color)
    ystep = 0.05
    if ifeat==2:
        ystep = 0.1
    ymax = (CHI_scores[:,0].max()//ystep + 1)*ystep
    ax1.set_ylim([0,ymax])
    ax1.set_yticks(np.arange(0, ymax+0.01, ystep))
    ax1.tick_params(axis='both', which='major', labelsize=xytick_size)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    color = 'tab:blue'
    ax2.set_ylabel('Calinskin-Harabasz score', color=color, fontsize=xylab_size)  # we already handled the x-label with ax1
    ax2.plot(clust_range, CHI_scores[:,1], color=color, linewidth=linewidth, linestyle='--', marker='o', markersize=10)
    ax2.tick_params(axis='y', labelcolor=color)
    ystep = 10
    if ifeat==2:
        ystep = 50
    ymax = (CHI_scores[:,1].max()//ystep + 1)*ystep
    ax2.set_ylim([0,ymax])
    ax2.set_yticks(np.arange(0, ymax+0.01, ystep))
    ax2.tick_params(axis='both', which='major', labelsize=xytick_size)

    fig.savefig(f'./results/clusters_Silhouette_scores_{feat_names[ifeat]}.pdf', bbox_inches='tight')

In [None]:
clust_num = 2
cluster_labels_list = []
for ifeat in range(len(cluster_vals_list)):
    feat = cluster_vals_list[ifeat].copy()
    feat = stats.zscore(feat, axis=0)
    feat[np.isnan(feat)] = 0
    cluster = KMeans(n_clusters=clust_num, random_state=1, max_iter=1000).fit(feat)
    if ifeat==0:
        ref_labels = cluster.labels_
        cluster_labels_list.append(ref_labels)
    else:
        alt_lab = cbf.reorder_clustering_labels(ref_labels, cluster.labels_)
        cluster_labels_list.append(alt_lab)
        
if care_dx==0:
    svfile = './results/cluster_multi_modality_2clusters_binary_control.pkl'
elif care_dx==1:
    svfile = './results/cluster_multi_modality_2clusters_binary_patient.pkl'
pickle.dump({'cluster_labels_list':cluster_labels_list, 'cluster_subjects':cluster_subjects, 'cluster_vals_list':cluster_vals_list}, 
             open(svfile, 'wb'))

In [None]:
print(metrics.adjusted_rand_score(cluster_labels_list[0], cluster_labels_list[1]))
print(metrics.adjusted_rand_score(cluster_labels_list[0], cluster_labels_list[2]))
print(metrics.adjusted_rand_score(cluster_labels_list[1], cluster_labels_list[2]))