In [1]:
import shutil
import os

import torch
import torch.nn as nn

import boda

# Helpers

1. `load_model` checks GPUs, clears a spot for the model to be downloaded, downloads, and loads the model in `eval` mode.

2. `FlankBuilder` is used to pad inputs with MPRA vector backbone sequence. For technical reasons, Malinois reads 600 nt sequences (i.e., n x 4 x 600 inputs) but it should be 200 nt of variable sequence padded with MPRA backbone.

In [2]:
def load_model(artifact_path):
    
    USE_CUDA = torch.cuda.device_count() >= 1
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')

    boda.common.utils.unpack_artifact(artifact_path)

    model_dir = './artifacts'

    my_model = boda.common.utils.model_fn(model_dir)
    my_model.eval()
    if USE_CUDA:
        my_model.cuda()
    
    return my_model

# Get Malinois

Can download directly from a Google Storage bucket you can access.

In [3]:
malinois_path = 'gs://tewhey-public-data/CODA_resources/malinois_model__20211113_021200__287348.tar.gz'
my_model = load_model(malinois_path)

Copying gs://tewhey-public-data/CODA_resources/malinois_model__20211113_021200__287348.tar.gz...
\ [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     
archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


# Set flanks

MPRA flanks are saved as constants in the `boda` repo. These need to be sized to (1, 4, 200) each and used to init `FlankBuilder`.

In [4]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)
print(f'left flank shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)
right_flank.shape
print(f'right flank shape: {right_flank.shape}')

flank_builder = boda.common.utils.FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)

flank_builder.cuda()

left flank shape: torch.Size([1, 4, 200])
right flank shape: torch.Size([1, 4, 200])


FlankBuilder()

# Example call

Using `torch.no_grad()` so the computation graph isn't saved to memory. Since sequences are passed to the model as onehots in `torch.float32` format, we can use `torch.randn` to validate the model setup. Here a batch of 10 variable 200 nt (fake) sequences are being padded to 600 nt, then being passed to the model. Note, `my_model` and `flank_builder` have been set on the GPU using `.cuda()` calls. Therefore, the fake sequence also needs to be sent to `cuda`.

In [5]:
with torch.no_grad():
    print( 
        my_model( 
            flank_builder(                     # Need to add MPRA flanks
                torch.randn((10,4,200)).cuda() # Simulate a batch_size x 4 nucleotide x 200 nt long sequence
            ) 
        ) 
    )

tensor([[ 0.8397,  0.6700,  9.3325],
        [-3.0140, -1.0176,  6.5865],
        [-0.7549,  0.3068, 10.9675],
        [ 3.1521, -0.2887, 11.7337],
        [-3.2651, -2.1552,  3.8119],
        [-0.5821,  1.0458,  9.1559],
        [-2.9773, -1.3296,  3.6720],
        [ 0.8760, -0.3315,  7.7030],
        [-2.0601, -0.3807, 11.6970],
        [ 2.0034,  0.5785,  5.9147]], device='cuda:0')


# Run on MPRA data set

We're focusing on sequences that are 200 nt long for simplicity. 

In [6]:
import pandas as pd
import numpy as np
import csv
from scipy.stats import pearsonr
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt

In [7]:
!gsutil cp gs://tewhey-public-data/CODA_resources/MPRA_ALL_HD_v2.txt ./
mpra_19 = pd.read_table('MPRA_ALL_HD_v2.txt', sep=' ', header=0)

Copying gs://tewhey-public-data/CODA_resources/MPRA_ALL_HD_v2.txt...
| [1 files][311.6 MiB/311.6 MiB]                                                
Operation completed over 1 objects/311.6 MiB.                                    


  exec(code_obj, self.user_global_ns, self.user_ns)


In [8]:
pass_seq = mpra_19.loc[ mpra_19['nt_sequence'].str.len() == 200 ].reset_index(drop=True)

