In [1]:
import os
original_path = os.getcwd()
root = '/rhome/yhu/bigdata/proj/experiment_GIST'
os.chdir(root)
new_path = os.getcwd()
print('redirect path: \n\t{} \n-->\t{}'.format(original_path, new_path))
print('root: {}'.format(root))

import warnings
warnings.filterwarnings('ignore')

from GIST.prepare.utils import load_hic, load_hic, iced_normalization
from GIST.visualize import display, load_data
from GIST.validation.utils import load_df_fish3d, fish3d_format, load_tad_bed
from GIST.validation.utils import pdist_3d, load_tad_bed, select_loci, remove_failed
from GIST.validation.validation_tad import select_structure3d
from GIST.validation.ab import normalizebydistance, decomposition, correlation, plot
from GIST.validation.ab import fit_genomic_spatial_func

import numpy as np
import torch
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.spatial.distance import cdist, pdist, squareform
import plotly.express as px

redirect path: 
	/bigdata/wmalab/yhu/proj/notes/exp_GIST 
-->	/bigdata/wmalab/yhu/proj/experiment_GIST
root: /rhome/yhu/bigdata/proj/experiment_GIST


Using backend: pytorch


In [2]:
validation_name = 'comparison'
chromosomes = ['22' , '21', '20']
# chromosomes = ['22']
methods = ['ShRec3D', 'LorDG', 'pastis', 'GEM', 'ChromSDE'] 
# methods = []

gist_best = {'20':39, '21':36, '22':33} #  '22':33
selected_pc = dict()
selected_pc['GIST'] = dict()
selected_pc['GIST']['chr20'] = [[0,0], 1 ]
selected_pc['GIST']['chr21'] = [[0,0], -1]
selected_pc['GIST']['chr22'] = [[0,0], -1]

selected_pc['ShRec3D'] = dict()
selected_pc['ShRec3D']['chr20'] = [[0,0], 1]
selected_pc['ShRec3D']['chr21'] = [[0,0], -1]
selected_pc['ShRec3D']['chr22'] = [[0,0], -1]

selected_pc['LorDG'] = dict()
selected_pc['LorDG']['chr20'] = [[0,1], 1]
selected_pc['LorDG']['chr21'] = [[0,0], -1]
selected_pc['LorDG']['chr22'] = [[1,2], 1]

selected_pc['pastis'] = dict()
selected_pc['pastis']['chr20'] = [[0,1], 1]
selected_pc['pastis']['chr21'] = [[0,0], -1]
selected_pc['pastis']['chr22'] = [[2,0], -1]

selected_pc['GEM'] = dict()
selected_pc['GEM']['chr20'] = [[0,1], -1]
selected_pc['GEM']['chr21'] = [[0,2], 1]
selected_pc['GEM']['chr22'] = [[2,1], 1 ]

selected_pc['ChromSDE'] = dict()
selected_pc['ChromSDE']['chr20'] = [[2,1], -1 ]
selected_pc['ChromSDE']['chr21'] = [[1,2], -1]
selected_pc['ChromSDE']['chr22'] = [[1,1], 1 ]

results_name = dict()
results_name['ShRec3D'] = 'conformation.xyz'
results_name['pastis'] = 'PM2.structure'
results_name['GEM'] = 'conformation_'
results_name['LorDG'] = 'output/conformation.xyz'
results_name['ChromSDE'] = 'conformation.xyz'

In [3]:
ab_max_all = dict()
ab_max_tad = dict()
pc_loci = dict()
pc_tad = dict()

