In [4]:
from model.EPInformer import EPInformer_v2, enhancer_predictor_256bp
from scripts.utils import prepare_input

In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch

In [None]:
# Download ABC element-gene data for K562
!wget https://www.encodeproject.org/files/ENCFF635RHY/@@download/ENCFF635RHY.bed.gz -O ./data/K562_enhancer_gene_links.bed.gz

In [None]:
# Donwload reference genome
!wget https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz  -P ./data/
!gunzip ./data/hg38.fa.gz

In [None]:
# Load ABC enhancer-gene data
enhancer_gene_k562 = pd.read_csv('./data/K562_enhancer_gene_links.bed.gz', sep='\t')
# Select the gene-enhancer links within 100kb to the TSS of target gene and remove the promoter element
enhancer_gene_k562_100kb = enhancer_gene_k562[(enhancer_gene_k562['distance']<=100_000)&(enhancer_gene_k562['distance']>1000)].reset_index()
enhancer_gene_k562_100kb.to_csv('./data/K562_enhancer_gene_links_100kb.tsv', index=False, sep='\t')

In [16]:
enhancer_gene_k562_100kb = pd.read_csv('./data/K562_enhancer_gene_links_100kb.tsv', sep='\t')
gene_tss = pd.read_csv('./data/GeneList_K562.txt', sep='\t')[['name', 'chr', 'tss', 'strand']]
data_split = pd.read_csv('./data/leave_chrom_out_crossvalidation_split_18377genes.csv')
enhancer_gene_k562_100kb_includeNoEnhancerGene = enhancer_gene_k562_100kb.merge(gene_tss, left_on='TargetGene', right_on='name', how='right', suffixes=['', '_gene']).reset_index()

In [17]:
gene_list = list(data_split[data_split['fold_1'] == 'test']['Gene name'].head(16))
# encode gene-enhancer links for EPInformer
# num_feature == 1: distance; num_feature == 2: distance + enhancer activity; num_feature == 3: distance + enhancer activity + hic contacts
device = 'cpu'
PE_codes, PE_feats, mRNA_feats, PE_pairs = prepare_input(enhancer_gene_k562_100kb_includeNoEnhancerGene, gene_list, num_features=2)
PE_codes = torch.from_numpy(PE_codes).float().to(device)
PE_feats = torch.from_numpy(PE_feats).float().to(device)
mRNA_feats = torch.from_numpy(mRNA_feats).float().to(device)
print(PE_codes.shape, PE_feats.shape, mRNA_feats.shape)

100%|██████████| 16/16 [00:00<00:00, 26.04it/s]


torch.Size([16, 61, 2000, 4]) torch.Size([16, 61, 2]) torch.Size([16, 9])


In [18]:
# Load pre-trained EPInformer-PE-Activity (CAGE-seq)
model = EPInformer_v2(n_encoder=3, n_enhancer=60, out_dim=64, n_extraFeat=2, device=device)
model_path = '../models_with_pretrained/fold_1_best_EPInformerV2.4base.64dim.3Trans.4head.TrueBN.TrueLN.TrueFeat.2extraFeat.60enh.preTrainedConv.tuneP2.K562.rmEnhNone.bs32.seq_feat_dist.DNaseH.PSignal.distanceDist100k.hicNone.len2k.distance.CAGE_checkpoint.pt'
checkpoint = torch.load(model_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

In [19]:
model.eval()
with torch.no_grad():
    pred_expr, _ = model(PE_codes, mRNA_feats, PE_feats)
    pred_expr = pred_expr.numpy().squeeze()
print(gene_list)
print(pred_expr)

['RAD52', 'M6PR', 'CYP26B1', 'ALS2', 'CASP10', 'CFLAR', 'TFPI', 'NDUFAF7', 'FKBP4', 'RECQL', 'RPAP3', 'GCFC2', 'WDR54', 'FARP2', 'ADIPOR2', 'UPP2']
[1.8046178  2.3075235  1.0193682  1.1156756  0.740408   1.7147359
 2.0511608  2.3792398  1.417219   2.0264227  2.0248396  2.1739185
 1.2864089  2.1143231  2.3980145  0.34572122]
