In [27]:
import os, pickle, h5py
import numpy as np, pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from scipy.sparse import coo_matrix
from scipy.stats import pearsonr, spearmanr

import torch
from torch import nn
from algo.Hcformer_pretrain import Hcformer, CNN_Extractor
from utils.data import str_to_seq_indices

# Load model

In [88]:
pretrain_para_path = Path('/work/magroup/hanzhan4/pretrain')
hcformer_path = Path('/work/magroup/hanzhan4/model/hcformer_pbulk/hic1d2d/d1xcmvsr')
hcformer_path = Path('/work/magroup/hanzhan4/model/hcformer_pbulk/hic2d/awasqs3z')

device = 'cuda'

In [3]:
# Define the CNN Extractor
# We need to first use this CNN Extractor to convert one hot DNA sequence 
cnn = CNN_Extractor().to(device)
cnn.stem.load_state_dict(torch.load(pretrain_para_path / 'stem.pt', map_location=torch.device('cpu')))
cnn.conv_tower.load_state_dict(torch.load(pretrain_para_path / 'conv_tower.pt', map_location=torch.device('cpu')))
cnn.eval()
pass

In [85]:
# Define the Hcformer
# Do not need to change the hyperparameter in the following model definition
model = Hcformer.from_hparams(
    dim = 768,
    seq_dim = 768,
    depth = 11,
    heads = 8,
    output_heads = dict(human=1),
    # target_length = 400,
    target_length = 240,
    dim_divisible_by = 128,
    hic_1d = False,
    hic_1d_feat_num = 5,
    hic_1d_feat_dim = 768,
    hic_2d = True,
).to(device)

In [89]:
state_dict = torch.load(hcformer_path, map_location=torch.device('cpu'))
model.load_state_dict({k.split('.', 1)[1] if k.startswith('module.') else k: v for k, v in state_dict.items()})

<All keys matched successfully>

# Toy example

In [52]:
# prepare the input sample
# seq = torch.randint(0, 5, (1, 409600))
# if your input sample is str type, you can use this function to convert your string
# we convert A->0, C->1, G->2, T->3, N->4
seq_str = ''.join(map(chr, np.random.choice(list(map(ord, 'ATGCN')), size=409600, replace=True)))
seq = str_to_seq_indices(seq_str)[None]

hic_1d = torch.rand(2, 400, 5)
hic_2d = torch.rand(2, 400, 400)

print(seq.shape, hic_1d.shape, hic_2d.shape)

torch.Size([1, 409600]) torch.Size([2, 400, 5]) torch.Size([2, 400, 400])


In [53]:
with torch.no_grad():
    seq, hic_1d, hic_2d = seq.to(device), hic_1d.to(device), hic_2d.to(device)
    seq_dense = cnn(seq)
    print(seq_dense.shape)
    pred = model(seq_dense, head='human', hic_1d=hic_1d, hic_2d=hic_2d)
    print(pred.shape)

torch.Size([1, 3200, 1536])
torch.Size([2, 240, 1])


# Evaluate at our dataset

In [79]:
path2dir = Path('/work/magroup/tianming/Researches/seqhic2expr/data/gage-seq-mBC')

feat_1d_list = ['ab', 'is-hw25', 'is-hw50', 'is-hw100', 'genebody']
hic_1d_all = []
for feat_name in tqdm(feat_1d_list):
    with open(path2dir / f'1d-score-celltypebulk-10kb-{feat_name}_1024_200_uint8.pkl', 'rb') as f:
        hic_1d_all.append(pickle.load(f).reshape(28, 3740, 400).astype(float) / 255)
hic_1d_all = np.stack(hic_1d_all, axis=-1)
print('Hi-C 1D', hic_1d_all.shape)

with open(path2dir / 'expression_cov_1024_200_celltypebulk.pkl', 'rb') as f:
    expression_all = pickle.load(f)
expression_all = np.concatenate([_.toarray() for _ in expression_all.ravel()], axis=0).reshape(28, 3740, 400)
expression_all = expression_all.reshape(28, 3740 * 400)
row_min = np.min(expression_all, axis=1, keepdims=True)
row_max = np.max(expression_all, axis=1, keepdims=True)
expression_all = (expression_all - row_min) / (row_max - row_min)
expression_all = np.log1p(expression_all * 1e4)
expression_all = expression_all.reshape(28, 3740, 400)
print('expression', expression_all.shape)

with open(path2dir / 'sequence_1024_200.tsv', 'r') as f:
    sequence_all = np.array(f.read().upper().strip().split())
print('sequence', sequence_all.shape, len(sequence_all[0]))

df_meta_gene = pd.read_csv(path2dir / 'genes.tsv', sep='\t')
print('Gene meta data', df_meta_gene.shape)

df_meta_cell = pd.read_csv(path2dir / 'cells.tsv', sep='\t')
print('Cell meta data', df_meta_cell.shape)

