In [1]:
import shutil
import os

import torch
import torch.nn as nn

import boda

# Load models interactively

This notebook describes how to load Malinois and use it for inference. It's important to remember that Malinois processes `(bsz, 4, 600)` tensors but was trained on 200-mers. Therefore you need to pad input sequences with 200 nucleotides on each side from the MPRA vector used to generate the data. This is done using the `FlankBuilder`.

# Get Malinois

Can download directly from a Google Storage bucket you can access.

In [2]:
malinois_path = 'gs://tewhey-public-data/CODA_resources/malinois_artifacts__20211113_021200__287348.tar.gz'
my_model = boda.common.utils.load_model(malinois_path)

Copying gs://tewhey-public-data/CODA_resources/malinois_artifacts__20211113_021200__287348.tar.gz...
| [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     
archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


In [3]:
input_len = torch.load('./artifacts/torch_checkpoint.pt')['model_hparams'].input_len


# Set flanks

MPRA flanks are saved as constants in the `boda` repo. These need to be sized to (1, 4, 200) each and used to init `FlankBuilder`.

In [4]:
left_pad_len = (input_len - 200) // 2
right_pad_len= (input_len - 200) - left_pad_len

left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-left_pad_len:] 
).unsqueeze(0)
print(f'left flank shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:right_pad_len] 
).unsqueeze(0)
right_flank.shape
print(f'right flank shape: {right_flank.shape}')

flank_builder = boda.common.utils.FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)

flank_builder.cuda()

left flank shape: torch.Size([1, 4, 200])
right flank shape: torch.Size([1, 4, 200])


FlankBuilder()

# Example call

Using `torch.no_grad()` so the computation graph isn't saved to memory. Since sequences are passed to the model as onehots in `torch.float32` format, we can use `torch.randn` to validate the model setup. Here a batch of 10 variable 200 nt (fake) sequences are being padded to 600 nt, then being passed to the model. Note, `my_model` and `flank_builder` have been set on the GPU using `.cuda()` calls. Therefore, the fake sequence also needs to be sent to `cuda`.

Note: this fake sequence will result in pathological predictions, it's only an illustrative example.

In [5]:
placeholder = torch.randn((10,4,200)).cuda() # Simulate a batch_size x 4 nucleotide x 200 nt long sequence
prepped_seq = flank_builder( placeholder )   # Need to add MPRA flanks

with torch.no_grad():
    print( my_model( prepped_seq ) )


tensor([[-3.4321, -1.6530,  3.0183],
        [-1.9178, -1.4821,  4.1295],
        [-1.1529, -0.2439, 12.0388],
        [-1.6435, -0.3880,  3.0986],
        [ 2.6905,  0.0489,  7.6151],
        [-0.2412, -1.0400,  3.9991],
        [-2.1948, -0.6833,  5.3087],
        [-0.8940, -0.0644,  9.7120],
        [-1.0495, -1.0386,  6.9109],
        [ 0.5872,  0.7307,  9.7070]], device='cuda:0')


# Run on MPRA data set

We're focusing on sequences that are 200 nt long for simplicity. In the paper we padded smaller sequences with additional nucleotides from the flanks.

In [6]:
import pandas as pd
import numpy as np
import csv
from scipy.stats import pearsonr, spearmanr
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt

In [7]:
!gsutil cp gs://tewhey-public-data/CODA_resources/MPRA_ALL_HD_v2.txt ./
mpra_19 = pd.read_table('MPRA_ALL_HD_v2.txt', sep=' ', header=0)

Copying gs://tewhey-public-data/CODA_resources/MPRA_ALL_HD_v2.txt...
- [1 files][311.6 MiB/311.6 MiB]                                                
Operation completed over 1 objects/311.6 MiB.                                    


  exec(code_obj, self.user_global_ns, self.user_ns)


In [8]:
pass_seq = mpra_19.loc[ mpra_19['nt_sequence'].str.len() == 200 ].reset_index(drop=True)

seq_tensor  = torch.stack([ boda.common.utils.dna2tensor(x['nt_sequence']) for i, x in tqdm.tqdm(pass_seq.iterrows(), total=pass_seq.shape[0]) ], dim=0)
seq_dataset = torch.utils.data.TensorDataset(seq_tensor)
seq_loader  = torch.utils.data.DataLoader(seq_dataset, batch_size=128)

  0%|          | 0/665326 [00:00<?, ?it/s]

In [9]:
results = []

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(seq_loader)):
        prepped_seq = flank_builder( batch[0].cuda() )
        predictions = my_model( prepped_seq ) + \
                      my_model( prepped_seq.flip(dims=[1,2]) ) # Also
        predictions = predictions.div(2.)
        results.append(predictions.detach().cpu())
                
predictions = torch.cat(results, dim=0)

  0%|          | 0/5198 [00:00<?, ?it/s]

In [10]:
pred_df     = pd.DataFrame( predictions.numpy(), columns=['K562_preds', 'HepG2_preds', 'SKNSH_preds'] )
all_results = pd.concat([pass_seq, pred_df], axis=1)
all_results

