In [1]:
import shutil
import os

import torch
import torch.nn as nn

import boda

# Load models interactively

This notebook describes how to load Malinois and use it for inference. It's important to remember that Malinois processes `(bsz, 4, 600)` tensors but was trained on 200-mers. Therefore you need to pad input sequences with 200 nucleotides on each side from the MPRA vector used to generate the data. This is done using the `FlankBuilder`.

# Get Malinois

Can download directly from a Google Storage bucket you can access.

In [2]:
malinois_path = 'gs://tewhey-public-data/CODA_resources/malinois_artifacts__20211113_021200__287348.tar.gz'
my_model = boda.common.utils.load_model(malinois_path)

Copying gs://tewhey-public-data/CODA_resources/malinois_artifacts__20211113_021200__287348.tar.gz...
\ [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     
archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


In [3]:
input_len = torch.load('./artifacts/torch_checkpoint.pt')['model_hparams'].input_len


# Set flanks

MPRA flanks are saved as constants in the `boda` repo. These need to be sized to (1, 4, 200) each and used to init `FlankBuilder`.

In [4]:
left_pad_len = (input_len - 200) // 2
right_pad_len= (input_len - 200) - left_pad_len

left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-left_pad_len:] 
).unsqueeze(0)
print(f'left flank shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:right_pad_len] 
).unsqueeze(0)
right_flank.shape
print(f'right flank shape: {right_flank.shape}')

flank_builder = boda.common.utils.FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)

flank_builder.cuda()

left flank shape: torch.Size([1, 4, 200])
right flank shape: torch.Size([1, 4, 200])


FlankBuilder()

# Example call

Using `torch.no_grad()` so the computation graph isn't saved to memory. Since sequences are passed to the model as onehots in `torch.float32` format, we can use `torch.randn` to validate the model setup. Here a batch of 10 variable 200 nt (fake) sequences are being padded to 600 nt, then being passed to the model. Note, `my_model` and `flank_builder` have been set on the GPU using `.cuda()` calls. Therefore, the fake sequence also needs to be sent to `cuda`.

Note: this fake sequence will result in pathological predictions, it's only an illustrative example.

In [5]:
placeholder = torch.randn((10,4,200)).cuda() # Simulate a batch_size x 4 nucleotide x 200 nt long sequence
prepped_seq = flank_builder( placeholder )   # Need to add MPRA flanks

with torch.no_grad():
    print( my_model( prepped_seq ) )


tensor([[-2.2126, -1.8508,  5.1186],
        [-1.4284, -1.2446,  1.8232],
        [ 0.7822,  0.1400,  5.8496],
        [ 1.6264,  0.2242, 11.8879],
        [-0.5781, -1.3709,  4.5894],
        [ 0.1255, -0.0334, 10.9916],
        [-2.8357, -1.0953,  3.6760],
        [ 3.3309,  0.2382,  8.7414],
        [ 1.3750, -0.0471,  7.9200],
        [ 0.6532, -0.5759,  8.4386]], device='cuda:0')


# Run on MPRA data set

We're focusing on sequences that are 200 nt long for simplicity. In the paper we padded smaller sequences with additional nucleotides from the flanks.

In [6]:
import pandas as pd
import numpy as np
import csv
from scipy.stats import pearsonr, spearmanr
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt

In [7]:
!gsutil cp gs://tewhey-public-data/CODA_resources/Table_S2__MPRA_dataset.txt ./
mpra_19 = pd.read_table('Table_S2__MPRA_dataset.txt', sep='\t', header=0)