with open(path2dir / 'cell_types.tsv', 'r') as f:
    cell_type_list = np.array(f.read().strip().split('\n'))
print(cell_type_list)

  0%|          | 0/5 [00:00<?, ?it/s]

Hi-C 1D (28, 3740, 400, 5)
expression (28, 3740, 400)
sequence (3740,) 409600
Gene meta data (3740, 4)
Cell meta data (3105, 3)
['Astro' 'L2 IT RvPP' 'L2/3 IT CTX a' 'L2/3 IT CTX b' 'L2/3 IT CTX c'
 'L2/3 IT RSP' 'L4 IT CTX' 'L4/5 IT CTX' 'L5 IT CTX' 'L5 IT RSP'
 'L5 PT CTX' 'L5/6 NP CTX' 'L6 CT CTX a' 'L6 CT CTX b' 'L6 IT CTX'
 'L6b CTX' 'Lamp5' 'Meis2' 'Micro' 'ODC' 'OPC' 'Pvalb a' 'Pvalb b' 'Sncg'
 'Sst a' 'Sst b' 'VLMC' 'Vip']


In [76]:
with open('/work/magroup/hanzhan4/model_output/hcformer_pbulk/test_expression.pkl', 'rb') as f:
    tmp = pickle.load(f)
tmp.reshape(6, -1, 240).shape

torch.Size([6, 748, 240])

In [78]:
tmp.shape

torch.Size([4488, 240, 1])

In [72]:
!ls /work/magroup/hanzhan4/model_output/hcformer_pbulk

baseline  hic1d  hic1d2d  hic2d  test_expression.pkl  valid_expression.pkl


In [55]:
def load_contact_map_of_one_gene(gene_id, cell_type_list=cell_type_list):
    # path2file = path2dir / 'contact_1024_200_celltypebulk.h5'
    path2file = Path('/scratch/tmp-tianming') / 'contact_1024_200_celltypebulk.h5'
    hic_2d = []
    with h5py.File(path2file, 'r') as f:
        for cell_type in cell_type_list:
            g = f[cell_type.replace('/', '')][gene_id]
            m = coo_matrix((g['data'][()], (g['row'][()], g['col'][()])), shape=(400, 400)).toarray()
            # coo_mat = coo_matrix((data, (row, col)), shape=(400, 400))
            # coo_mat.sum_duplicates()
            # coo_mat.data = np.log1p(coo_mat.data)
            m = np.log1p((m + m.T) / (np.eye(len(m)) + 1))
            hic_2d.append(m)
    hic_2d = np.stack(hic_2d, axis=0)
    return hic_2d
hic_2d_example = load_contact_map_of_one_gene(df_meta_gene.gene_id.iloc[0])
print(hic_2d_example.shape)

(28, 400, 400)


In [56]:
seen_cell_type_idx = np.array([4, 10, 3, 13, 23, 7, 21, 17, 25, 2, 27, 6, 16, 22, 12, 15, 14, 1, 9, 18, 26, 0])
unseen_cell_type_idx = np.array([5, 19, 24, 11, 8, 20])
seen_gene_idx = np.arange(2992)
unseen_gene_idx = np.arange(2992, 3740)

In [None]:
prediction_all = np.empty([28, 3740, 240])
cell_type_indices = unseen_cell_type_idx
gene_indices = unseen_gene_idx
for gene_idx, gene_id in tqdm(
    zip(gene_indices, df_meta_gene.iloc[gene_indices].gene_id), total=len(gene_indices)):
    seq = str_to_seq_indices(sequence_all[gene_idx])[None].to(device)
    hic_1d = torch.tensor(hic_1d_all[cell_type_indices, gene_idx], device=device, dtype=torch.float32)
    hic_2d = load_contact_map_of_one_gene(gene_id, cell_type_list[cell_type_idx])
    hic_2d = torch.tensor(hic_2d, device=device, dtype=torch.float32)
    with torch.no_grad():
        seq_dense = cnn(seq)
        seq_dense = seq_dense.repeat(len(hic_1d), 1, 1)
        pred = model(seq_dense, head='human', hic_1d=hic_1d, hic_2d=hic_2d)
        prediction_all[cell_type_indices, gene_idx] = pred.cpu().numpy().squeeze(-1)
    # del seq, hic_1d, hic_2d, seq_dense, pred

In [21]:
1/0

ZeroDivisionError: division by zero

In [82]:
def get_matrix(x):
    x = x[cell_type_indices][:, gene_indices]
    # x = x[..., 80: -80]
    return x
pred = get_matrix(prediction_all)
truth = get_matrix(expression_all[..., 80:-80])
c = pearsonr(pred.ravel(), truth.ravel())
c

PearsonRResult(statistic=0.49327673616329243, pvalue=0.0)

In [84]:
pearsonr(get_matrix(prediction_all).ravel(), tmp.reshape(6, -1, 240).ravel())

PearsonRResult(statistic=0.4932767361815076, pvalue=0.0)