In [1]:
import sys
import os
import shutil
import gzip
import csv
import multiprocessing

import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import boda

import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

In [2]:
boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn

In [3]:
torch.cuda.device_count()

1

In [4]:
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
fasta_fn = 'GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta' # Genome reference


In [6]:
fasta_dict = boda.data.Fasta(fasta_fn)

pre-reading fasta into memory
100%|██████████| 44284892/44284892 [00:16<00:00, 2617376.00it/s]
finding keys
parsing
100%|██████████| 195/195 [08:45<00:00,  2.69s/it] 
done
loading DataFrame
Checking and filtering tokens
Allele length checks
Done


In [50]:
vcf_fn = 'gencode.v42.500bp.promoter.cosmic.vars.vcf'
out_fn = 'gencode.v42.500bp.promoter.cosmic.vars.pt'

vcf_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.vcf'
out_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.pt'

vcf_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.vcf'
out_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.full_vector_rc.pt'

vcf_fn = 'k562.ase.calls.fdr.05.vcf'
out_fn = 'k562.ase.calls.fdr.05.full_vector_rc.pt'

vcf_fn = 'hepg2.ase.calls.fdr.05.vcf'
out_fn = 'hepg2.ase.calls.fdr.05.pt'


In [51]:
test_vcf = boda.data.VCF(vcf_fn, chr_prefix='', max_allele_size=20, max_indel_size=20)

loading DataFrame
Checking and filtering tokens
Allele length checks
Done


In [52]:
WINDOW_SIZE = 200
RELATIVE_START = 0
RELATIVE_END = 200
REVERSE_COMPLEMENTS = True

BATCH_SIZE = 16 * 1
NUM_WORKERS= 0


In [53]:
vcf_data = boda.data.VcfDataset(test_vcf.vcf, fasta_dict.fasta, 
                                                 WINDOW_SIZE, RELATIVE_START, RELATIVE_END, 
                                                 reverse_complements=REVERSE_COMPLEMENTS)
vcf_loader = torch.utils.data.DataLoader(vcf_data, batch_size=BATCH_SIZE*torch.cuda.device_count(), num_workers=NUM_WORKERS*torch.cuda.device_count())

returned 499/499 records


In [54]:
vcf_data[0]['ref'].shape

torch.Size([400, 4, 200])

In [55]:
vcf_data[0]['ref'].unflatten(0,(2 if REVERSE_COMPLEMENTS else 1,200)).shape

torch.Size([2, 200, 4, 200])

In [56]:
for i in range(200):
    reshaped_vcf_rec = vcf_data[0]['ref'].unflatten(0,(2 if REVERSE_COMPLEMENTS else 1,200))
    print(''.join([ ['A','C','G','T'][n] for n in reshaped_vcf_rec[0,i].argmax(0) ]))

GCAGTTCTCCTCGCGCTCCCACACCTGGTCCGCCCAGTCGGAACTCACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAAC
CAGTTCTCCTCGCGCTCCCACACCTGGTCCGCCCAGTCGGAACTCACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACG
AGTTCTCCTCGCGCTCCCACACCTGGTCCGCCCAGTCGGAACTCACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGA
GTTCTCCTCGCGCTCCCACACCTGGTCCGCCCAGTCGGAACTCACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGAC
TTCTCCTCGCGCTCCCACACCTGGTCCGCCCAGTCGGAACTCACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAAC

In [57]:
# (n_variants, n_alleles, n_strands, n_positions)

In [58]:
class FlankBuilder(nn.Module):
    def __init__(self,
                 left_flank=None,
                 right_flank=None,
                 batch_dim=0,
                 cat_axis=-1
                ):
        
        super().__init__()
        
        self.register_buffer('left_flank', left_flank.detach().clone())
        self.register_buffer('right_flank', right_flank.detach().clone())
        
        self.batch_dim = batch_dim
        self.cat_axis  = cat_axis
        
    def add_flanks(self, my_sample):
        *batch_dims, channels, length = my_sample.shape
        
        pieces = []
        
        if self.left_flank is not None:
            pieces.append( self.left_flank.expand(*batch_dims, -1, -1) )
            
        pieces.append( my_sample )
        
        if self.right_flank is not None:
            pieces.append( self.right_flank.expand(*batch_dims, -1, -1) )
            
        return torch.cat( pieces, axis=self.cat_axis )
    
    def forward(self, my_sample):
        return self.add_flanks(my_sample)