seq_tensor  = torch.stack([ boda.common.utils.dna2tensor(x['nt_sequence']) for i, x in tqdm.tqdm(pass_seq.iterrows(), total=pass_seq.shape[0]) ], dim=0)
seq_dataset = torch.utils.data.TensorDataset(seq_tensor)
seq_loader  = torch.utils.data.DataLoader(seq_dataset, batch_size=128)

  0%|          | 0/665326 [00:00<?, ?it/s]

In [9]:
results = []

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(seq_loader)):
        prepped_seq = flank_builder( batch[0].cuda() )
        predictions = my_model( prepped_seq ) + \
                      my_model( prepped_seq.flip(dims=[1,2]) ) # Also
        predictions = predictions.div(2.)
        results.append(predictions.detach().cpu())
                
predictions = torch.cat(results, dim=0)

  0%|          | 0/5198 [00:00<?, ?it/s]

In [10]:
pred_df = pd.DataFrame( predictions.numpy(), columns=['K562_preds', 'HepG2_preds', 'SKNSH_preds'] )
pred_df

Unnamed: 0,K562_preds,HepG2_preds,SKNSH_preds
0,1.414037,1.322234,0.819266
1,-0.249841,-0.161379,-0.276085
2,0.893052,1.522660,1.519637
3,2.049274,2.146186,2.141511
4,2.268655,2.329347,2.307040
...,...,...,...
665321,-0.233427,-0.242739,-0.177346
665322,-0.004860,0.889347,-0.015702
665323,-0.280905,0.021103,-0.305730
665324,1.966225,2.292811,2.474705


In [11]:
all_results = pd.concat([pass_seq, pred_df], axis=1)
all_results

Unnamed: 0,HepG2_mean,HepG2_std,ID_count,IDs,K562_mean,K562_std,OL,OL_count,SKNSH_mean,SKNSH_std,...,exp_mean_hepg2,exp_mean_k562,exp_mean_sknsh,lfcSE_hepg2,lfcSE_k562,lfcSE_sknsh,nt_sequence,K562_preds,HepG2_preds,SKNSH_preds
0,0.936302,,1,10:133978962:T:C:A:wC,1.138157,,29,1.0,0.433954,,...,3903.485478,4491.882102,2758.124172,0.070170,0.071954,0.099081,CGCTTGTTCTCCCACGTGGGGCTGGTTCAGTCATGTCTGGGGGTGA...,1.414037,1.322234,0.819266
1,-0.019441,0.181691,1,10:103714782:C:T:A:wC,-0.241691,0.116453,2933,2.0,-0.383256,0.027153,...,438.217426,363.590834,307.232292,0.183618,0.149959,0.197220,CAGTTGAGCAGGTATGTCAGACTTTTATAAAATATCTCCCCCACTC...,-0.249841,-0.161379,-0.276085
2,2.198518,0.264541,1,10:1932535:G:A:R:wC,0.806424,0.155304,273031,3.0,1.832190,0.163337,...,4911.621720,2131.653819,4157.073042,0.107208,0.103578,0.113405,CAGTTGAATCCATTTCATCAAAATTTATCGATTAAAATCAGTCCTA...,0.893052,1.522660,1.519637
3,2.151191,,1,3:155535351:G:C:A:wC,2.617037,,30,1.0,2.195686,,...,701.519324,969.196210,722.443599,0.117329,0.129472,0.113711,CAGTTGAATCTTATCCTTCATTTTCTTTCTGACCTTATACTTACTT...,2.049274,2.146186,2.141511
4,2.528459,,1,3:155535351:G:C:R:wC,2.816753,,30,1.0,2.509255,,...,963.142129,1176.035964,948.146039,0.130918,0.141325,0.091585,CAGTTGAATCTTATCCTTCATTTTCTTTCTGACCTTATACTTACTT...,2.268655,2.329347,2.307040
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
665321,-0.321885,,1,10:15182849:NA:NA,-0.573023,,,,-0.193369,,...,211.399415,177.638087,231.025429,0.196973,0.172873,0.200871,CAGGGTGCCTGGGGGTTGGCTCCACAGACAGGGATGGCTGCCATCT...,-0.233427,-0.242739,-0.177346
665322,0.444297,,1,2:197276159:NA:NA,-0.047592,,,,0.319537,,...,115.003293,81.351179,105.356335,0.265899,0.251303,0.275643,CAGGGTGCAGTGCTCTTCCTTTTTAGTGGGTGCTATCACTTCTAAA...,-0.004860,0.889347,-0.015702
665323,-0.088156,,1,11:10712109:NA:NA,-0.217911,,,,-0.106013,,...,157.378913,143.499246,155.354492,0.226295,0.182207,0.231551,CAGGGTCATGGGCGTGAGTTACCTCTGCTAAGACTCTGAATTTGAA...,-0.280905,0.021103,-0.305730
665324,1.053338,,1,sample_4_4_4_02__044:0559,1.276081,,,,0.702644,,...,170.042767,198.413252,133.212516,0.238876,0.170216,0.247984,CAGGGTCATAGGGGGAGTTATTCAAGCACTCTAGCTGACCGCTGTC...,1.966225,2.292811,2.474705


