In [1]:
import sys
import os
import shutil
import gzip
import csv
import multiprocessing

import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import combine_state_for_ensemble, vmap

import numpy as np
import pandas as pd
import boda

import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

In [2]:
boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn

In [3]:
torch.cuda.device_count()

1

In [4]:
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
fasta_fn = 'GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta' # Genome reference


In [5]:
fasta_dict = boda.data.Fasta(fasta_fn)

pre-reading fasta into memory
100%|██████████████████████████| 44284892/44284892 [00:16<00:00, 2684409.76it/s]
finding keys
parsing
100%|█████████████████████████████████████████| 195/195 [11:35<00:00,  3.57s/it]
done


In [32]:
vcf_fn = 'gencode.v42.500bp.promoter.cosmic.vars.vcf'
out_fn = 'gencode.v42.500bp.promoter.cosmic.vars.pt'

vcf_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.vcf'
out_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.pt'

vcf_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.vcf'
out_fn = 'all.significant.calls.k562.hepg2.rd.35.p.point.2.full_vector_rc.pt'

vcf_fn = 'k562.ase.calls.fdr.05.vcf'
out_fn = 'k562.ase.calls.fdr.05.full_vector_rc.pt'

vcf_fn = 'dnase.v2.all.cells.aggreg.vcf'
out_fn = 'dnase.v2.all.cells.aggreg.pt'

vcf_fn = 'hepg2.ase.calls.fdr.05.vcf'
out_fn = 'hepg2.ase.calls.fdr.05.pt'

In [33]:
test_vcf = boda.data.VCF(vcf_fn, chr_prefix='', max_allele_size=20, max_indel_size=20)

loading DataFrame
Checking and filtering tokens
Allele length checks
Done


In [34]:
torch.randn(499, 3).cpu().numpy()

array([[ 0.29374588, -0.17407839, -2.332368  ],
       [-0.14504315,  0.01482684, -0.29215747],
       [-1.3820796 ,  0.9389331 ,  0.24991722],
       ...,
       [ 0.8460486 , -1.0635027 , -1.2409896 ],
       [ 2.0486255 , -0.82711196, -1.5808923 ],
       [-0.7439912 , -0.0762727 ,  0.4138339 ]], dtype=float32)

In [35]:
WINDOW_SIZE = 200
RELATIVE_START = 25
RELATIVE_END = 180
STEP_SIZE = 25
REVERSE_COMPLEMENTS = True

BATCH_SIZE = 16 * 1
NUM_WORKERS= 0


In [36]:
vcf_data = boda.data.VcfDataset(test_vcf.vcf, fasta_dict.fasta, 
                                                 WINDOW_SIZE, RELATIVE_START, RELATIVE_END, step_size=STEP_SIZE,
                                                 reverse_complements=REVERSE_COMPLEMENTS, use_contigs=[])
vcf_loader = torch.utils.data.DataLoader(vcf_data, batch_size=BATCH_SIZE*torch.cuda.device_count(), num_workers=NUM_WORKERS*torch.cuda.device_count())

499/499 records have matching contig in FASTA
returned 499/499 records


In [37]:
vcf_data[0]['ref'].shape

torch.Size([14, 4, 200])