mpra_19 = mpra_19.loc[ mpra_19.loc[:, ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE']].max(axis=1) < 1.0 ]

Copying gs://tewhey-public-data/CODA_resources/Table_S2__MPRA_dataset.txt...
\ [1 files][267.2 MiB/267.2 MiB]                                                
Operation completed over 1 objects/267.2 MiB.                                    


  exec(code_obj, self.user_global_ns, self.user_ns)


In [8]:
pass_seq = mpra_19.loc[ mpra_19['sequence'].str.len() == 200 ].reset_index(drop=True)

seq_tensor  = torch.stack([ boda.common.utils.dna2tensor(x['sequence']) for i, x in tqdm.tqdm(pass_seq.iterrows(), total=pass_seq.shape[0]) ], dim=0)
seq_dataset = torch.utils.data.TensorDataset(seq_tensor)
seq_loader  = torch.utils.data.DataLoader(seq_dataset, batch_size=128)

  0%|          | 0/717741 [00:00<?, ?it/s]

In [9]:
results = []

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(seq_loader)):
        prepped_seq = flank_builder( batch[0].cuda() )
        predictions = my_model( prepped_seq ) + \
                      my_model( prepped_seq.flip(dims=[1,2]) ) # Also
        predictions = predictions.div(2.)
        results.append(predictions.detach().cpu())
                
predictions = torch.cat(results, dim=0)

  0%|          | 0/5608 [00:00<?, ?it/s]

In [10]:
pred_df     = pd.DataFrame( predictions.numpy(), columns=['K562_preds', 'HepG2_preds', 'SKNSH_preds'] )
all_results = pd.concat([pass_seq, pred_df], axis=1)
all_results

Unnamed: 0,IDs,chr,data_project,OL,class,K562_log2FC,HepG2_log2FC,SKNSH_log2FC,K562_lfcSE,HepG2_lfcSE,SKNSH_lfcSE,sequence,K562_preds,HepG2_preds,SKNSH_preds
0,7:70038969:G:T:A:wC,7,UKBB,29,"BMI,BFP",0.060779,0.233601,0.047194,0.098795,0.118254,0.130671,CCTGGTCTTTCTTGCTAAATAAACATATCGTGCATCATCCAGATCT...,0.022539,0.491981,0.470595
1,1:192696196:C:T:A:wC,1,UKBB,33,Depression_GP,0.379639,0.004565,-0.244395,0.162169,0.186394,0.118952,CATAAAGATGAGGCTTGGCAAAGAACATCTCTCGGTGCCTCCCATT...,-0.147521,-0.183192,-0.356858
2,1:211209457:C:T:A:wC,1,UKBB,33,CAD,0.036707,0.384537,-0.004578,0.098391,0.121640,0.087458,CATAAAGCCAATCACTGAGATGACAAGTACTGCCAGGAAAGAAGGC...,-0.171176,0.195917,-0.020640
3,15:89574440:GT:G:R:wC,15,UKBB,33,CAD,4.508784,4.116494,3.040183,0.157035,0.209049,0.195014,CATAAAGGCAGTGTAGACCCAAACAGTGAGCAGTAGCAAGATTTAT...,4.552554,3.782228,4.054496
4,12:63513920:G:A:A:wC,12,UKBB,32,Morning_Person,1.616602,1.423444,1.335892,0.159670,0.148307,0.224775,CATAAAGGGCTGAACATGCTGTTGAAAAAATGTAGATATAAAAGTT...,1.276844,1.127613,1.073648
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
717736,4:44680358:NA:NA,4,CRE,15,K27_All,7.443810,5.344297,6.585129,0.097905,0.069708,0.140760,CAGTAGTAAGAAAGAGACAATGCAAAGGAATTGGCACAGCACTCAG...,4.823985,4.611618,5.402607
717737,18:9125893:NA:NA,18,CRE,15,K27_Uniq,-0.204913,-0.156933,-0.209358,0.133052,0.157279,0.185115,CAGTACTGCTGGCCCCAGAAAAGCCCCTCTCCTTATACCCTAGGCC...,-0.096624,0.147689,-0.033426
717738,12:33905808:NA:NA,12,CRE,15,K27_Uniq,1.218233,0.613623,0.569894,0.127132,0.167222,0.190639,CAGTACCTTGTCCCCACTTCCCATTTGGCCTCTGGCAGAGGAGGAG...,1.395628,0.742739,0.536832
717739,3:128145854:NA:NA,3,CRE,15,K27_Uniq,-0.222234,-0.338764,-0.817852,0.159002,0.198187,0.238637,CAGTACACCCCAGCTTCCAAAGGCCTTCTGTGACAAAGAGAGACTA...,-0.097387,-0.072233,-0.320700


# Validation set performance
Check performance on chromosomes 19, 21, and X (held-out for validation during hparam selection).

In [11]:
chr_filter = (all_results['chr'] == 19) | \
             (all_results['chr'] == 21) | \
             (all_results['chr'] == '19') | \
             (all_results['chr'] == '21') | \
             (all_results['chr'] == 'X')

val_results = all_results.loc[ chr_filter ]

## Pearson's r

In [12]:
for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = pearsonr(val_results[f'{cell}_log2FC'], val_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.9131, pvalue: 0.0
HepG2
stat: 0.9110, pvalue: 0.0
SKNSH
stat: 0.9073, pvalue: 0.0


## Spearman's rho

In [13]:
for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = spearmanr(val_results[f'{cell}_log2FC'], val_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.8405, pvalue: 0.0
HepG2
stat: 0.8615, pvalue: 0.0
SKNSH
stat: 0.8588, pvalue: 0.0


# Test set performance
Check performance on chromosomes 7 and 13 (held-out for final testing, not used for model selection).

In [14]:
chr_filter = (all_results['chr'] == 7) | \
             (all_results['chr'] == 13) | \
             (all_results['chr'] == '7') | \
             (all_results['chr'] == '13')

test_results = all_results.loc[ chr_filter ]

## Pearson's r

In [15]:
for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = pearsonr(test_results[f'{cell}_log2FC'], test_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.8842, pvalue: 0.0
HepG2
stat: 0.8880, pvalue: 0.0
SKNSH
stat: 0.8785, pvalue: 0.0


## Spearman's rho

In [16]:
test_results = all_results.loc[ chr_filter ]

for cell in ['K562', 'HepG2', 'SKNSH']:
    corr = spearmanr(test_results[f'{cell}_log2FC'], test_results[f'{cell}_preds'])
    print(cell)
    print(f'stat: {corr[0]:.4f}, pvalue: {corr[1]}')

K562
stat: 0.8104, pvalue: 0.0
HepG2
stat: 0.8334, pvalue: 0.0
SKNSH
stat: 0.8306, pvalue: 0.0
