In [1]:
import sys
import os
import shutil
import gzip
import multiprocessing

import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import boda

In [2]:
boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn

# Set files and variables

## Files
1. BODA model artifacts
2. ENCODE's preferred FASTA reference
3. A GRCh38 clinvar VCF (or other VCF in GRCh38)

## Variables
These control input augmentation
1. Size of MPRA insert
2. First position where tested variant is places within insert
3. Last position where tested variant is placed within insert
4. Flag to specify if averaging predictions over reverse complement
5. Loader batch size

In [3]:
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz' # Model
fasta_fn = 'GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta' # Genome reference
vcf_fn = 'clinvar.vcf.gz' # VCF example
vcf_fn = 'quicker_vars.vcf.gz'

In [4]:
WINDOW_SIZE = 200
RELATIVE_START = 95
RELATIVE_END = 105
REVERSE_COMPLEMENTS = True

BATCH_SIZE = 500

# Simple helper classes

1. One class to add MPRA vector flanks to input sequences
2. Another to predict skews from pairs of ref and alt alleles and

In [5]:
class FlankBuilder(nn.Module):
    def __init__(self,
                 left_flank=None,
                 right_flank=None,
                 batch_dim=0,
                 cat_axis=-1
                ):
        
        super().__init__()
        
        self.register_buffer('left_flank', left_flank.detach().clone())
        self.register_buffer('right_flank', right_flank.detach().clone())
        
        self.batch_dim = batch_dim
        self.cat_axis  = cat_axis
        
    def add_flanks(self, my_sample):
        *batch_dims, channels, length = my_sample.shape
        
        pieces = []
        
        if self.left_flank is not None:
            pieces.append( self.left_flank.expand(*batch_dims, -1, -1) )
            
        pieces.append( my_sample )
        
        if self.right_flank is not None:
            pieces.append( self.right_flank.expand(*batch_dims, -1, -1) )
            
        return torch.cat( pieces, axis=self.cat_axis )
    
    def forward(self, my_sample):
        return self.add_flanks(my_sample)

class VepTester(nn.Module):
    
    def __init__(self,
                  model
                 ):
        
        super().__init__()
        self.model = model
        
    def forward(self, ref_batch, alt_batch):
        
        ref_shape, alt_shape = ref_batch.shape, alt_batch.shape
        assert ref_shape == alt_shape
        
        ref_batch = ref_batch.flatten(0,1).to(self.model.device)
        alt_batch = alt_batch.flatten(0,1).to(self.model.device)
        
        ref_preds = self.model(ref_batch)
        alt_preds = self.model(alt_batch)

        skew_preds = alt_preds - ref_preds
        skew_preds = skew_preds.unflatten(0, ref_shape[0:2]).mean(1)

        return skew_preds

# Load model

In [6]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')

unpack_artifact(hpo_rec)

model_dir = './artifacts'

my_model = model_fn(model_dir)
my_model.cuda()

archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


