!pip install imageio
!pip install imageio-ffmpeg
!pip install dmslogo
!pip install palettable
!pip install array_to_latex

In [1]:
import sys
import os
import time

import torch
import torch.nn as nn

import boda

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn

In [2]:
bsz = 40

## Set up parameters

In [3]:
left_hold = boda.common.utils.dna2tensor( (boda.common.constants.MPRA_UPSTREAM)[-200:] ).repeat(bsz,1,1)
right_hold= boda.common.utils.dna2tensor( (boda.common.constants.MPRA_UPSTREAM)[:200] ).repeat(bsz,1,1)

In [4]:
left_hold

tensor([[[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        ...,

        [[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1

In [5]:
my_params = boda.generator.GumbelSoftmaxParameters(
    nn.Parameter(torch.randn([bsz,4,200])), 
    left_flank=left_hold, 
    right_flank=right_hold,
    n_samples=1
)

In [6]:
my_params().shape

torch.Size([40, 4, 600])

## Load model

In [7]:
unpack_artifact('gs://syrgoth/aip_ui_test/model_artifacts__20210515_230601__116249.tar.gz')
model_dir = './artifacts'

my_model = model_fn(model_dir)

Loaded model from 20210515_230601


## Test model + parameters combination

In [8]:
class SEBridge(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, hook):
        probs     = self.model(hook).softmax(dim=1)
        log_probs = probs.log()
        return (probs*log_probs).sum(dim=1).mul(-1.)

In [11]:
energy_fn = SEBridge(my_model)

In [None]:
energy_fn(my_params())

## Test generator

In [10]:
my_sampler = boda.generator.NUTS3( my_params, energy_fn, max_tree_depth=6 )

In [11]:
start = time.time()
my_samples = my_sampler.collect_samples(1e-1,10)
print(time.time() - start)

669.2568943500519


## Test cuda

In [12]:
my_params.cuda()
energy_fn.cuda()

SEBridge(
  (model): BassetVL(
    (pad1): ConstantPad1d(padding=[9, 9], value=0.0)
    (conv1): Conv1dNorm(
      (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
      (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pad2): ConstantPad1d(padding=[5, 5], value=0.0)
    (conv2): Conv1dNorm(
      (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
      (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pad3): ConstantPad1d(padding=[3, 3], value=0.0)
    (conv3): Conv1dNorm(
      (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
      (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pad4): ConstantPad1d(padding=(1, 1), value=0.0)
    (maxpool_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (maxpool_4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=Fals

In [13]:
energy_fn(my_params())

tensor([1.0973, 1.0924, 1.0948, 1.0947, 1.0959, 1.0906, 1.0771, 1.0517, 1.0978,
        1.0881, 1.0965, 1.0974, 1.0977, 1.0963, 1.0958, 1.0811, 1.0647, 1.0937,
        1.0977, 1.0744, 1.0772, 1.0971, 1.0983, 1.0973, 1.0970, 1.0790, 1.0937,
        1.0766, 1.0877, 0.9917, 1.0934, 1.0903, 1.0892, 1.0717, 1.0660, 1.0982,
        1.0960, 1.0864, 1.0954, 1.0970], device='cuda:0',
       grad_fn=<MulBackward0>)

In [14]:
my_sampler = boda.generator.NUTS3( my_params, energy_fn, max_tree_depth=6 )

In [16]:
start = time.time()
my_samples = my_sampler.collect_samples(1e-1,10)
print(time.time() - start)

9.746002435684204


In [17]:
my_samples[-1][0]

tensor([[[ -6.8437,  17.8763,  -0.1315,  ..., -18.5676, -28.9050,  -4.3739],
         [-23.0793,  17.8637, -17.6396,  ...,   9.6393, -38.7782,  -8.9497],
         [  5.9327,  18.5760,  -3.2536,  ...,  24.2850,  -8.2704, -11.5782],
         [-19.6584,  23.1742,  23.3611,  ...,  15.4485,   4.6120,  19.6974]],

        [[  8.4275,  -3.4474,   1.0758,  ...,  -7.1960,   2.7149,  -3.7693],
         [  6.4695,  40.1236,   5.8912,  ...,  16.8275,   2.4700,  14.5646],
         [-14.4251,   1.4949,  11.0203,  ...,  25.2833,  -1.3034,   6.9184],
         [ -9.5555,  11.5167,  -6.8069,  ...,   7.8747, -13.1826, -37.1442]],

        [[ -1.8678,   6.0356,  10.5502,  ...,  14.9514, -11.8969,  -9.4892],
         [ -0.9737, -15.3925,  19.8969,  ...,  -6.5278,  -7.2056,  -6.5082],
         [ -0.1698, -30.8604, -25.1898,  ..., -21.0621,  32.6659,  13.3362],
         [ -1.3777,  26.0208,   8.9887,  ...,   9.6813, -14.7003,  -0.6798]],

        ...,

        [[-18.0933,  12.2198,  42.7688,  ...,   5.2563, 

In [18]:
my_params.theta.data = my_samples[0][0]
energy_fn( my_params() )

tensor([1.0482, 1.0931, 1.0970, 1.0960, 1.0799, 1.0967, 1.0891, 1.0861, 1.0887,
        1.0949, 1.0941, 1.0596, 1.0956, 1.0758, 1.0961, 1.0970, 1.0044, 1.0985,
        1.0971, 1.0893, 1.0800, 1.0978, 1.0978, 1.0907, 1.0596, 1.0308, 1.0981,
        1.0979, 1.0983, 1.0843, 1.0947, 1.0880, 1.0898, 1.0897, 1.0967, 1.0952,
        1.0957, 1.0961, 1.0864, 1.0019], device='cuda:0',
       grad_fn=<MulBackward0>)

In [19]:
my_params.theta.data = my_samples[-2][0]
energy_fn( my_params() )

tensor([1.0974, 1.0969, 1.0944, 1.0948, 1.0971, 1.0918, 1.0967, 1.0985, 1.0947,
        1.0888, 1.0954, 1.0707, 1.0810, 1.0850, 1.0779, 1.0712, 1.0941, 1.0977,
        1.0838, 1.0955, 1.0984, 1.0900, 1.0967, 1.0668, 1.0951, 1.0948, 1.0972,
        1.0951, 1.0694, 1.0748, 1.0976, 1.0936, 1.0956, 1.0977, 1.0979, 1.0908,
        1.0893, 1.0956, 1.0883, 1.0946], device='cuda:0',
       grad_fn=<MulBackward0>)