class VepTester(nn.Module):
    
    def __init__(self,
                  model
                 ):
        
        super().__init__()
        self.model = torch.nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
        
    def forward(self, ref_batch, alt_batch):
        
        ref_shape, alt_shape = ref_batch.shape, alt_batch.shape
        assert ref_shape == alt_shape
        
        ref_batch = ref_batch.flatten(0,1).cuda()
        alt_batch = alt_batch.flatten(0,1).cuda()
        
        with torch.cuda.amp.autocast():
            ref_preds = self.model(ref_batch.contiguous())
            alt_preds = self.model(alt_batch.contiguous())

        ref_preds = ref_preds.unflatten(0, ref_shape[0:2])
        ref_preds = ref_preds.unflatten(1, (2, ref_shape[1]//2))
        
        alt_preds = alt_preds.unflatten(0, alt_shape[0:2])
        alt_preds = alt_preds.unflatten(1, (2, alt_shape[1]//2))
            
        skew_preds = alt_preds - ref_preds

        return {'ref': ref_preds, 
                'alt': alt_preds, 
                'skew': skew_preds}
    
class VepTester_FullRC(nn.Module):
    
    def __init__(self,
                  model
                 ):
        
        super().__init__()
        self.model = torch.nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
        
    def forward(self, ref_batch, alt_batch):
        
        ref_shape, alt_shape = ref_batch.shape, alt_batch.shape
        assert ref_shape == alt_shape
        
        ref_batch = ref_batch.flatten(0,1).cuda()
        alt_batch = alt_batch.flatten(0,1).cuda()
        
        with torch.cuda.amp.autocast():
            ref_preds = self.model(ref_batch.contiguous()).unflatten(0, ref_shape[0:2])
            alt_preds = self.model(alt_batch.contiguous()).unflatten(0, ref_shape[0:2])
            
        with torch.cuda.amp.autocast():
            ref_preds_rc = self.model(ref_batch.flip(dims=[-2,-1]).contiguous()).unflatten(0, ref_shape[0:2])
            alt_preds_rc = self.model(alt_batch.flip(dims=[-2,-1]).contiguous()).unflatten(0, ref_shape[0:2])

        ref_preds = torch.stack([ref_preds, ref_preds_rc], dim=1)
        alt_preds = torch.stack([alt_preds, alt_preds_rc], dim=1)
            
        skew_preds = alt_preds - ref_preds

        return {'ref': ref_preds, 
                'alt': alt_preds, 
                'skew': skew_preds}

In [59]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')

unpack_artifact(hpo_rec)

model_dir = './artifacts'

my_model = model_fn(model_dir)
my_model.cuda()
my_model.eval()

Loaded model from 20211113_021200 in eval mode


archive unpacked in ./


BassetBranched(
  (pad1): ConstantPad1d(padding=(9, 9), value=0.0)
  (conv1): Conv1dNorm(
    (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
    (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad2): ConstantPad1d(padding=(5, 5), value=0.0)
  (conv2): Conv1dNorm(
    (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad3): ConstantPad1d(padding=(3, 3), value=0.0)
  (conv3): Conv1dNorm(
    (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad4): ConstantPad1d(padding=(1, 1), value=0.0)
  (maxpool_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (maxpool_4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (linear1): LinearNorm(
    (linear): Linear(in

In [60]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0).unsqueeze(0)
left_flank.shape

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0).unsqueeze(0)
right_flank.shape

torch.Size([1, 1, 4, 200])

In [61]:
flank_builder = FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)
flank_builder.cuda()
vep_tester = VepTester(my_model) if REVERSE_COMPLEMENTS else VepTester_FullRC(my_model)

In [62]:
ref_preds = []
alt_preds = []
skew_preds= []

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(vcf_loader)):
        ref_allele, alt_allele = batch['ref'], batch['alt']
        
        ref_allele = flank_builder(ref_allele.cuda()).contiguous()
        alt_allele = flank_builder(alt_allele.cuda()).contiguous()
        
        all_preds = vep_tester(ref_allele, alt_allele)
        
        ref_preds.append(all_preds['ref'].cpu())
        alt_preds.append(all_preds['alt'].cpu())
        skew_preds.append(all_preds['skew'].cpu())

ref_preds = torch.cat(ref_preds, dim=0)
alt_preds = torch.cat(alt_preds, dim=0)
skew_preds= torch.cat(skew_preds, dim=0)

100%|██████████| 32/32 [00:06<00:00,  4.94it/s]


In [63]:
torch.save({'ref': ref_preds, 'alt': alt_preds}, out_fn)

In [64]:
ref_preds.shape

torch.Size([499, 2, 200, 3])