# Let's predict for the first chromosome
## Let's try to load data via DataLoader to predict (to avoid indexing error!)

In [1]:
# Import necessary packages
import numpy as np
import sys
import os
import pandas as pd
from Bio import SeqIO
from matplotlib import pyplot as plt
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
import math
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import re
import csv

from enformer_pytorch import Enformer, seq_indices_to_one_hot, GenomeIntervalDataset
from enformer_loader import GenomeDataIntervalDataset
import ast
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print('using GPU') if torch.cuda.is_available() else print('using CPU')

using GPU


In [3]:
# get file paths
path = 'tests/data/'
fasta_file = '/lustre/scratch126/gengen/projects/graft/Dataset/reference/hg38_galGal6_full/fasta/GRCh38.GRCg6a.full.renamed.merged.fa'
bed_file = path + 'test1_train_dataset.bed'
chroms = ['hg38_1']
print(f'File paths loaded \n')
# and load pre-trained model
weight_file = '/lustre/scratch126/gengen/teams/parts/jh47/enformer_pytorch/model_weights/'
if torch.cuda.is_available():
    print(f'Using GPU\n')
    enformer = from_pretrained(weight_file, use_checkpointing = True).cuda()
    print(f"Pre-trained model loaded.")
else:
    print('using CPU')


File paths loaded 

Using GPU

Pre-trained model loaded.


In [77]:
# let's try with our toy dataset - using dataloader it will not give us an error
from torch.utils.data import DataLoader
dataset = GenomeDataIntervalDataset(
    all_chroms=['hg38_1'],
    bed_file=bed_file,
    fasta_file=fasta_file)
data_loader = DataLoader(dataset, batch_size=4, shuffle=False)

with torch.no_grad():
    i = 0
    y_hat_vector = []
    for batch in data_loader:
        i += 1
        s = batch[0].cuda()
        y_hat = enformer(s, head = 'human')
        y_hat_vector.append(y_hat)

print(f"{i} batches were predicted")




10 batches were predicted


In [82]:
# get our desired track!
predtest = []
for i in range(len(y_hat_vector)):
    #batchpred = y_hat_vector[i]['human'].cpu().detach().squeeze(0).numpy()
    for j in range(y_hat_vector[0].shape[0]):
        predtest.append(y_hat_vector[i][j][:,1436].cpu().detach().numpy())

print(f"There are {len(predtest[0])} values in each of the {len(predtest)} predictions.")

There are 896 values in each of the 40 predictions.


# Let's simply predict for genomic windows of 196kb

### Get ranges