for chrom in chromosomes:
    # load config .json
    configuration_path = os.path.join(root, 'data')
    configuration_name = 'config_predict_{}.json'.format(chrom)
    info, config_data = load_data.load_configuration(configuration_path, configuration_name)
    resolution = info['resolution']

    # load dataset
    dataset_path = os.path.join(root, 'data', info['cell'], info['hyper'])
    dataset_name = 'dataset.pt'
    HD = load_data.load_dataset(dataset_path, dataset_name)
    graph, feat, ratio, indx  = HD[0]
    raw_id = graph['top_graph'].ndata['id'].cpu().numpy()
    rmNaN_id = np.arange(len(raw_id))
    raw2rm = {} # raw -> rm id
    rm2raw = {} # rm -> raw id
    for A, B in zip(raw_id, rmNaN_id):
        raw2rm[A] = B
        rm2raw[B] = A

    # load prediction
    prediction_path = os.path.join(root, 'data', info['cell'], info['hyper'], 'output')
    prediction_name = 'prediction.pkl'
    prediction = load_data.load_prediction(prediction_path, prediction_name)

    # assignment
    structures = dict()
    structures['GIST'] = prediction['{}_0'.format(chrom)]['structures']
    xweights = prediction['{}_0'.format(chrom)]['structures_weights'].astype(float).flatten()
    true_cluster = np.array(prediction['{}_0'.format(chrom)]['true_cluster'])
    predict_cluster = np.array(prediction['{}_0'.format(chrom)]['predict_cluster'][0])
    print( 'GIST structure shape: ', structures['GIST'].shape )

    # load .cool
    cool_file = info['cool_file']
    cool_data_path = os.path.join(root, 'data', 'raw')
    file = os.path.join(cool_data_path, cool_file)
    raw_hic, _, cooler = load_data.load_hic( name=file, chromosome='chr{}'.format(str(info['chromosome'])))
    # norm_hic = iced_normalization(raw_hic)

    mat = raw_hic # normalizebydistance(avg_fishTAD_3d, genomic_index=None)
    print('Hi-C shape with NAN: {}'.format(mat.shape) )
    mat = normalizebydistance(mat)
    mat = mat[raw_id,:]
    mat = mat[:, raw_id]
    np.fill_diagonal(mat, 0)
    print('Hi-C shape after rm NAN: {}'.format(mat.shape) )
    corr = correlation(mat, center=True)
    hic_pc = decomposition(corr, method='PCA', nc=3)

    for i, m in enumerate(methods):
        re = info['cool_file'].split('.')[1]
        mpath = os.path.join(root, 'comparison', m.lower(), info['cell'], re, chrom)
        if m.lower() != 'gem':
            file = os.path.join(mpath, results_name[m])
            if m.lower() != 'lordg':
                x = np.loadtxt(file, dtype=np.float).reshape(-1,1,3)
            elif m.lower() =='lordg':
                s = np.loadtxt(file, dtype=np.float).reshape(-1,1,4)
                x = s[:,:,0:3]
        else:
            x = list()
            for k in np.arange(1,5):
                file = os.path.join(mpath, results_name[m]+'{}.txt'.format(k))
                x.append(np.loadtxt(file, dtype=np.float).reshape(-1,1,3))
            x = np.concatenate(x, axis=1)
        structures[m] = x
    
    pc_loci[chrom] = dict()
    pc_loci[chrom]['hic'] = hic_pc
    pc_tad[chrom] = dict() 
    X_pc = dict()
    x_pdist = dict()
    
    for key, data3d in structures.items():
        K = data3d.shape[1]
        res = np.empty((raw_hic.shape[0], data3d.shape[1], 3))
        res.fill(np.nan)
        res[raw_id, :, :] = data3d
        x_3d = res.transpose( (1, 0, 2) )
        x_pdist[key] = pdist_3d(x_3d)
        if key == 'GIST':
            for k in np.arange(K):
                x_pdist[key][k,:,:] = xweights[k]*x_pdist[key][k,:,:]*len(xweights)

    X_corr = dict()
    for i, (key, pmat) in enumerate(x_pdist.items()):
        mat = np.nanmean(pmat, axis=0) # normalizebydistance(avg_fishTAD_3d, genomic_index=None)
        mat = normalizebydistance(mat)
        mat = mat[raw_id,:]
        mat = mat[:, raw_id]
        np.nan_to_num(mat, 0)
        rid = (mat.sum(axis=0) !=0 )
        mat = mat[rid,:]
        mat = mat[:,rid]
        corr = correlation(mat, center=False)
        pc = decomposition(corr, method='PCA', nc=3)
        X_pc[key] = np.zeros((len(raw_id),3))
        X_pc[key][rid,:] = pc
        X_corr[key] = corr
        pc_loci[chrom][key] = np.zeros((len(raw_id),3))
        pc_loci[chrom][key][rid,:] = pc
        if key == 'GIST':
            for j in np.arange(pmat.shape[0]):
                print( '{} GIST structure shape: '.format(j), structures['GIST'].shape )
                mat = pmat[j,:,:]
                mat = normalizebydistance(mat)
                mat = mat[raw_id,:]
                mat = mat[:, raw_id]
                np.nan_to_num(mat, 0)
                rid = (mat.sum(axis=0) !=0 )
                mat = mat[rid,:]
                mat = mat[:,rid]
                corr = correlation(mat, center=False)
                pc = decomposition(corr, method='PCA', nc=3)
                X_pc['GIST_{}'.format(j)] = np.zeros((len(raw_id),3))
                X_pc['GIST_{}'.format(j)][rid,:] = pc
                X_corr['GIST_{}'.format(j)] = corr
                pc_loci[chrom]['GIST_{}'.format(j)] = np.zeros((len(raw_id),3))
                pc_loci[chrom]['GIST_{}'.format(j)][rid,:] = pc

    ab_max_all[chrom] = dict()
    path = os.path.join(root, 'data', 'FISH', 'loci_position')
    name = 'hg19_Chr{}.bed'.format(chrom)
    df = load_tad_bed(path, name)
    select_idx = select_loci(df, resolution)
    for i, idx in enumerate(select_idx):
        arr = np.intersect1d(raw_id, idx)
        select_idx[i] = np.array([ raw2rm[x] for x in arr ] )

    for _, (key, x_pc) in enumerate(X_pc.items()):
        mat = np.ones((hic_pc.shape[1], hic_pc.shape[1]) )
        for i in np.arange(hic_pc.shape[1]):
            for j in np.arange(hic_pc.shape[1]):        
                fb = hic_pc[:,i] > 0
                xb = x_pc[:, j] > 0
                mat[i,j] = np.equal(fb, xb).sum()/len(xb)
        for _, (selected_key, _) in enumerate(selected_pc.items()):
            if selected_key in key:
                x,y = selected_pc[selected_key]['chr{}'.format(chrom)][0]
                ab_max_all[chrom][key] = max(1-mat[x,y], mat[x,y])
        print(ab_max_all[chrom][key], key)

    
    tad_idx = np.empty(0)
    center_tad_idx = np.ones(len(select_idx))
    for i, idx in enumerate(select_idx):
        center_tad_idx[i] = np.mean(idx).astype(int)
        tad_idx = np.append( tad_idx, idx)
    tad_idx = np.asarray(tad_idx, dtype=int)
    ab_max_tad[chrom] = dict()
    pc_tad[chrom]['hic'] = hic_pc[center_tad_idx.astype(int),:]
    for _, (key, x_pc) in enumerate(X_pc.items()):
        mat = np.ones((hic_pc.shape[1], hic_pc.shape[1]) )
        idx = center_tad_idx.astype(int)
        
        pc_tad[chrom][key] = x_pc[idx,:]
        for i in np.arange(hic_pc.shape[1]):
            for j in np.arange(hic_pc.shape[1]):        
                fb = hic_pc[idx,i] > 0
                xb = x_pc[idx, j] > 0
                mat[i,j] = np.equal(fb, xb).sum()/len(xb)
        for _, (selected_key, _) in enumerate(selected_pc.items()):
            if selected_key in key:
                x,y = selected_pc[selected_key]['chr{}'.format(chrom)][0]
                ab_max_tad[chrom][key] = max(1-mat[x,y], mat[x,y])
        print(ab_max_tad[chrom][key], key)