In [38]:
vcf_data[0]['ref'].unflatten(0,(2 if REVERSE_COMPLEMENTS else 1,vcf_data[0]['ref'].shape[0]//2)).shape

torch.Size([2, 7, 4, 200])

In [39]:
for i in range(vcf_data[0]['ref'].shape[0]//2):
    reshaped_vcf_rec = vcf_data[0]['ref'].unflatten(0,(2 if REVERSE_COMPLEMENTS else 1,vcf_data[0]['ref'].shape[0]//2))
    print(''.join([ ['A','C','G','T'][n] for n in reshaped_vcf_rec[0,i].argmax(0) ]))

ACACCTGGTCCGCCCAGTCGGAACTCACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGACAGGGGCTCGCTCTGTGC
CACCCCTACGCCGCCGCCGCTGCCGCCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGACAGGGGCTCGCTCTGTGCGGCACTTCCTGTGTCTGCGCGGGAT
CCGCCGCCGCCGCCGGTCCCGGAGCCAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGACAGGGGCTCGCTCTGTGCGGCACTTCCTGTGTCTGCGCGGGATGATAACGCATAAAACAGCGCTTGCT
CAGAGAAGAAACAGCAACCGGCGCGCGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGACAGGGGCTCGCTCTGTGCGGCACTTCCTGTGTCTGCGCGGGATGATAACGCATAAAACAGCGCTTGCTCAGGTCCAGGACGCCAGAAGAAACA
CGCCAAAAGTATCGTCACTTCCTGTATTGGCGCGTAATGATGATATAATAGCCGACCTCCGGCCCAGAACTCGAGACAACGACAGGGGCTCGCTCTGTGCGGCACTTCCTGTGTCTGCGCGGGATGATAACGCATAAAACAGCGCTTGCTCAGGTCCAGGACGCCAGAAGAAACAGCCCGGTGAGCGCACTTCCGA

In [40]:
vcf_data.vcf['chrom'].isin(fasta_dict.fasta.keys())

0      True
1      True
2      True
3      True
4      True
       ... 
494    True
495    True
496    True
497    True
498    True
Name: chrom, Length: 499, dtype: bool

In [41]:
# (n_variants, n_alleles, n_strands, n_positions)

In [42]:
class FlankBuilder(nn.Module):
    def __init__(self,
                 left_flank=None,
                 right_flank=None,
                 batch_dim=0,
                 cat_axis=-1
                ):
        
        super().__init__()
        
        self.register_buffer('left_flank', left_flank.detach().clone())
        self.register_buffer('right_flank', right_flank.detach().clone())
        
        self.batch_dim = batch_dim
        self.cat_axis  = cat_axis
        
    def add_flanks(self, my_sample):
        *batch_dims, channels, length = my_sample.shape
        
        pieces = []
        
        if self.left_flank is not None:
            pieces.append( self.left_flank.expand(*batch_dims, -1, -1) )
            
        pieces.append( my_sample )
        
        if self.right_flank is not None:
            pieces.append( self.right_flank.expand(*batch_dims, -1, -1) )
            
        return torch.cat( pieces, axis=self.cat_axis )
    
    def forward(self, my_sample):
        return self.add_flanks(my_sample)

class VepTester(nn.Module):
    
    def __init__(self,
                  model
                 ):
        
        super().__init__()
        self.model = torch.nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
        
    def forward(self, ref_batch, alt_batch):
        
        ref_shape, alt_shape = ref_batch.shape, alt_batch.shape
        assert ref_shape == alt_shape
        
        ref_batch = ref_batch.flatten(0,1).cuda()
        alt_batch = alt_batch.flatten(0,1).cuda()
        
        with torch.cuda.amp.autocast():
            ref_preds = self.model(ref_batch.contiguous())
            alt_preds = self.model(alt_batch.contiguous())

        ref_preds = ref_preds.unflatten(0, ref_shape[0:2])
        ref_preds = ref_preds.unflatten(1, (2, ref_shape[1]//2))
        
        alt_preds = alt_preds.unflatten(0, alt_shape[0:2])
        alt_preds = alt_preds.unflatten(1, (2, alt_shape[1]//2))
            
        skew_preds = alt_preds - ref_preds

        return {'ref': ref_preds, 
                'alt': alt_preds, 
                'skew': skew_preds}
    
class VepTester_FullRC(nn.Module):
    
    def __init__(self,
                  model
                 ):
        
        super().__init__()
        self.model = torch.nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
        
    def forward(self, ref_batch, alt_batch):
        
        ref_shape, alt_shape = ref_batch.shape, alt_batch.shape
        assert ref_shape == alt_shape
        
        ref_batch = ref_batch.flatten(0,1).cuda()
        alt_batch = alt_batch.flatten(0,1).cuda()
        
        with torch.cuda.amp.autocast():
            ref_preds = self.model(ref_batch.contiguous()).unflatten(0, ref_shape[0:2])
            alt_preds = self.model(alt_batch.contiguous()).unflatten(0, ref_shape[0:2])
            
        with torch.cuda.amp.autocast():
            ref_preds_rc = self.model(ref_batch.flip(dims=[-2,-1]).contiguous()).unflatten(0, ref_shape[0:2])
            alt_preds_rc = self.model(alt_batch.flip(dims=[-2,-1]).contiguous()).unflatten(0, ref_shape[0:2])

        ref_preds = torch.stack([ref_preds, ref_preds_rc], dim=1)
        alt_preds = torch.stack([alt_preds, alt_preds_rc], dim=1)
            
        skew_preds = alt_preds - ref_preds

        return {'ref': ref_preds, 
                'alt': alt_preds, 
                'skew': skew_preds}

In [43]:
class reductions(object):
    
    @staticmethod
    def mean(tensor, dim):
        return tensor.mean(dim=dim)
    
    @staticmethod
    def sum(tensor, dim):
        return tensor.sum(dim=dim)
    
    @staticmethod
    def max(tensor, dim):
        return tensor.amax(dim=dim)
    
    @staticmethod
    def min(tensor, dim):
        return tensor.amin(dim=dim)
    
    @staticmethod
    def abs_max(tensor, dim):
        n_dims = len(tensor.shape)
        get_idx= tensor.abs().argmax(dim=dim)
        slicer = []
        for i in range(n_dims):
            if i != dim:
                viewer = [1] * n_dims
                dim_size = tensor.shape[i]
                viewer[i] = dim_size
                viewer.pop(dim)
                slicer.append( torch.arange(dim_size).view(*viewer).expand(*get_idx.shape) )
            else:
                slicer.append( get_idx )
            
        return tensor[slicer]
    
    @staticmethod
    def abs_min(tensor, dim):
        n_dims = len(tensor.shape)
        get_idx= tensor.abs().argmin(dim=dim)
        slicer = []
        for i in range(n_dims):
            if i != dim:
                viewer = [1] * n_dims
                dim_size = tensor.shape[i]
                viewer[i] = dim_size
                viewer.pop(dim)
                slicer.append( torch.arange(dim_size).view(*viewer).expand(*get_idx.shape) )
            else:
                slicer.append( get_idx )
            
        return tensor[slicer]


In [6]:
def load_model(artifact_path):
    
    USE_CUDA = torch.cuda.device_count() >= 1
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')

    unpack_artifact(artifact_path)

    model_dir = './artifacts'

    my_model = model_fn(model_dir)
    my_model.eval()
    if USE_CUDA:
        my_model.cuda()
    
    return my_model

class ConsistentModelPool(nn.Module):
    
    def __init__(self,
                 path_list
                ):
        super().__init__()
        
        models = [ load_model(model_path) for model_path in path_list ]
        self.fmodel, self.params, self.buffers = combine_state_for_ensemble(models)
            
    def forward(self, batch):
        
        preds = vmap(self.fmodel, in_dims=(0, 0, None))(self.params, self.buffers, batch)
        return preds.mean(dim=0)
            
class VariableModelPool(nn.Module):
    
    def __init__(self,
                 path_list
                ):
        super().__init__()
        
        self.models = [ load_model(model_path) for model_path in path_list ]
            
    def forward(self, batch):
        
        return torch.stack([model(batch) for model in self.models]).mean(dim=0)


In [10]:
if False:
    my_model = ConsistentModelPool([
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_10/model_artifacts__20230112_024037__325298.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_11/model_artifacts__20230112_003539__758935.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_2/model_artifacts__20230111_192235__715916.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_3/model_artifacts__20230111_190417__993143.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_4/model_artifacts__20230112_000934__941678.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_5/model_artifacts__20230112_003327__287605.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_6/model_artifacts__20230112_020038__431749.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_7/model_artifacts__20230112_182326__436818.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_8/model_artifacts__20230112_014853__150994.tar.gz',
        'gs://tewhey-public-data/crossval_models/models_230111/test_1_val_9/model_artifacts__20230111_232551__863644.tar.gz'
    ])

if True:
    my_model = load_model(hpo_rec)
    ckpt = torch.load('./artifacts/torch_checkpoint.pt')


Copying gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz...
- [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     


Loaded model from 20211113_021200 in eval mode


archive unpacked in ./


In [13]:
ckpt.keys()

dict_keys(['data_module', 'data_hparams', 'model_module', 'model_hparams', 'graph_module', 'graph_hparams', 'model_state_dict', 'timestamp', 'random_tag'])

In [14]:
ckpt['data_hparams']

Namespace(activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean'], batch_size=1076, chr_column='chr', data_project=['BODA', 'UKBB', 'GTEX'], datafile_path='gs://syrgoth/data/MPRA_ALL_v3.txt', duplication_cutoff=0.5, exclude_chr_train=[''], normalize=False, num_workers=8, padded_seq_len=600, project_column='data_project', sequence_column='nt_sequence', std_multiple_cut=6.0, synth_chr='synth', synth_seed=102202, synth_test_pct=99.98, synth_val_pct=0.0, test_chrs=['7', '13'], up_cutoff_move=3.0, use_reverse_complements=True, val_chrs=['19', '21', 'X'])

In [13]:
my_model(torch.randn((1,4,600)).cuda()).shape

torch.Size([1, 3])

In [52]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0).unsqueeze(0)
left_flank.shape

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0).unsqueeze(0)
right_flank.shape

torch.Size([1, 1, 4, 200])

In [53]:
flank_builder = FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)
flank_builder.cuda()
vep_tester = VepTester(my_model) if REVERSE_COMPLEMENTS else VepTester_FullRC(my_model)

In [54]:
ref_preds = []
alt_preds = []
skew_preds= []

with torch.no_grad():
    for i, batch in enumerate(tqdm.tqdm(vcf_loader)):
        ref_allele, alt_allele = batch['ref'], batch['alt']
        
        ref_allele = flank_builder(ref_allele.cuda()).contiguous()
        alt_allele = flank_builder(alt_allele.cuda()).contiguous()
        
        all_preds = vep_tester(ref_allele, alt_allele)
        
        #ref_preds.append(all_preds['ref'].cpu())
        #alt_preds.append(all_preds['alt'].cpu())
        skew_preds.append(all_preds['skew'].cpu())

#ref_preds = torch.cat(ref_preds, dim=0)
#alt_preds = torch.cat(alt_preds, dim=0)
skew_preds= torch.cat(skew_preds, dim=0)

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|███████████████████████████████████████████| 32/32 [00:03<00:00,  8.47it/s]


In [56]:
reductions.abs_max( skew_preds.flatten(1,2), dim=1 )[:20]

tensor([[-0.1680, -0.0453, -0.2559],
        [-1.3037, -0.7322, -1.1129],
        [ 0.1487,  0.0974,  0.1986],
        [ 0.7297,  0.7278,  0.3882],
        [ 0.1549,  0.1276,  0.1208],
        [-0.1534, -0.3621, -0.3195],
        [ 0.2609,  0.2534,  0.3482],
        [-1.8330, -1.3625, -1.8389],
        [ 0.3339,  0.2508,  0.2232],
        [ 0.1988,  0.2930,  0.2083],
        [-0.2940, -0.4840, -0.5363],
        [ 0.5022, -0.2696, -0.4596],
        [ 1.6947,  1.0100,  0.5921],
        [-0.1018, -0.1527, -0.2493],
        [-0.0582,  0.0364,  0.0642],
        [ 0.2422,  0.7290,  0.5186],
        [ 0.0731,  0.2964, -0.0660],
        [ 0.0966,  0.1153,  0.1246],
        [-0.5113, -0.6667, -0.8171],
        [-1.5726, -0.3283, -0.1331]])

In [21]:
torch.save({'ref': ref_preds, 'alt': alt_preds}, out_fn)

In [22]:
vcf_data[0]['ref'].shape

torch.Size([14, 4, 200])

In [23]:
torch.arange(10).view(2,-1)[:,::3]

tensor([[0, 3],
        [5, 8]])

In [27]:
torch.arange(10).view(2,-1)[:,::4].amax(dim=0)

tensor([5, 9])

In [31]:
a = torch.randn(4, 4, 4)
a

tensor([[[ 2.6856e+00,  3.4301e-01, -6.0946e-01, -1.0167e+00],
         [ 6.7576e-01, -1.9882e+00, -3.2934e+00,  4.9012e-01],
         [ 6.6305e-01, -8.5159e-02, -1.3291e+00, -1.7398e-01],
         [ 2.8326e-01, -3.7289e-01,  1.1904e+00, -1.2813e+00]],

        [[-1.6285e+00, -1.6911e+00,  5.5387e-02, -1.6507e+00],
         [ 2.1087e-01, -1.1325e+00,  1.2331e+00,  2.8895e-01],
         [ 1.8064e+00, -1.1330e-01,  2.9672e-01,  1.5957e-01],
         [ 2.4114e-01, -1.1229e+00,  9.7628e-01, -6.5015e-01]],

        [[ 4.7679e-01,  5.6372e-01,  2.8561e-01,  1.0265e+00],
         [-6.6995e-01, -1.2969e-01,  2.1860e+00,  1.8343e+00],
         [-1.2044e+00,  4.7310e-01,  8.5648e-01,  8.2419e-02],
         [-4.5638e-01, -1.3656e+00,  6.9233e-01,  1.0186e-01]],

        [[ 1.0010e-01,  4.2466e-01,  8.6448e-01, -8.2313e-02],
         [-8.7955e-01,  9.1755e-01,  9.4769e-02, -1.9924e+00],
         [-3.3751e-01,  5.4379e-01, -2.0579e+00, -3.0400e-03],
         [ 2.4735e-01,  3.9347e-01, -1.2182e+00, 

In [35]:
a.argmax(dim=1)

tensor([[0, 0, 3, 1],
        [2, 2, 1, 1],
        [0, 0, 1, 1],
        [3, 1, 0, 3]])

In [40]:
torch.arange(4).view(1,4).expand(4,4)

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3]])

In [47]:
slicer = [torch.arange(4).view(4,1).expand(4,4),a.abs().argmax(dim=1), torch.arange(4).view(1,4).expand(4,4)]

a[slicer]

tensor([[ 2.6856, -1.9882, -3.2934, -1.2813],
        [ 1.8064, -1.6911,  1.2331, -1.6507],
        [-1.2044, -1.3656,  2.1860,  1.8343],
        [-0.8795,  0.9175, -2.0579, -1.9924]])

In [59]:
a

tensor([[[ 2.6856e+00,  3.4301e-01, -6.0946e-01, -1.0167e+00],
         [ 6.7576e-01, -1.9882e+00, -3.2934e+00,  4.9012e-01],
         [ 6.6305e-01, -8.5159e-02, -1.3291e+00, -1.7398e-01],
         [ 2.8326e-01, -3.7289e-01,  1.1904e+00, -1.2813e+00]],

        [[-1.6285e+00, -1.6911e+00,  5.5387e-02, -1.6507e+00],
         [ 2.1087e-01, -1.1325e+00,  1.2331e+00,  2.8895e-01],
         [ 1.8064e+00, -1.1330e-01,  2.9672e-01,  1.5957e-01],
         [ 2.4114e-01, -1.1229e+00,  9.7628e-01, -6.5015e-01]],

        [[ 4.7679e-01,  5.6372e-01,  2.8561e-01,  1.0265e+00],
         [-6.6995e-01, -1.2969e-01,  2.1860e+00,  1.8343e+00],
         [-1.2044e+00,  4.7310e-01,  8.5648e-01,  8.2419e-02],
         [-4.5638e-01, -1.3656e+00,  6.9233e-01,  1.0186e-01]],

        [[ 1.0010e-01,  4.2466e-01,  8.6448e-01, -8.2313e-02],
         [-8.7955e-01,  9.1755e-01,  9.4769e-02, -1.9924e+00],
         [-3.3751e-01,  5.4379e-01, -2.0579e+00, -3.0400e-03],
         [ 2.4735e-01,  3.9347e-01, -1.2182e+00, 

In [54]:
reductions.abs_min( a, 1 )

tensor([[ 0.2833, -0.0852, -0.6095, -0.1740],
        [ 0.2109, -0.1133,  0.0554,  0.1596],
        [-0.4564, -0.1297,  0.2856,  0.0824],
        [ 0.1001,  0.3935,  0.0948, -0.0030]])

In [60]:
vcf_sub = boda.data.VcfDataset(test_vcf.vcf, fasta_dict.fasta, 
                                                 WINDOW_SIZE, RELATIVE_START, RELATIVE_END, step_size=STEP_SIZE,
                                                 reverse_complements=REVERSE_COMPLEMENTS, use_contigs=['chr10', 'chr11'])


1165430/1165430 records have matching contig in FASTA
removing 1045783/1165430 records based on contig blacklist
returned 119647/1165430 records
