In [1]:
import torch
import torch.nn.functional as F
import pickle
from utils.dataset import load_data_pbulk
from algo.Hcformer_pretrain_1d import Hcformer # just change here for old algo version
from algo.module import pearson_corr_coef
from torch import nn
from tqdm import tqdm
from typing import List
from scipy.sparse import coo_matrix

In [2]:
data_path = './data'
seed = 0
num_workers = 8
target_length = 240
dim = 768
depth = 11
heads = 8
output_heads = 1
hic_1d_feat_num = 5

add_hic_1d = False
add_hic_2d = False
batch_size = 32
gpu = [3, 7]

In [3]:
if len(gpu) > 0:
    device = torch.device(f"cuda:{gpu[0]}")
    print(f"Device is {gpu}")
else:
    device = torch.device(f"cuda:{gpu}" if (torch.cuda.is_available() and gpu >= 0) else "cpu")
    print(f"Device is {device}.")

Device is [3, 7]


In [4]:
def sparse_to_torch(coo_matrix: List[coo_matrix]):
    dense_matrix = []
    for m in coo_matrix:
        m = m.toarray()
        m = m + m.T
        m /= torch.ones(400) + torch.eye(400)
        dense_matrix.append(m)
    return torch.stack(dense_matrix, dim=0)

def evaluation(model, data_loader, device):
    model.eval()
    with torch.no_grad():
        with tqdm(total=len(data_loader), dynamic_ncols=True) as t:
            t.set_description('Evaluation: ')
            total_pred = []
            total_exp = []
            for item in data_loader:
                if add_hic_1d and add_hic_2d:
                    seq, exp, hic_1d, hic_2d = item[0].to(device), item[1], item[2].to(device), item[3]
                    hic_2d = sparse_to_torch(hic_2d).to(device)
                elif add_hic_1d:
                    seq, exp, hic_1d = item[0].to(device), item[1], item[2].to(device)
                    hic_2d = None
                elif add_hic_2d:
                    seq, exp, hic_2d = item[0].to(device), item[1], item[2]
                    hic_1d = None
                    hic_2d = sparse_to_torch(hic_2d).to(device)
                else:
                    seq, exp = item[0].to(device), item[1].to(device)
                    hic_1d, hic_2d = None, None
                pred = model(seq, head='human', hic_1d=hic_1d, hic_2d=hic_2d)

                total_pred.append(pred.detach().cpu())
                total_exp.append(exp.unsqueeze(-1))
                t.update()
            total_pred = torch.concat(total_pred, dim=0)
            total_exp  = torch.concat(total_exp,  dim=0)

    return pearson_corr_coef(total_pred, total_exp)[0]

In [5]:
_, _, test_loader = load_data_pbulk(
    path = data_path, 
    seed = seed, 
    batch_size = batch_size, 
    num_workers = num_workers, 
    target_len = target_length,
    hic_1d = add_hic_1d,
    hic_2d = add_hic_2d)

model = Hcformer.from_hparams(
    dim = dim,
    seq_dim = dim,
    depth = depth,
    heads = heads,
    output_heads = dict(human=output_heads),
    target_length = target_length,
    dim_divisible_by = dim / 12,
    hic_1d = add_hic_1d,
    hic_1d_feat_num = hic_1d_feat_num,       
    hic_1d_feat_dim = dim,
    hic_2d = add_hic_2d,
).to(device)

In [6]:
if len(gpu) > 1:
    model = nn.DataParallel(model, device_ids=gpu)

model.load_state_dict(torch.load('/home/han_harry_zhang/SeqHiC2RNA/output/hcformer_pbulk/model/hcformer_pbulk5/yakuks2o'))

# model.load_state_dict(torch.load('/home/han_harry_zhang/SeqHiC2RNA/output/hcformer_pbulk/model/hcformer_pbulk2/ai9o5a06'))
# if len(gpu) > 1:
#     model = nn.DataParallel(model, device_ids=gpu)

<All keys matched successfully>