GIST structure shape:  (3161, 40, 3)
Hi-C shape with NAN: (5131, 5131)
Hi-C shape after rm NAN: (3161, 3161)
0 GIST structure shape:  (3161, 40, 3)
1 GIST structure shape:  (3161, 40, 3)
2 GIST structure shape:  (3161, 40, 3)
3 GIST structure shape:  (3161, 40, 3)
4 GIST structure shape:  (3161, 40, 3)
5 GIST structure shape:  (3161, 40, 3)
6 GIST structure shape:  (3161, 40, 3)
7 GIST structure shape:  (3161, 40, 3)
8 GIST structure shape:  (3161, 40, 3)
9 GIST structure shape:  (3161, 40, 3)
10 GIST structure shape:  (3161, 40, 3)
11 GIST structure shape:  (3161, 40, 3)
12 GIST structure shape:  (3161, 40, 3)
13 GIST structure shape:  (3161, 40, 3)
14 GIST structure shape:  (3161, 40, 3)
15 GIST structure shape:  (3161, 40, 3)
16 GIST structure shape:  (3161, 40, 3)
17 GIST structure shape:  (3161, 40, 3)
18 GIST structure shape:  (3161, 40, 3)
19 GIST structure shape:  (3161, 40, 3)
20 GIST structure shape:  (3161, 40, 3)
21 GIST structure shape:  (3161, 40, 3)
22 GIST structure sha

In [4]:
pd.DataFrame.from_dict(ab_max_all, orient='index')
pd.DataFrame.from_dict(ab_max_all, orient='index').style.highlight_max(color = 'yellow', axis = 1)

