In [1]:
from model.EPInformer import EPInformer_v2, enhancer_predictor_256bp
from scripts.utils import prepare_input

In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch

In [None]:
# Download ABC element-gene data for K562
!wget https://www.encodeproject.org/files/ENCFF635RHY/@@download/ENCFF635RHY.bed.gz -O ./data/K562_enhancer_gene_links.bed.gz

In [None]:
# Donwload reference genome
!wget https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz  -P ./data/
!gunzip ./data/hg38.fa.gz

In [8]:
# Load ABC enhancer-gene data
enhancer_gene_k562 = pd.read_csv('./data/K562_enhancer_gene_links.bed.gz', sep='\t')
# Select the gene-enhancer links within 100kb to the TSS of target gene and remove the promoter element
enhancer_gene_k562_100kb = enhancer_gene_k562[(enhancer_gene_k562['distance']<=100_000)&(enhancer_gene_k562['distance']>1000)].reset_index()
enhancer_gene_k562_100kb.to_csv('./data/K562_enhancer_gene_links_100kb.tsv', index=False, sep='\t')

In [3]:
# Load gene-enhancer links in K562 cell line
enhancer_gene_k562_100kb = pd.read_csv('./data/K562_enhancer_gene_links_100kb.tsv', sep='\t')

In [33]:
enhancer_gene_k562_100kb[enhancer_gene_k562_100kb['TargetGene'] == 'HBB']

Unnamed: 0,index,#chr,start,end,name,class,activity_base,normalized_h3K27ac,normalized_dhs,activity_base_squared,...,hic_contact_pl_scaled,hic_pseudocount,hic_contact_pl_scaled_adj,ABC.Score.Numerator,ABC.Score,powerlaw.Score.Numerator,powerlaw.Score,CellType,hic_contact_squared,Ensembl_ID
2467,60746,chr11,5128857,5130098,intergenic|chr11:5128857-5130098,intergenic,10.566709,14.776777,7.556136,111.655339,...,0.004363,0.0,0.004363,0.046099,0.015107,0.098347,0.001982,K562_ID_2644,1.9e-05,ENSG00000244734
2468,60747,chr11,5131702,5132324,promoter|chr11:5131702-5132324,promoter,7.500498,11.2345,5.007564,56.25747,...,0.003526,0.0,0.003526,0.026447,0.008667,0.071426,0.001439,K562_ID_2644,1.2e-05,ENSG00000244734
2469,60748,chr11,5132392,5132892,promoter|chr11:5132392-5132892,promoter,1.249283,3.207317,0.486608,1.560708,...,0.003526,0.0,0.003526,0.004405,0.001444,0.011966,0.000241,K562_ID_2644,1.2e-05,ENSG00000244734
2470,60749,chr11,5151586,5152619,promoter|chr11:5151586-5152619,promoter,4.699122,4.584471,4.81664,22.081748,...,0.001954,0.0,0.001954,0.009184,0.00301,0.055016,0.001109,K562_ID_2644,4e-06,ENSG00000244734
2471,60750,chr11,5153509,5154470,intergenic|chr11:5153509-5154470,intergenic,2.602571,2.862061,2.366608,6.773376,...,0.001954,0.0,0.001954,0.005086,0.001667,0.031153,0.000628,K562_ID_2644,4e-06,ENSG00000244734
2472,60751,chr11,5160051,5160551,intergenic|chr11:5160051-5160551,intergenic,1.819946,1.802978,1.837073,3.312203,...,0.005903,0.0,0.005903,0.010744,0.003521,0.023566,0.000475,K562_ID_2644,3.5e-05,ENSG00000244734
2473,60752,chr11,5171825,5172325,intergenic|chr11:5171825-5172325,intergenic,1.595402,0.92125,2.762886,2.545308,...,0.002591,0.0,0.002591,0.004134,0.001355,0.024457,0.000493,K562_ID_2644,7e-06,ENSG00000244734
2474,60753,chr11,5178461,5178961,promoter|chr11:5178461-5178961,promoter,0.895923,1.649535,0.486608,0.802678,...,0.00201,0.0,0.00201,0.001801,0.00059,0.01536,0.00031,K562_ID_2644,4e-06,ENSG00000244734
2475,60754,chr11,5196012,5196698,intergenic|chr11:5196012-5196698,intergenic,2.715718,1.526782,4.830502,7.375124,...,0.005493,0.0,0.005493,0.014917,0.004888,0.069101,0.001393,K562_ID_2644,3e-05,ENSG00000244734
2476,60755,chr11,5196781,5197281,intergenic|chr11:5196781-5197281,intergenic,1.540977,1.63778,1.449895,2.37461,...,0.005493,0.0,0.005493,0.008464,0.002774,0.039977,0.000806,K562_ID_2644,3e-05,ENSG00000244734