In [28]:
# jacob's code
seq_len = 196608
# Only center 114688 basepairs are used, 320 bins cropped on either side.
PRED_WIDTH = seq_len - (320 * 128) - (320 * 128)
OUT_N_BINS = (PRED_WIDTH // 128)
PRED_WIDTH

114688

In [85]:
chrom_length = 248956422 # for hg38_1
for s in range(0, chrom_length - seq_len, PRED_WIDTH):
    print(s)

0
114688
229376
344064
458752
573440
688128
802816
917504
1032192
1146880
1261568
1376256
1490944
1605632
1720320
1835008
1949696
2064384
2179072
2293760
2408448
2523136
2637824
2752512
2867200
2981888
3096576
3211264
3325952
3440640
3555328
3670016
3784704
3899392
4014080
4128768
4243456
4358144
4472832
4587520
4702208
4816896
4931584
5046272
5160960
5275648
5390336
5505024
5619712
5734400
5849088
5963776
6078464
6193152
6307840
6422528
6537216
6651904
6766592
6881280
6995968
7110656
7225344
7340032
7454720
7569408
7684096
7798784
7913472
8028160
8142848
8257536
8372224
8486912
8601600
8716288
8830976
8945664
9060352
9175040
9289728
9404416
9519104
9633792
9748480
9863168
9977856
10092544
10207232
10321920
10436608
10551296
10665984
10780672
10895360
11010048
11124736
11239424
11354112
11468800
11583488
11698176
11812864
11927552
12042240
12156928
12271616
12386304
12500992
12615680
12730368
12845056
12959744
13074432
13189120
13303808
13418496
13533184
13647872
13762560
13877248
1399

In [102]:
# jacob's code
chrom_length = 248956422 # for hg38_1
intervals = [("hg38_1", s, s+seq_len) for s in range(0, chrom_length - seq_len, PRED_WIDTH)]

intervals

[('hg38_1', 0, 196608),
 ('hg38_1', 114688, 311296),
 ('hg38_1', 229376, 425984),
 ('hg38_1', 344064, 540672),
 ('hg38_1', 458752, 655360),
 ('hg38_1', 573440, 770048),
 ('hg38_1', 688128, 884736),
 ('hg38_1', 802816, 999424),
 ('hg38_1', 917504, 1114112),
 ('hg38_1', 1032192, 1228800),
 ('hg38_1', 1146880, 1343488),
 ('hg38_1', 1261568, 1458176),
 ('hg38_1', 1376256, 1572864),
 ('hg38_1', 1490944, 1687552),
 ('hg38_1', 1605632, 1802240),
 ('hg38_1', 1720320, 1916928),
 ('hg38_1', 1835008, 2031616),
 ('hg38_1', 1949696, 2146304),
 ('hg38_1', 2064384, 2260992),
 ('hg38_1', 2179072, 2375680),
 ('hg38_1', 2293760, 2490368),
 ('hg38_1', 2408448, 2605056),
 ('hg38_1', 2523136, 2719744),
 ('hg38_1', 2637824, 2834432),
 ('hg38_1', 2752512, 2949120),
 ('hg38_1', 2867200, 3063808),
 ('hg38_1', 2981888, 3178496),
 ('hg38_1', 3096576, 3293184),
 ('hg38_1', 3211264, 3407872),
 ('hg38_1', 3325952, 3522560),
 ('hg38_1', 3440640, 3637248),
 ('hg38_1', 3555328, 3751936),
 ('hg38_1', 3670016, 3866624),

In [94]:
int((seq_len - PRED_WIDTH) /2/128)

320

In [29]:
chrom_length = 248956422 # for hg38_1
intervals = [("hg38_1", s - int((seq_len - PRED_WIDTH) /2), s+seq_len+int((seq_len - PRED_WIDTH) /2)) for s in range(0, chrom_length - seq_len, PRED_WIDTH) if s>0]
# we now have intervals
intervals

[('hg38_1', 73728, 352256),
 ('hg38_1', 188416, 466944),
 ('hg38_1', 303104, 581632),
 ('hg38_1', 417792, 696320),
 ('hg38_1', 532480, 811008),
 ('hg38_1', 647168, 925696),
 ('hg38_1', 761856, 1040384),
 ('hg38_1', 876544, 1155072),
 ('hg38_1', 991232, 1269760),
 ('hg38_1', 1105920, 1384448),
 ('hg38_1', 1220608, 1499136),
 ('hg38_1', 1335296, 1613824),
 ('hg38_1', 1449984, 1728512),
 ('hg38_1', 1564672, 1843200),
 ('hg38_1', 1679360, 1957888),
 ('hg38_1', 1794048, 2072576),
 ('hg38_1', 1908736, 2187264),
 ('hg38_1', 2023424, 2301952),
 ('hg38_1', 2138112, 2416640),
 ('hg38_1', 2252800, 2531328),
 ('hg38_1', 2367488, 2646016),
 ('hg38_1', 2482176, 2760704),
 ('hg38_1', 2596864, 2875392),
 ('hg38_1', 2711552, 2990080),
 ('hg38_1', 2826240, 3104768),
 ('hg38_1', 2940928, 3219456),
 ('hg38_1', 3055616, 3334144),
 ('hg38_1', 3170304, 3448832),
 ('hg38_1', 3284992, 3563520),
 ('hg38_1', 3399680, 3678208),
 ('hg38_1', 3514368, 3792896),
 ('hg38_1', 3629056, 3907584),
 ('hg38_1', 3743744, 402

In [30]:
intervals[0][1] + ((320) * 128) # first actual result

114688

In [31]:
intervals[-1][2] - ((320) * 128) # last actual result

248954880

In [32]:
n_intervals = len(intervals) # the number of intervals
# now we want to store the bins, which are 128 bp long. We take the start of our actual result, and the end, and split it in 128bp bins.
all_pred_intervals = [(s, s+128) for s in range(intervals[0][1] + ((320) * 128), intervals[-1][2] - ((320) * 128), 128)]
len(all_pred_intervals) # we should have all of these values!

1944064

In [126]:
# let's write a .bed file to load later
with open("pred_hg38_1.bed", "w") as bed_file:
    # Loop through the list of tuples and write each line
    for region in intervals:
        chrom, start, end = region
        bed_file.write(f"{chrom}\t{start}\t{end}\n")

print("BED file has been written as 'pred_hg38_1.bed'")

BED file has been written as 'pred_hg38_1.bed'


### Now we have our ranges and the coordinates of our results. Let's predict. I'll use some of Jacob's code.

In [21]:
targets_txt = 'targets_human.txt' # change to targets_mouse.txt if mouse
chrom = 'hg38_1'
chrom_sizes_path = path + 'all_chrom_len.txt'
chrom_sizes = pd.read_csv(chrom_sizes_path, sep='\t', header=None, index_col=0)
chrom_sizes.columns = ['size']
chrom_length = chrom_sizes.loc[chrom].item()


In [117]:
pred_dir = path + f'/predictions/{chrom}'
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)

In [119]:
import pysam
import pyfaidx
fa_loc = '/lustre/scratch126/gengen/projects/graft/Dataset/reference/hg38_galGal6_full/fasta/GRCh38.GRCg6a.full.renamed.merged.fa'
pyfaidx.Faidx(fa_loc)
fasta_open = pysam.Fastafile(fa_loc)

In [22]:
# our targets
targets_predicted_list = {
        'HEK293_H3K4me3': 'CHIP:H3K4me3:HEK293',
        'HEK293_H3K9me3': 'CHIP:H3K9me3:HEK293'
        }

In [23]:
targets_df = pd.read_csv(targets_txt, sep='\t')
targets_to_pred = {}
for key, val in targets_predicted_list.items():
    targets_to_pred[key] = targets_df.loc[targets_df['description'] == val]['index'].tolist()

# Check if all targets contain at least one track
assert sum([len(targets_to_pred[key]) == 0 for key in targets_to_pred.keys()]) == 0, 'Some targets are not found'


In [24]:
targets_to_pred

{'HEK293_H3K4me3': [1176], 'HEK293_H3K9me3': [1436]}

In [14]:
# we can create a data loader with our bed and fasta file:
from enformer_pytorch import GenomeIntervalDataset
from torch.utils.data import DataLoader
bed_file = 'pred_hg38_1.bed'
datachr1 = GenomeIntervalDataset(
    bed_file=bed_file,
    fasta_file=fasta_file)
data_loader = DataLoader(datachr1, batch_size=1, shuffle=False)
# batch size of 1 to save computation time

In [15]:
len(data_loader)

2169

In [16]:
# pred loop
print(f'Batch 1/{len(data_loader)}')
with torch.no_grad():
    y_hat_vector = []
    for i, batch in enumerate(data_loader):
        if (i + 1) % 10 == 0:
            print(f'Batch {i+1}/{len(data_loader)}')
        s = batch.cuda()
        y_hat = enformer(s, head = 'human')
        y_hat_vector.append(y_hat)

Batch 1/2169




Batch 10/2169
Batch 20/2169
Batch 30/2169
Batch 40/2169
Batch 50/2169
Batch 60/2169
Batch 70/2169
Batch 80/2169
Batch 90/2169
Batch 100/2169
Batch 110/2169
Batch 120/2169
Batch 130/2169
Batch 140/2169
Batch 150/2169
Batch 160/2169
Batch 170/2169
Batch 180/2169
Batch 190/2169
Batch 200/2169
Batch 210/2169
Batch 220/2169
Batch 230/2169
Batch 240/2169
Batch 250/2169
Batch 260/2169
Batch 270/2169
Batch 280/2169
Batch 290/2169
Batch 300/2169
Batch 310/2169
Batch 320/2169
Batch 330/2169
Batch 340/2169
Batch 350/2169
Batch 360/2169
Batch 370/2169
Batch 380/2169
Batch 390/2169
Batch 400/2169
Batch 410/2169
Batch 420/2169
Batch 430/2169
Batch 440/2169
Batch 450/2169
Batch 460/2169
Batch 470/2169
Batch 480/2169
Batch 490/2169
Batch 500/2169
Batch 510/2169
Batch 520/2169
Batch 530/2169
Batch 540/2169
Batch 550/2169
Batch 560/2169
Batch 570/2169
Batch 580/2169
Batch 590/2169
Batch 600/2169
Batch 610/2169
Batch 620/2169
Batch 630/2169
Batch 640/2169
Batch 650/2169
Batch 660/2169
Batch 670/2169
Batc

In [18]:
print(f"Shape of the prediction: {y_hat_vector[0].shape}. There are {len(y_hat_vector)} predictions.")
# y_hat_vector[0].shape

Shape of the prediction: torch.Size([1, 896, 5313]). There are 2169 predictions.


Great we have predicted things!

In [35]:
len(all_pred_intervals)
# all_pred_intervals

1944064

In [38]:
# I use this to store the values:
all_target_preds = dict(zip(
        targets_to_pred.keys(),
        [np.zeros((len(all_pred_intervals), 1)) for i in range(len(targets_to_pred.keys()))]
    ))
print(f"The shape of target is  {all_target_preds['HEK293_H3K4me3'].shape}.")
all_target_preds

The shape of target is  (1944064, 1).


{'HEK293_H3K4me3': array([[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]]),
 'HEK293_H3K9me3': array([[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [0.]])}

In [None]:

# (in this case targets_to_pred is a dictionary of targets to be predicted)

# in the prediction loop I have this:
for target in targets_to_pred.keys():
            all_target_preds[target][curr_i:curr_i+bs, 0] = predictions[:, targets_to_pred[target]].mean(1)
# !! Note that this is based on the original Enformer, shapes / indexing in `predictions[:, targets_to_pred[target]].mean(1)` might be different in enformer-pytorch

# you can save predictions with something like this:
for target in all_target_preds.keys():
        with open(f'{pred_dir}/{prefix}_enformer_fullpreds_{target}.bedGraph', 'w') as f:
            f.write(f'track type=bedGraph name="enformer predictions ({target})" description="enformer predictions ({target})"\n')
            for i, ((s, e), score) in enumerate(zip((all_pred_intervals + intervals_to_append), all_target_preds[target][:,0])):
                f.write(f'{chrom}\t{s}\t{e}\t{score}\n')

In [None]:
## concatenate predictions
concat_preds = {}
for target in targets_to_pred.keys():
    concat_preds[target] = np.concatenate((all_target_preds[target], last_preds[target][indices_to_append]))

for target in concat_preds.keys():
    with open(f'{pred_dir}/{prefix}_enformer_fullpreds_{target}.bedGraph', 'w') as f:
        f.write(f'track type=bedGraph name="enformer predictions ({target})" description="enformer predictions ({target})"\n')
        for i, ((s, e), score) in enumerate(zip((all_pred_intervals + intervals_to_append), concat_preds[target][:,0])):
            f.write(f'{chrom}\t{s}\t{e}\t{score}\n')
    