Unnamed: 0,GIST,GIST_0,GIST_1,GIST_2,GIST_3,GIST_4,GIST_5,GIST_6,GIST_7,GIST_8,GIST_9,GIST_10,GIST_11,GIST_12,GIST_13,GIST_14,GIST_15,GIST_16,GIST_17,GIST_18,GIST_19,GIST_20,GIST_21,GIST_22,GIST_23,GIST_24,GIST_25,GIST_26,GIST_27,GIST_28,GIST_29,GIST_30,GIST_31,GIST_32,GIST_33,GIST_34,GIST_35,GIST_36,GIST_37,GIST_38,GIST_39,ShRec3D,LorDG,pastis,GEM,ChromSDE
22,0.745966,0.605821,0.65802,0.645049,0.527365,0.590952,0.576716,0.751028,0.538121,0.751977,0.603923,0.613413,0.538437,0.703258,0.707687,0.551091,0.512496,0.637773,0.660867,0.612464,0.56596,0.829801,0.575451,0.796583,0.686492,0.617842,0.777602,0.574502,0.716229,0.591901,0.593799,0.7292,0.534957,0.589687,0.792787,0.714015,0.524518,0.529579,0.615628,0.683961,0.636507,0.730149,0.737425,0.529896,0.616577,0.573869
21,0.832402,0.80509,0.573557,0.884544,0.772502,0.612042,0.854128,0.735258,0.80509,0.778088,0.637803,0.742086,0.782744,0.780261,0.583489,0.698945,0.711359,0.662011,0.801055,0.707014,0.787089,0.677529,0.6527,0.62725,0.687461,0.679392,0.787089,0.7036,0.831471,0.718187,0.730602,0.743327,0.605214,0.753569,0.759466,0.860335,0.505897,0.88144,0.573246,0.761949,0.783054,0.786468,0.722843,0.604904,0.555556,0.646493
20,0.746535,0.722453,0.646743,0.671864,0.819127,0.505371,0.769058,0.635482,0.774428,0.681566,0.642758,0.811157,0.553881,0.700104,0.655405,0.720894,0.664934,0.673597,0.669612,0.599619,0.6114,0.735793,0.536729,0.800762,0.761088,0.657138,0.501906,0.521137,0.569993,0.761608,0.755024,0.519058,0.812543,0.557519,0.741857,0.512301,0.523216,0.758316,0.621275,0.758836,0.873008,0.795392,0.630804,0.515419,0.5149,0.537249


In [5]:
df.to_csv(index=True)

',Chromosome,Start,End\n0,20,60000,1492000\n1,20,2652000,3052000\n2,20,4132000,5172000\n3,20,6092000,7892000\n4,20,8732000,9452000\n5,20,10412000,11572000\n6,20,14052000,16772000\n7,20,17452000,18292000\n8,20,19652000,20052000\n9,20,20812000,22252000\n10,20,24052000,24932000\n11,20,31976339,32256339\n12,20,32896339,33136339\n13,20,33376339,34296586\n14,20,34536586,34776586\n15,20,35806586,36166586\n16,20,37206586,39606586\n17,20,40326586,42046586\n18,20,42846586,43526586\n19,20,45326593,45806593\n20,20,46286593,47846593\n21,20,48566593,49566593\n22,20,49966593,50206593\n23,20,50806593,52166593\n24,20,52846593,53286593\n25,20,54926593,55926593\n26,20,56766594,58126605\n27,20,60566605,60726605\n28,20,61449555,62129556\n'

In [6]:
pd.DataFrame.from_dict(ab_max_tad, orient='index')
pd.DataFrame.from_dict(ab_max_tad, orient='index').style.highlight_max(color = 'yellow', axis = 1)