Unnamed: 0,HepG2_mean,HepG2_std,ID_count,IDs,K562_mean,K562_std,OL,OL_count,SKNSH_mean,SKNSH_std,...,exp_mean_hepg2,exp_mean_k562,exp_mean_sknsh,lfcSE_hepg2,lfcSE_k562,lfcSE_sknsh,nt_sequence,K562_preds,HepG2_preds,SKNSH_preds
0,0.936302,,1,10:133978962:T:C:A:wC,1.138157,,29,1.0,0.433954,,...,3903.485478,4491.882102,2758.124172,0.070170,0.071954,0.099081,CGCTTGTTCTCCCACGTGGGGCTGGTTCAGTCATGTCTGGGGGTGA...,1.414011,1.322426,0.819247
1,-0.019441,0.181691,1,10:103714782:C:T:A:wC,-0.241691,0.116453,2933,2.0,-0.383256,0.027153,...,438.217426,363.590834,307.232292,0.183618,0.149959,0.197220,CAGTTGAGCAGGTATGTCAGACTTTTATAAAATATCTCCCCCACTC...,-0.249855,-0.161317,-0.276066
2,2.198518,0.264541,1,10:1932535:G:A:R:wC,0.806424,0.155304,273031,3.0,1.832190,0.163337,...,4911.621720,2131.653819,4157.073042,0.107208,0.103578,0.113405,CAGTTGAATCCATTTCATCAAAATTTATCGATTAAAATCAGTCCTA...,0.892995,1.522612,1.519431
3,2.151191,,1,3:155535351:G:C:A:wC,2.617037,,30,1.0,2.195686,,...,701.519324,969.196210,722.443599,0.117329,0.129472,0.113711,CAGTTGAATCTTATCCTTCATTTTCTTTCTGACCTTATACTTACTT...,2.049520,2.146376,2.141627
4,2.528459,,1,3:155535351:G:C:R:wC,2.816753,,30,1.0,2.509255,,...,963.142129,1176.035964,948.146039,0.130918,0.141325,0.091585,CAGTTGAATCTTATCCTTCATTTTCTTTCTGACCTTATACTTACTT...,2.269007,2.329491,2.307048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
665321,-0.321885,,1,10:15182849:NA:NA,-0.573023,,,,-0.193369,,...,211.399415,177.638087,231.025429,0.196973,0.172873,0.200871,CAGGGTGCCTGGGGGTTGGCTCCACAGACAGGGATGGCTGCCATCT...,-0.233415,-0.242785,-0.177466
665322,0.444297,,1,2:197276159:NA:NA,-0.047592,,,,0.319537,,...,115.003293,81.351179,105.356335,0.265899,0.251303,0.275643,CAGGGTGCAGTGCTCTTCCTTTTTAGTGGGTGCTATCACTTCTAAA...,-0.004899,0.889780,-0.015843
665323,-0.088156,,1,11:10712109:NA:NA,-0.217911,,,,-0.106013,,...,157.378913,143.499246,155.354492,0.226295,0.182207,0.231551,CAGGGTCATGGGCGTGAGTTACCTCTGCTAAGACTCTGAATTTGAA...,-0.280889,0.021142,-0.305690
665324,1.053338,,1,sample_4_4_4_02__044:0559,1.276081,,,,0.702644,,...,170.042767,198.413252,133.212516,0.238876,0.170216,0.247984,CAGGGTCATAGGGGGAGTTATTCAAGCACTCTAGCTGACCGCTGTC...,1.967246,2.293525,2.475624


# Validation set performance
Check performance on chromosomes 19, 21, and X (held-out for validation during hparam selection).

In [11]:
chr_filter = (all_results['chr'] == 19) | \
             (all_results['chr'] == 21) | \
             (all_results['chr'] == '19') | \
             (all_results['chr'] == '21') | \
             (all_results['chr'] == 'X')

val_results = all_results.loc[ chr_filter ]

## Pearson's r

In [12]:
for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = pearsonr(val_results[f'{cell}_mean'], val_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.9152, pvalue: 0.0
HepG2
stat: 0.9131, pvalue: 0.0
SKNSH
stat: 0.9104, pvalue: 0.0


## Spearman's rho

In [13]:
for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = spearmanr(val_results[f'{cell}_mean'], val_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.8504, pvalue: 0.0
HepG2
stat: 0.8675, pvalue: 0.0
SKNSH
stat: 0.8661, pvalue: 0.0


# Test set performance
Check performance on chromosomes 7 and 13 (held-out for final testing, not used for model selection).

In [14]:
chr_filter = (all_results['chr'] == 7) | \
             (all_results['chr'] == 13) | \
             (all_results['chr'] == '7') | \
             (all_results['chr'] == '13')

test_results = all_results.loc[ chr_filter ]

## Pearson's r

In [15]:
for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = pearsonr(test_results[f'{cell}_mean'], test_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.8821, pvalue: 0.0
HepG2
stat: 0.8856, pvalue: 0.0
SKNSH
stat: 0.8766, pvalue: 0.0


## Spearman's rho

In [16]:
test_results = all_results.loc[ chr_filter ]

for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = spearmanr(test_results[f'{cell}_mean'], test_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.8198, pvalue: 0.0
HepG2
stat: 0.8381, pvalue: 0.0
SKNSH
stat: 0.8378, pvalue: 0.0