In [7]:
model.eval()
with torch.no_grad():
    with tqdm(total=len(test_loader), dynamic_ncols=True) as t:
        t.set_description('Evaluation: ')
        total_pred = []
        total_exp = []
        for item in test_loader:
            if add_hic_1d and add_hic_2d:
                seq, exp, hic_1d, hic_2d = item[0].to(device), item[1], item[2].to(device), item[3]
                hic_2d = sparse_to_torch(hic_2d).to(device)
            elif add_hic_1d:
                seq, exp, hic_1d = item[0].to(device), item[1], item[2].to(device)
                hic_2d = None
            elif add_hic_2d:
                seq, exp, hic_2d = item[0].to(device), item[1], item[2]
                hic_1d = None
                hic_2d = sparse_to_torch(hic_2d).to(device)
            else:
                seq, exp = item[0].to(device), item[1]
                hic_1d, hic_2d = None, None
            pred = model(seq, head='human', hic_1d=hic_1d, hic_2d=hic_2d)

            total_pred.append(pred.detach().cpu())
            total_exp.append(exp.unsqueeze(-1))
            t.update()
        total_pred = torch.concat(total_pred, dim=0)
        total_exp  = torch.concat(total_exp,  dim=0)
mean_test_pearson_corr_coef = pearson_corr_coef(total_pred, total_exp)[0]

Evaluation: : 100%|██████████| 141/141 [01:17<00:00,  1.81it/s]


In [8]:
pred_center = total_pred - total_pred.mean(dim = 1, keepdim = True)
exp_center = total_exp - total_exp.mean(dim = 1, keepdim = True)
test_pearson_list = F.cosine_similarity(pred_center, exp_center, dim = 1)

In [9]:
reshaped_test = test_pearson_list.reshape(-1, 748)
cell_type_averages = torch.mean(reshaped_test, dim=1)
cell_type_std = torch.std(reshaped_test, dim=1)
sequence_averages = torch.mean(reshaped_test, dim=0)
sequence_std = torch.std(reshaped_test, dim=0)
print(cell_type_averages)
print(cell_type_std)
print(sequence_averages)
print(sequence_std)
# with open('try/cell_type_averages_hic_1d.pkl', 'wb') as f:
#     pickle.dump(cell_type_averages, f)
# with open('try/cell_type_std_hic_1d.pkl', 'wb') as f:
#     pickle.dump(cell_type_std, f)
# with open('try/sequence_averages_hic_1d.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)
# with open('try/sequence_std_hic_1d.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)

with open('try/cell_type_averages.pkl', 'wb') as f:
    pickle.dump(cell_type_averages, f)
with open('try/cell_type_std.pkl', 'wb') as f:
    pickle.dump(cell_type_std, f)