Unnamed: 0,GIST,GIST_0,GIST_1,GIST_2,GIST_3,GIST_4,GIST_5,GIST_6,GIST_7,GIST_8,GIST_9,GIST_10,GIST_11,GIST_12,GIST_13,GIST_14,GIST_15,GIST_16,GIST_17,GIST_18,GIST_19,GIST_20,GIST_21,GIST_22,GIST_23,GIST_24,GIST_25,GIST_26,GIST_27,GIST_28,GIST_29,GIST_30,GIST_31,GIST_32,GIST_33,GIST_34,GIST_35,GIST_36,GIST_37,GIST_38,GIST_39,ShRec3D,LorDG,pastis,GEM,ChromSDE
22,0.769231,0.576923,0.692308,0.538462,0.538462,0.5,0.653846,0.884615,0.5,0.807692,0.538462,0.615385,0.5,0.576923,0.576923,0.5,0.5,0.5,0.692308,0.730769,0.538462,0.807692,0.5,0.846154,0.692308,0.538462,0.884615,0.538462,0.807692,0.615385,0.576923,0.692308,0.538462,0.576923,0.923077,0.769231,0.538462,0.576923,0.730769,0.653846,0.730769,0.730769,0.769231,0.653846,0.653846,0.653846
21,0.818182,0.787879,0.545455,0.909091,0.878788,0.606061,0.909091,0.727273,0.818182,0.787879,0.575758,0.69697,0.757576,0.818182,0.515152,0.666667,0.69697,0.545455,0.757576,0.666667,0.787879,0.69697,0.666667,0.515152,0.606061,0.666667,0.757576,0.757576,0.848485,0.69697,0.757576,0.69697,0.606061,0.727273,0.727273,0.818182,0.515152,0.939394,0.575758,0.727273,0.666667,0.757576,0.636364,0.606061,0.727273,0.69697
20,0.724138,0.62069,0.62069,0.689655,0.758621,0.586207,0.793103,0.62069,0.62069,0.62069,0.724138,0.758621,0.62069,0.724138,0.62069,0.62069,0.655172,0.586207,0.586207,0.655172,0.62069,0.724138,0.551724,0.793103,0.689655,0.655172,0.62069,0.551724,0.586207,0.758621,0.758621,0.586207,0.724138,0.689655,0.724138,0.551724,0.517241,0.793103,0.586207,0.758621,0.896552,0.827586,0.62069,0.62069,0.62069,0.724138


In [7]:
for key, value in pc_loci.items():
    print(pc_loci[key].keys(), key)

dict_keys(['hic', 'GIST', 'GIST_0', 'GIST_1', 'GIST_2', 'GIST_3', 'GIST_4', 'GIST_5', 'GIST_6', 'GIST_7', 'GIST_8', 'GIST_9', 'GIST_10', 'GIST_11', 'GIST_12', 'GIST_13', 'GIST_14', 'GIST_15', 'GIST_16', 'GIST_17', 'GIST_18', 'GIST_19', 'GIST_20', 'GIST_21', 'GIST_22', 'GIST_23', 'GIST_24', 'GIST_25', 'GIST_26', 'GIST_27', 'GIST_28', 'GIST_29', 'GIST_30', 'GIST_31', 'GIST_32', 'GIST_33', 'GIST_34', 'GIST_35', 'GIST_36', 'GIST_37', 'GIST_38', 'GIST_39', 'ShRec3D', 'LorDG', 'pastis', 'GEM', 'ChromSDE']) 22
dict_keys(['hic', 'GIST', 'GIST_0', 'GIST_1', 'GIST_2', 'GIST_3', 'GIST_4', 'GIST_5', 'GIST_6', 'GIST_7', 'GIST_8', 'GIST_9', 'GIST_10', 'GIST_11', 'GIST_12', 'GIST_13', 'GIST_14', 'GIST_15', 'GIST_16', 'GIST_17', 'GIST_18', 'GIST_19', 'GIST_20', 'GIST_21', 'GIST_22', 'GIST_23', 'GIST_24', 'GIST_25', 'GIST_26', 'GIST_27', 'GIST_28', 'GIST_29', 'GIST_30', 'GIST_31', 'GIST_32', 'GIST_33', 'GIST_34', 'GIST_35', 'GIST_36', 'GIST_37', 'GIST_38', 'GIST_39', 'ShRec3D', 'LorDG', 'pastis', 'GEM'

In [8]:
pd.DataFrame.from_dict(ab_max_tad, orient='index').to_csv(path_or_buf='/bigdata/wmalab/yhu/proj/notes/exp_GIST/results/ab_PC_best_tad.csv', index=True)
pd.DataFrame.from_dict(ab_max_all, orient='index').to_csv(path_or_buf='/bigdata/wmalab/yhu/proj/notes/exp_GIST/results/ab_PC_best_all.csv', index=True)

In [9]:
path='/bigdata/wmalab/yhu/proj/notes/exp_GIST/results/pc'
np.savez_compressed(path, loci=pc_loci, tad=pc_tad, allow_pickle=True)

In [10]:
# path='/bigdata/wmalab/yhu/proj/notes/exp_GIST/results/pc.npz'
# d = np.load(path, allow_pickle=True)
# d['loci'].item()['21']['hic']