In [None]:
import torch
import torch.nn.functional as F
import pickle
from utils.dataset import load_data_pbulk
from algo.Hcformer_pretrain import Hcformer # just change here for different algo version
from algo.module import pearson_corr_coef
from torch import nn
from tqdm import tqdm
from typing import List
from scipy.sparse import coo_matrix

In [None]:
data_path = './data'
seed = 0
num_workers = 8
target_length = 240
dim = 768
depth = 11
heads = 8
output_heads = 1
hic_1d_feat_num = 5

add_hic_1d = False
add_hic_2d = True
batch_size = 32
gpu = [0, 4]

In [None]:
if len(gpu) > 0:
    device = torch.device(f"cuda:{gpu[0]}")
    print(f"Device is {gpu}")
else:
    device = torch.device(f"cuda:{gpu}" if (torch.cuda.is_available() and gpu >= 0) else "cpu")
    print(f"Device is {device}.")

In [None]:
def sparse_to_torch(coo_matrix: List[coo_matrix]):
    dense_matrix = []
    for m in coo_matrix:
        m = m.toarray()
        m = m + m.T
        m /= torch.ones(400) + torch.eye(400)
        dense_matrix.append(m)
    return torch.stack(dense_matrix, dim=0)

def evaluation(model, data_loader, device):
    model.eval()
    with torch.no_grad():
        with tqdm(total=len(data_loader), dynamic_ncols=True) as t:
            t.set_description('Evaluation: ')
            total_pred = []
            total_exp = []
            for item in data_loader:
                if add_hic_1d and add_hic_2d:
                    seq, exp, hic_1d, hic_2d = item[0].to(device), item[1], item[2].to(device), item[3]
                    hic_2d = sparse_to_torch(hic_2d).to(device)
                elif add_hic_1d:
                    seq, exp, hic_1d = item[0].to(device), item[1], item[2].to(device)
                    hic_2d = None
                elif add_hic_2d:
                    seq, exp, hic_2d = item[0].to(device), item[1], item[2]
                    hic_1d = None
                    hic_2d = sparse_to_torch(hic_2d).to(device)
                else:
                    seq, exp = item[0].to(device), item[1].to(device)
                    hic_1d, hic_2d = None, None
                pred = model(seq, head='human', hic_1d=hic_1d, hic_2d=hic_2d)

                total_pred.append(pred.detach().cpu())
                total_exp.append(exp.unsqueeze(-1))
                t.update()
            total_pred = torch.concat(total_pred, dim=0)
            total_exp  = torch.concat(total_exp,  dim=0)

    return pearson_corr_coef(total_pred, total_exp)[0]

In [None]:
_, _, test_loader = load_data_pbulk(
    path = data_path, 
    seed = seed, 
    batch_size = batch_size, 
    num_workers = num_workers, 
    target_len = target_length,
    hic_1d = add_hic_1d,
    hic_2d = add_hic_2d)

model = Hcformer.from_hparams(
    dim = dim,
    seq_dim = dim,
    depth = depth,
    heads = heads,
    output_heads = dict(human=output_heads),
    target_length = target_length,
    dim_divisible_by = dim / 12,
    hic_1d = add_hic_1d,
    hic_1d_feat_num = hic_1d_feat_num,       
    hic_1d_feat_dim = dim,
    hic_2d = add_hic_2d,
).to(device)

In [None]:
# if len(gpu) > 1:
#     model = nn.DataParallel(model, device_ids=gpu)

# model.load_state_dict(torch.load('/home/han_harry_zhang/SeqHiC2RNA/output/hcformer_pbulk/model/hcformer_pbulk5/yakuks2o'))

model.load_state_dict(torch.load('/home/han_harry_zhang/SeqHiC2RNA/output/hcformer_pbulk/model/hcformer_pbulk2/ai9o5a06'))
if len(gpu) > 1:
    model = nn.DataParallel(model, device_ids=gpu)