In [27]:
enhancer_gene_k562_100kb[enhancer_gene_k562_100kb['#chr'] == 'chr8']['TargetGene'].unique()

array(['ZNF596', 'RPL23AP53', 'FAM87A', 'ERICH1', 'LOC101927752', 'CLN8',
       'ARHGEF10', 'LOC101927815', 'LOC100287015', 'MCPH1', 'ANGPT2',
       'MCPH1-AS1', 'AGPAT5', 'FAM86B3P', 'CLDN23', 'MFHAS1', 'ERI1',
       'PPP1R3B', 'LOC101929128', 'TNKS', 'MSRA', 'RP1L1', 'PINX1',
       'LOC101929229', 'MTMR9', 'GATA4', 'NEIL2', 'FDFT1', 'CTSB',
       'LONRF1', 'DLC1', 'C8orf48', 'TUSC3', 'MICU3', 'ZDHHC2', 'VPS37A',
       'CNOT7', 'MTMR7', 'PCM1', 'LOC101929066', 'ASAH1', 'NAT1',
       'INTS10', 'LPL', 'SLC18A1', 'ATP6V1B2', 'GFRA2', 'DOK2', 'XPO7',
       'FGF17', 'DMTN', 'FAM160B2', 'NUDT18', 'REEP4', 'LGI3', 'SFTPC',
       'BMP1', 'POLR3D', 'LOC100507071', 'SLC39A14', 'PPP3CC', 'SORBS3',
       'C8orf58', 'CCAR2', 'BIN3-IT1', 'BIN3', 'EGR3', 'RHOBTB2',
       'LOC286059', 'TNFRSF10B', 'LOC254896', 'TNFRSF10A', 'LOC389641',
       'R3HCC1', 'LOC100507156', 'ENTPD4', 'SLC25A37', 'KCTD9', 'CDCA2',
       'PPP2R2A', 'SDAD1P1', 'BNIP3L', 'DPYSL2', 'TRIM35', 'PTK2B',
       'CHRNA2'

In [65]:
# encode gene-enhancer links for EPInformer
# num_feature == 1: distance; num_feature == 2: distance + enhancer activity; num_feature == 3: distance + enhancer activity + hic contacts
device = 'cpu'
gene_list = ['MYC', 'KLF1', 'LPL']
PE_codes, PE_feats, mRNA_feats, PE_pairs = prepare_input(enhancer_gene_k562_100kb, gene_list, num_features=2)
PE_codes = torch.from_numpy(PE_codes).float().to(device)
PE_feats = torch.from_numpy(PE_feats).float().to(device)
mRNA_feats = torch.from_numpy(mRNA_feats).float().to(device)
print(PE_codes.shape, PE_feats.shape, mRNA_feats.shape)

100%|██████████| 3/3 [00:00<00:00, 11.77it/s]

torch.Size([3, 61, 2000, 4]) torch.Size([3, 61, 2]) torch.Size([3, 9])





In [66]:
# Load pre-trained EPInformer-PE-Activity (CAGE-seq)
model = EPInformer_v2(n_encoder=3, n_enhancer=60, out_dim=64, n_extraFeat=2, device=device)
model_path = '../models_with_pretrained/fold_8_best_EPInformerV2.4base.64dim.3Trans.4head.TrueBN.TrueLN.TrueFeat.2extraFeat.60enh.preTrainedConv.tuneP2.K562.rmEnhNone.bs32.seq_feat_dist.DNaseH.PSignal.distanceDist100k.hicNone.len2k.distance.CAGE_checkpoint.pt'
checkpoint = torch.load(model_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

In [69]:
model.eval()
with torch.no_grad():
    pred_expr, _ = model(PE_codes, mRNA_feats, PE_feats)
    pred_expr = pred_expr.numpy().squeeze()
print(gene_list)
print(pred_expr)

['MYC', 'KLF1', 'LPL']
[1.3414612  0.935605   0.41467816]
