In [1]:
import subprocess
import sys
import os
import shutil
import gzip
import csv
import argparse
import multiprocessing

import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import combine_state_for_ensemble, vmap

import numpy as np
import pandas as pd
import boda
from boda.common import constants, utils
from boda.common.utils import unpack_artifact, model_fn


In [2]:
def load_model(artifact_path):
    
    USE_CUDA = torch.cuda.device_count() >= 1
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')

    unpack_artifact(artifact_path)

    model_dir = './artifacts'

    my_model = model_fn(model_dir)
    my_model.eval()
    if USE_CUDA:
        my_model.cuda()
    
    return my_model

class FlankBuilder(nn.Module):
    def __init__(self,
                 left_flank=None,
                 right_flank=None,
                 batch_dim=0,
                 cat_axis=-1
                ):
        
        super().__init__()
        
        self.register_buffer('left_flank', left_flank.detach().clone())
        self.register_buffer('right_flank', right_flank.detach().clone())
        
        self.batch_dim = batch_dim
        self.cat_axis  = cat_axis
        
    def add_flanks(self, my_sample):
        *batch_dims, channels, length = my_sample.shape
        
        pieces = []
        
        if self.left_flank is not None:
            pieces.append( self.left_flank.expand(*batch_dims, -1, -1) )
            
        pieces.append( my_sample )
        
        if self.right_flank is not None:
            pieces.append( self.right_flank.expand(*batch_dims, -1, -1) )
            
        return torch.cat( pieces, axis=self.cat_axis )
    
    def forward(self, my_sample):
        return self.add_flanks(my_sample)

In [3]:
malinois_path = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
my_model = load_model(malinois_path)

Copying gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz...
/ [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     
archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


In [4]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)
print(f'left flank shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)
right_flank.shape
print(f'right flank shape: {right_flank.shape}')

flank_builder = FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)
flank_builder.cuda()

left flank shape: torch.Size([1, 4, 200])
right flank shape: torch.Size([1, 4, 200])


FlankBuilder()

In [5]:
with torch.no_grad():
    print( 
        my_model( 
            flank_builder(                     # Need to add MPRA flanks
                torch.randn((10,4,200)).cuda() # Simulate a batch_size x 4 nucleotide x 200 nt long sequence
            ) 
        ) 
    )

tensor([[-1.1762, -0.6605,  4.3805],
        [-3.6497, -2.2724,  2.4733],
        [-1.0202, -1.0788,  1.3754],
        [-1.5042, -1.3004,  6.2957],
        [-2.4146, -1.6084,  8.1475],
        [-1.9277, -1.3732,  5.3238],
        [-1.0373, -0.7777,  4.8458],
        [-1.6621, -1.1395,  2.2218],
        [-1.3609, -0.5861,  4.8060],
        [-1.7919, -1.1859,  7.3544]], device='cuda:0')