with open('try/sequence_averages.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)
with open('try/sequence_std.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)


tensor([0.3726, 0.2786, 0.3304, 0.3433, 0.4711, 0.3078])
tensor([0.2439, 0.2238, 0.2341, 0.2375, 0.2716, 0.2155])
tensor([ 2.7672e-01,  4.6281e-01,  1.2892e-01,  4.6896e-01,  3.7157e-01,
         5.0322e-01,  3.6859e-02,  2.5045e-03,  2.5428e-01,  3.9096e-01,
         4.8517e-01,  6.6358e-01,  5.0590e-01,  4.6288e-01,  3.9558e-01,
         1.9238e-01,  2.5985e-01,  5.4446e-01,  6.4720e-01,  4.7370e-01,
         5.2818e-03,  4.1119e-01,  3.6931e-01,  1.9615e-01,  3.3771e-01,
         8.9374e-02,  1.1863e-02,  3.9876e-01, -1.9272e-02,  7.5622e-02,
         5.5938e-01,  4.7661e-01,  2.2251e-01,  3.0453e-01,  5.4599e-01,
         5.2379e-01,  5.2165e-01,  1.3492e-01, -9.4183e-05,  4.3120e-01,
         5.5438e-01,  4.3887e-01,  5.3409e-01,  5.0853e-01,  5.2469e-01,
         4.6138e-01,  3.5571e-02,  2.2146e-01,  3.2400e-01,  8.6585e-03,
         5.4530e-01,  6.9646e-01,  1.4004e-01,  7.0693e-01,  3.8285e-01,
         4.4742e-01,  6.7590e-01,  7.6496e-01,  5.2466e-01,  4.7534e-01,
         5

In [10]:
cell_type_averages = []
cell_type_std = []
for subarray in reshaped_test:
    non_zero_elements = subarray[subarray != 0]
    cell_type_averages.append(torch.mean(non_zero_elements))
    cell_type_std.append(torch.std(non_zero_elements))
print(cell_type_averages)
print(cell_type_std)

sequence_averages = []
sequence_std = []
for i in range(reshaped_test.size(1)):
    subarray = reshaped_test[:, i]
    non_zero_elements = subarray[subarray != 0]
    sequence_averages.append(torch.mean(non_zero_elements))
    sequence_std.append(torch.std(non_zero_elements))
print(sequence_averages)
print(sequence_std)
# with open('try/cell_type_averages_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(cell_type_averages, f)
# with open('try/cell_type_std_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(cell_type_std, f)
# with open('try/sequence_averages_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)
# with open('try/sequence_std_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)

with open('try/cell_type_averages_no0.pkl', 'wb') as f:
    pickle.dump(cell_type_averages, f)
with open('try/cell_type_std_no0.pkl', 'wb') as f:
    pickle.dump(cell_type_std, f)
with open('try/sequence_averages_no0.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)
with open('try/sequence_std_no0.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)
    

[tensor(0.3942), tensor(0.3007), tensor(0.3602), tensor(0.3722), tensor(0.4762), tensor(0.3284)]
[tensor(0.2332), tensor(0.2178), tensor(0.2214), tensor(0.2245), tensor(0.2686), tensor(0.2068)]
[tensor(0.2767), tensor(0.4628), tensor(0.1289), tensor(0.4690), tensor(0.3716), tensor(0.5032), tensor(0.0442), tensor(0.0150), tensor(0.2543), tensor(0.3910), tensor(0.4852), tensor(0.6636), tensor(0.5059), tensor(0.4629), tensor(0.3956), tensor(0.2309), tensor(0.2599), tensor(0.5445), tensor(0.6472), tensor(0.4737), tensor(0.0079), tensor(0.4112), tensor(0.3693), tensor(0.1962), tensor(0.3377), tensor(0.0894), tensor(0.0178), tensor(0.3988), tensor(-0.0289), tensor(0.1134), tensor(0.5594), tensor(0.4766), tensor(0.2225), tensor(0.3045), tensor(0.5460), tensor(0.5238), tensor(0.5216), tensor(0.2024), tensor(-0.0001), tensor(0.4312), tensor(0.5544), tensor(0.4389), tensor(0.5341), tensor(0.5085), tensor(0.5247), tensor(0.4614), tensor(0.0356), tensor(0.2215), tensor(0.3240), tensor(0.0087), ten

[tensor(0.1151), tensor(0.1623), tensor(0.1416), tensor(0.1455), tensor(0.1422), tensor(0.1386), tensor(0.0678), tensor(nan), tensor(0.1202), tensor(0.1540), tensor(0.1597), tensor(0.1547), tensor(0.1074), tensor(0.2515), tensor(0.0856), tensor(0.1601), tensor(0.0835), tensor(0.1026), tensor(0.0965), tensor(0.1661), tensor(0.0503), tensor(0.0948), tensor(0.0972), tensor(0.1349), tensor(0.2415), tensor(0.0556), tensor(0.0339), tensor(0.1638), tensor(0.0284), tensor(0.1654), tensor(0.0856), tensor(0.1051), tensor(0.1306), tensor(0.0732), tensor(0.1336), tensor(0.1384), tensor(0.1240), tensor(0.1893), tensor(0.0414), tensor(0.1470), tensor(0.1588), tensor(0.1933), tensor(0.0811), tensor(0.2134), tensor(0.1527), tensor(0.1785), tensor(0.1626), tensor(0.1310), tensor(0.0880), tensor(0.0973), tensor(0.1772), tensor(0.1136), tensor(0.0357), tensor(0.1526), tensor(0.1192), tensor(0.1332), tensor(0.1110), tensor(0.1151), tensor(0.1659), tensor(0.2005), tensor(0.1145), tensor(0.2240), tensor(0.1