In [1]:
import shutil
import os

import torch
import torch.nn as nn

import boda

# Helpers

1. `load_model` checks GPUs, clears a spot for the model to be downloaded, downloads, and loads the model in `eval` mode.

2. `FlankBuilder` is used to pad inputs with MPRA vector backbone sequence. For technical reasons, Malinois reads 600 nt sequences (i.e., n x 4 x 600 inputs) but it should be 200 nt of variable sequence padded with MPRA backbone.

In [2]:
def load_model(artifact_path):
    
    USE_CUDA = torch.cuda.device_count() >= 1
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')

    boda.common.utils.unpack_artifact(artifact_path)

    model_dir = './artifacts'

    my_model = boda.common.utils.model_fn(model_dir)
    my_model.eval()
    if USE_CUDA:
        my_model.cuda()
    
    return my_model

class FlankBuilder(nn.Module):
    def __init__(self,
                 left_flank=None,
                 right_flank=None,
                 batch_dim=0,
                 cat_axis=-1
                ):
        
        super().__init__()
        
        self.register_buffer('left_flank', left_flank.detach().clone())
        self.register_buffer('right_flank', right_flank.detach().clone())
        
        self.batch_dim = batch_dim
        self.cat_axis  = cat_axis
        
    def add_flanks(self, my_sample):
        *batch_dims, channels, length = my_sample.shape
        
        pieces = []
        
        if self.left_flank is not None:
            pieces.append( self.left_flank.expand(*batch_dims, -1, -1) )
            
        pieces.append( my_sample )
        
        if self.right_flank is not None:
            pieces.append( self.right_flank.expand(*batch_dims, -1, -1) )
            
        return torch.cat( pieces, axis=self.cat_axis )
    
    def forward(self, my_sample):
        return self.add_flanks(my_sample)

# Get Malinois

Can download directly from a Google Storage bucket you can access.

In [3]:
malinois_path = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
my_model = load_model(malinois_path)

Copying gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz...
/ [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     
archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


# Set flanks

MPRA flanks are saved as constants in the `boda` repo. These need to be sized to (1, 4, 200) each and used to init `FlankBuilder`.

In [4]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)
print(f'left flank shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)
right_flank.shape
print(f'right flank shape: {right_flank.shape}')

flank_builder = FlankBuilder(
    left_flank=left_flank,
    right_flank=right_flank,
)
flank_builder.cuda()

left flank shape: torch.Size([1, 4, 200])
right flank shape: torch.Size([1, 4, 200])


FlankBuilder()

# Example call

Using `torch.no_grad()` so the computation graph isn't saved to memory. Since sequences are passed to the model as onehots in `torch.float32` format, we can use `torch.randn` to validate the model setup. Here a batch of 10 variable 200 nt (fake) sequences are being padded to 600 nt, then being passed to the model. Note, `my_model` and `flank_builder` have been set on the GPU using `.cuda()` calls. Therefore, the fake sequence also needs to be sent to `cuda`.

In [5]:
with torch.no_grad():
    print( 
        my_model( 
            flank_builder(                     # Need to add MPRA flanks
                torch.randn((10,4,200)).cuda() # Simulate a batch_size x 4 nucleotide x 200 nt long sequence
            ) 
        ) 
    )

tensor([[-5.9237, -4.3324, -2.6590],
        [ 1.8957, -1.6334,  7.3978],
        [-1.3470, -1.1018,  2.1494],
        [ 0.0840, -0.4010,  7.8278],
        [-2.9887, -2.0866,  3.6742],
        [ 0.1913, -1.0352,  8.1859],
        [-1.4896, -0.7747,  5.5385],
        [-0.8139, -1.0344,  3.1287],
        [-0.5258, -0.4886,  4.0395],
        [-1.6722, -0.6007,  3.7459]], device='cuda:0')