BassetBranched(
  (pad1): ConstantPad1d(padding=(9, 9), value=0.0)
  (conv1): Conv1dNorm(
    (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
    (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad2): ConstantPad1d(padding=(5, 5), value=0.0)
  (conv2): Conv1dNorm(
    (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad3): ConstantPad1d(padding=(3, 3), value=0.0)
  (conv3): Conv1dNorm(
    (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad4): ConstantPad1d(padding=(1, 1), value=0.0)
  (maxpool_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (maxpool_4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (linear1): LinearNorm(
    (linear): Linear(in

# Read VCF and FASTA
Parsing the FASTA file takes about 8.5 minutes, so fear not. `tqdm` sees 195 contigs and calculates remaining time based on the size of the last contig. As you know, contig size is highly variable, so the timing estimates are bad.

In [7]:
fasta_dict = boda.data.Fasta(fasta_fn)
test_vcf = boda.data.fasta_datamodule.VCF(vcf_fn, chr_prefix='chr', max_allele_size=10, max_indel_size=10)

pre-reading fasta into memory
100%|██████████| 44284892/44284892 [00:15<00:00, 2832350.11it/s]
finding keys
parsing
100%|██████████| 195/195 [08:06<00:00,  2.49s/it] 
done
66it [00:02, 23.79it/s]skipping large indel at line 611, id: 13:107222934:GTCCCAGCAGCCCCACCTCCTCATACCGTCATC:G:R:wC
skipping large allele at line 686, id: 13:108910495:TATATATATAC:T:R:wC
skipping large indel at line 1080, id: 13:110940361:T:TGGGAGGCTGAG:R:wC
skipping large indel at line 1331, id: 13:111359129:T:TCCCAATACAGTGTGCAC:R:wC
skipping large indel at line 1515, id: 13:112552087:ACAACAGCCGGGG:A:R:wC
skipping large indel at line 1617, id: 13:113592551:GGGGTTGGTTCAGGTGAGCGCTGAT:G:R:wC
skipping large indel at line 1684, id: 13:113641625:C:CATGCAGCCCCCGGTGGGGGGCAGGG:R:wC
skipping large allele at line 1733, id: 13:113759830:T:TTCCTATATCC:R:wC
skipping large indel at line 1941, id: 13:114202200:TCCACCAGCTCTTAGCTACTCGGCCCAAAGACACA:T:R:wC
skipping large indel at line 1979, id: 13:114514584:G:GGGGGCCCTGTGTGA:R:wC
2016it

In [8]:
vcf_data = boda.data.fasta_datamodule.VcfDataset(test_vcf.vcf, fasta_dict.fasta, 
                                                 WINDOW_SIZE, RELATIVE_START, RELATIVE_END, 
                                                 reverse_complements=REVERSE_COMPLEMENTS)
vcf_loader = torch.utils.data.DataLoader(vcf_data, batch_size=BATCH_SIZE, num_workers=4)

filtering vcf by contig keys
100%|██████████| 29457/29457 [00:00<00:00, 1614019.95it/s]
returned 29457/29457 records


# Prepare MPRA flanks and VEP function

In [9]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0).unsqueeze(0)
left_flank.shape

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0).unsqueeze(0)
right_flank.shape

torch.Size([1, 1, 4, 200])

In [10]:
flank_builder = FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)

vep_tester = VepTester(my_model)

# Run VEP

In [11]:
predictions = torch.empty(size=(0,3))

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(vcf_loader)):
        ref_allele, alt_allele = batch['ref'], batch['alt']
        
        ref_allele = flank_builder(ref_allele)
        alt_allele = flank_builder(alt_allele)
        
        skew_preds = vep_tester(ref_allele, alt_allele)
        
        predictions = torch.cat([predictions, skew_preds.cpu()], dim=0)
                

  5%|▌         | 3/59 [00:03<01:12,  1.29s/it]


RuntimeError: Caught RuntimeError in DataLoader worker process 3.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/ubuntu/boda2/boda/data/fasta_datamodule.py", line 366, in __getitem__
    ref_slices = self.window_slicer(ref_segments)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/boda2/boda/data/fasta_datamodule.py", line 46, in forward
    hook = F.conv1d(input, self.weight)
RuntimeError: Calculated padded input size per channel: (1). Kernel size: (200). Kernel size can't be greater than actual input size


# Assemble DataFrame

In [None]:
vcf_df = pd.concat([
    pd.DataFrame(vcf_data.vcf),
    pd.DataFrame(data=predictions.numpy(), columns=['K562_skew_pred','HepG2_skew_pred','SKNSH_skew_pred'])
], axis=1)

vcf_df['ref'] = [ "".join([ boda.common.constants.STANDARD_NT[ x[:,i].argmax() ] if x[:,i].sum() == 1 else 'N' for i in range( x.shape[-1] ) ]) 
                  for x in vcf_df['ref'] ]

vcf_df['alt'] = [ "".join([ boda.common.constants.STANDARD_NT[ x[:,i].argmax() ] if x[:,i].sum() == 1 else 'N' for i in range( x.shape[-1] ) ]) 
                  for x in vcf_df['alt'] ]



In [None]:
vcf_df