In [None]:
import numpy as np
import pandas as pd
import os
import pickle
from scipy import stats
import base_functions as bf
import cluster_base_functions as cbf

from sklearn.cluster import KMeans
from sklearn.impute import SimpleImputer
from sklearn import metrics

import warnings
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
# plt.rcParams["font.family"] = "Arial"
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
%matplotlib inline

In [None]:
D = pickle.load(open('./median_data/combat_cluster_multi_modality_raw_data.pkl', 'rb'))
struc_feat = D['struc_feat']
conn_feat = D['conn_feat']
act_feat = D['act_feat']

basic_DF = D['basic_DF']
subjects = D['subjects']

struc_feat = stats.zscore(struc_feat, axis=0)
act_feat = stats.zscore(act_feat, axis=0)
conn_feat = stats.zscore(conn_feat, axis=0)

care_dx = 1 # 1: only care patients; 0: only care HC; -1: use all participants
dx_labels = bf.get_subject_info(basic_DF, subjects, ['diagnosis_group'])
cluster_vals_list = [struc_feat, act_feat, conn_feat]
cluster_vals_list = [c[dx_labels==care_dx] for c in cluster_vals_list]
cluster_subjects = [subjects[i] for i,v in enumerate(dx_labels) if v==care_dx]

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit

## resampling-based clustering
repeat_num = 100
clust_num = 2
subj_num = cluster_vals_list[0].shape[0]
type_num = len(cluster_vals_list)
all_cluster_labels = -1*np.ones((type_num, subj_num, repeat_num))


cvg = ShuffleSplit(n_splits=repeat_num, test_size=0.2, train_size=0.8)


for irep, (tr_idx, ts_idx) in enumerate(cvg.split(cluster_vals_list[0])):
    print(irep)
    for ifeat in range(len(cluster_vals_list)):
        
        feat = cluster_vals_list[ifeat][tr_idx].copy()
        feat = stats.zscore(feat, axis=0)
        feat[np.isnan(feat)] = 0
    

        cluster = KMeans(n_clusters=clust_num, random_state=1, max_iter=1000).fit(feat)
        all_cluster_labels[ifeat,tr_idx,irep] = cluster.labels_

# if care_dx==0:
#     svfile = './results/cluster_multi_modality_2clusters_CV_control.pkl'
# elif care_dx==1:
#     svfile = './results/cluster_multi_modality_2clusters_CV_patient.pkl'
# pickle.dump({'all_cluster_labels':all_cluster_labels, 
#              'subjects':subjects}, 
#              open(svfile, 'wb'))

In [None]:
## For each modality, estimate ARI and AMI between resampling-based subtypes
all_aris = []
all_amis = []
for itp in range(type_num):
    print(itp)
    aris = []
    amis = []
    for i in range(repeat_num-1):
        for j in range(i+1,repeat_num):
            idx = np.all(all_cluster_labels[itp,:,[i,j]]>=0, axis=0)
            v = metrics.adjusted_rand_score(all_cluster_labels[itp,idx,i], all_cluster_labels[itp,idx,j])
            t = metrics.adjusted_mutual_info_score(all_cluster_labels[itp,idx,i], all_cluster_labels[itp,idx,j])
            aris.append(v)
            amis.append(t)
    all_aris.append(aris)
    all_amis.append(amis)
all_aris = np.stack(all_aris, axis=0)
all_aris.mean(axis=1)
all_amis = np.stack(all_amis, axis=0)
all_amis.mean(axis=1)