# Test set performance
Check performance on chromosomes 7 and 13 (held-out for training and validation).

In [12]:
chr_filter = (all_results['chr'] == 7) | (all_results['chr'] == 13) | (all_results['chr'] == '7') | (all_results['chr'] == '13')


In [13]:
pearsonr(all_results.loc[ chr_filter,'K562_mean'], all_results.loc[ chr_filter, 'K562_preds'])

(0.8821404660621879, 0.0)

In [14]:
pearsonr(all_results.loc[ chr_filter,'HepG2_mean'], all_results.loc[ chr_filter, 'HepG2_preds'])

(0.8855923303550868, 0.0)

In [15]:
pearsonr(all_results.loc[ chr_filter,'SKNSH_mean'], all_results.loc[ chr_filter, 'SKNSH_preds'])

(0.8765583735317558, 0.0)

In [16]:
all_results.loc[:,['IDs', 'nt_sequence', 'K562_preds', 'HepG2_preds', 'SKNSH_preds']].to_csv('inference_check.tsv', sep='\t', index=False, header=True, quoting=csv.QUOTE_NONE)
all_results.loc[:,['IDs', 'nt_sequence', 'K562_preds', 'HepG2_preds', 'SKNSH_preds']]

Unnamed: 0,IDs,nt_sequence,K562_preds,HepG2_preds,SKNSH_preds
0,10:133978962:T:C:A:wC,CGCTTGTTCTCCCACGTGGGGCTGGTTCAGTCATGTCTGGGGGTGA...,1.414037,1.322234,0.819266
1,10:103714782:C:T:A:wC,CAGTTGAGCAGGTATGTCAGACTTTTATAAAATATCTCCCCCACTC...,-0.249841,-0.161379,-0.276085
2,10:1932535:G:A:R:wC,CAGTTGAATCCATTTCATCAAAATTTATCGATTAAAATCAGTCCTA...,0.893052,1.522660,1.519637
3,3:155535351:G:C:A:wC,CAGTTGAATCTTATCCTTCATTTTCTTTCTGACCTTATACTTACTT...,2.049274,2.146186,2.141511
4,3:155535351:G:C:R:wC,CAGTTGAATCTTATCCTTCATTTTCTTTCTGACCTTATACTTACTT...,2.268655,2.329347,2.307040
...,...,...,...,...,...
665321,10:15182849:NA:NA,CAGGGTGCCTGGGGGTTGGCTCCACAGACAGGGATGGCTGCCATCT...,-0.233427,-0.242739,-0.177346
665322,2:197276159:NA:NA,CAGGGTGCAGTGCTCTTCCTTTTTAGTGGGTGCTATCACTTCTAAA...,-0.004860,0.889347,-0.015702
665323,11:10712109:NA:NA,CAGGGTCATGGGCGTGAGTTACCTCTGCTAAGACTCTGAATTTGAA...,-0.280905,0.021103,-0.305730
665324,sample_4_4_4_02__044:0559,CAGGGTCATAGGGGGAGTTATTCAAGCACTCTAGCTGACCGCTGTC...,1.966225,2.292811,2.474705