In [None]:
model.eval()
with torch.no_grad():
    with tqdm(total=len(test_loader), dynamic_ncols=True) as t:
        t.set_description('Evaluation: ')
        total_pred = []
        total_exp = []
        for item in test_loader:
            if add_hic_1d and add_hic_2d:
                seq, exp, hic_1d, hic_2d = item[0].to(device), item[1], item[2].to(device), item[3]
                hic_2d = sparse_to_torch(hic_2d).to(device)
            elif add_hic_1d:
                seq, exp, hic_1d = item[0].to(device), item[1], item[2].to(device)
                hic_2d = None
            elif add_hic_2d:
                seq, exp, hic_2d = item[0].to(device), item[1], item[2]
                hic_1d = None
                hic_2d = sparse_to_torch(hic_2d).to(device)
            else:
                seq, exp = item[0].to(device), item[1]
                hic_1d, hic_2d = None, None
            pred = model(seq, head='human', hic_1d=hic_1d, hic_2d=hic_2d)

            total_pred.append(pred.detach().cpu())
            total_exp.append(exp.unsqueeze(-1))
            t.update()
        total_pred = torch.concat(total_pred, dim=0)
        total_exp  = torch.concat(total_exp,  dim=0)
mean_test_pearson_corr_coef = pearson_corr_coef(total_pred, total_exp)[0]

In [None]:
pred_center = total_pred - total_pred.mean(dim = 1, keepdim = True)
exp_center = total_exp - total_exp.mean(dim = 1, keepdim = True)
test_pearson_list = F.cosine_similarity(pred_center, exp_center, dim = 1)

In [None]:
reshaped_test = test_pearson_list.reshape(-1, 748)
cell_type_averages = torch.mean(reshaped_test, dim=1)
cell_type_std = torch.std(reshaped_test, dim=1)
sequence_averages = torch.mean(reshaped_test, dim=0)
sequence_std = torch.std(reshaped_test, dim=0)
print(cell_type_averages)
print(cell_type_std)
print(sequence_averages)
print(sequence_std)
# with open('try/cell_type_averages_hic_1d.pkl', 'wb') as f:
#     pickle.dump(cell_type_averages, f)
# with open('try/cell_type_std_hic_1d.pkl', 'wb') as f:
#     pickle.dump(cell_type_std, f)
# with open('try/sequence_averages_hic_1d.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)
# with open('try/sequence_std_hic_1d.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)

with open('try/cell_type_averages.pkl', 'wb') as f:
    pickle.dump(cell_type_averages, f)
with open('try/cell_type_std.pkl', 'wb') as f:
    pickle.dump(cell_type_std, f)
with open('try/sequence_averages.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)
with open('try/sequence_std.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)


In [None]:
cell_type_averages = []
cell_type_std = []
for subarray in reshaped_test:
    non_zero_elements = subarray[subarray != 0]
    cell_type_averages.append(torch.mean(non_zero_elements))
    cell_type_std.append(torch.std(non_zero_elements))
print(cell_type_averages)
print(cell_type_std)

sequence_averages = []
sequence_std = []
for i in range(reshaped_test.size(1)):
    subarray = reshaped_test[:, i]
    non_zero_elements = subarray[subarray != 0]
    sequence_averages.append(torch.mean(non_zero_elements))
    sequence_std.append(torch.std(non_zero_elements))
print(sequence_averages)
print(sequence_std)
# with open('try/cell_type_averages_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(cell_type_averages, f)
# with open('try/cell_type_std_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(cell_type_std, f)
# with open('try/sequence_averages_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)
# with open('try/sequence_std_hic_1d_no0.pkl', 'wb') as f:
#     pickle.dump(sequence_averages, f)

with open('try/cell_type_averages_no0.pkl', 'wb') as f:
    pickle.dump(cell_type_averages, f)
with open('try/cell_type_std_no0.pkl', 'wb') as f:
    pickle.dump(cell_type_std, f)
with open('try/sequence_averages_no0.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)
with open('try/sequence_std_no0.pkl', 'wb') as f:
    pickle.dump(sequence_averages, f)
    