In [None]:
import pandas as pd
import os
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples, silhouette_score
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

import concurrent.futures


def make_dir(d):
    '''Utility for making a directory if not existing.'''
    if not os.path.exists(d):
        os.makedirs(d)
        
INTER_ALL_FORM_PATH = '../cluster/inter_clusting/'

INTER_ALL_FORM_OUTPATH = '../output/inter_cluster_res/'
make_dir(INTER_ALL_FORM_OUTPATH)
SUBCLUSTER_OUTPATH = '../output/inter_cluster_res/sub_cluster/'
make_dir(SUBCLUSTER_OUTPATH)


tads = ['chr18.10530000.11160000','chr16.85845000.86580000','chr2.230805000.231690000']

number_of_cluster_list = [4,5,6,7]

def clustering_3cell(tad):
    
    start_posi = tad.split('.')[1]
    end_posi = tad.split('.')[2]
    resolution = 5000
    beads_number = (int(end_posi)-int(start_posi))/resolution
    
    GM_single_chain_df = pd.read_csv(INTER_ALL_FORM_PATH+'GM.'+tad+'.csv')
    K_single_chain_df = pd.read_csv(INTER_ALL_FORM_PATH+'H.'+tad+'.csv')
    IMR_single_chain_df = pd.read_csv(INTER_ALL_FORM_PATH+'IMR.'+tad+'.csv')


    print("data loaded!")

    ## inter chains
    all_chains_list = [GM_single_chain_df,K_single_chain_df,IMR_single_chain_df]
    all_chains_df = pd.concat(all_chains_list, axis=1)
    all_chains_array = all_chains_df.T.to_numpy()


    ## K means clustering
    print('clustering ...')

    for number_of_cluster in number_of_cluster_list:
        ## cluster based on inter
        clustering_result = KMeans(n_clusters=number_of_cluster, random_state=0).fit(all_chains_array)
        label_list = list(clustering_result.labels_)

        ## concat the lable to the chain dataframe
        all_chains_cluster = all_chains_df.T
        all_chains_cluster['label'] = label_list
        all_chains_cluster = all_chains_cluster.reset_index()
        print(all_chains_cluster.head())



        all_chains_cluster.to_csv(INTER_ALL_FORM_OUTPATH+tad+'_'+str(number_of_cluster)+'.csv',index = False)
        print('cluster result Saved!')
        
        for cluster_label in range(number_of_cluster):
            print(cluster_label)
            ## select culster based on the label
            cluster = all_chains_cluster[all_chains_cluster.label==cluster_label]
            
            ## get all the chain information
            cluster_chains_info = cluster.iloc[:,1:len(all_chains_cluster.columns)-1]
           
            ## get the heatmap frequency info 
            fq_list = cluster_chains_info.sum()/len(cluster_chains_info)
        
            fq_row = []
            fq_col = []
            for i in range(int(beads_number)):
                for j in range(i,int(beads_number)):
                    fq_row.append(i)
                    fq_col.append(j)

            fq_heatmap_df = pd.DataFrame(list(zip(fq_row, fq_col,fq_list)),
                           columns =['i', 'j','fq'])

            fq_heatmap_df.to_csv(SUBCLUSTER_OUTPATH+tad+'_'+str(number_of_cluster)+'_'+str(cluster_label)+'_fq.csv',index=False)
            print('heatmap saved')
            
with concurrent.futures.ProcessPoolExecutor() as executor:
    executor.map(clustering_3cell, tads)