In [None]:
## plot figures for ARI and AMI of each modality
import seaborn as sns
colors = sns.color_palette('deep', as_cmap=True)
colors = [[87, 99, 111], [146, 129, 135], [127, 159, 174]]
colors = [[i/255  for i in c] for c in colors]
feat_names = ['anat', 'act', 'conn']
for ifeat in range(type_num):
    vals = np.concatenate((all_aris[ifeat],all_amis[ifeat]))
    metric = len(all_aris[ifeat])*['ARI'] + len(all_amis[ifeat])*['AMI']
    mdf = pd.DataFrame({'metric':metric, 'vals':vals})
    fig, ax = plt.subplots(figsize=(6,6))
    # sns.histplot(vals, kde=True, bins=20, line_kws={'linewidth':3}, alpha = 0.2, edgecolor='gray', ax=ax, color=colors[3])
    sns.boxplot(data=mdf, x='metric', y='vals', orient='v', width=0.5, color=colors[ifeat])
    # ax.set_xlim([-0.06,0.06])
    ax.set_ylim([0,1.01])
    # ax.set_xticks(np.arange(0.2, 1.01, 0.2))
    # ax.set_yticks(np.arange(0, 601, 200))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    # ax.tick_params(axis='both', which='major', labelsize=10)
    # ax.set_xlabel('Adjusted Rand Index', fontsize=12)
    # ax.set_ylabel('Count', fontsize=12)
    # fig.savefig(f'./results/single_modality_dist_ari_{feat_names[ifeat]}.pdf', bbox_inches="tight")

In [None]:
## estimate ARI and AMI between different modality-based subtypes
btw_modality_ari = np.zeros((3,repeat_num))
btw_modality_ami = np.zeros((3,repeat_num))
for i in range(repeat_num):
    idx01 = np.all(all_cluster_labels[[0,1],:,i]>=0, axis=0)
    idx02 = np.all(all_cluster_labels[[0,2],:,i]>=0, axis=0)
    idx12 = np.all(all_cluster_labels[[1,2],:,i]>=0, axis=0)
    v01 = metrics.adjusted_rand_score(all_cluster_labels[0,idx01,i], all_cluster_labels[1,idx01,i])
    v02 = metrics.adjusted_rand_score(all_cluster_labels[0,idx02,i], all_cluster_labels[2,idx02,i])
    v12 = metrics.adjusted_rand_score(all_cluster_labels[1,idx12,i], all_cluster_labels[2,idx12,i])
    
    t01 = metrics.adjusted_mutual_info_score(all_cluster_labels[0,idx01,i], all_cluster_labels[1,idx01,i])
    t02 = metrics.adjusted_mutual_info_score(all_cluster_labels[0,idx02,i], all_cluster_labels[2,idx02,i])
    t12 = metrics.adjusted_mutual_info_score(all_cluster_labels[1,idx12,i], all_cluster_labels[2,idx12,i])
    
    btw_modality_ari[0,i] = v01
    btw_modality_ari[1,i] = v02
    btw_modality_ari[2,i] = v12
    
    btw_modality_ami[0,i] = t01
    btw_modality_ami[1,i] = t02
    btw_modality_ami[2,i] = t12

In [None]:
# plot figures
vals0 = np.concatenate([v for v in btw_modality_ari], axis=0)
vals1 = np.concatenate([v for v in btw_modality_ami], axis=0)
vals = np.concatenate((vals0,vals1), axis=0)

btw_cat = repeat_num*['Anat&Act'] + repeat_num*['Anat&Conn'] + repeat_num*['Act&Conn']
btw_cat = btw_cat*2
metric = len(vals0)*['ARI'] + len(vals0)*['AMI']
mdf = pd.DataFrame({'metric':metric, 'vals':vals, 'btw_cat':btw_cat})
fig, ax = plt.subplots(figsize=(6,6))
# sns.histplot(vals, kde=True, bins=20, line_kws={'linewidth':3}, alpha = 0.2, edgecolor='gray', ax=ax, color=colors[3])
sns.boxplot(data=mdf, x='metric', y='vals', hue='btw_cat', orient='v', width=0.5, palette='deep')
ax.set_ylim([-0.025,0.05])
ax.set_yticks(np.arange(-0.025, 0.051, 0.025));
ax.plot([-.5,1.5], [0,0], linestyle='--', linewidth=1, color='gray')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# fig.savefig(f'./results/between_modality_dist_ari.pdf', bbox_inches="tight")