In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import types
import math
import pickle

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from torch.autograd import grad

import matplotlib as mp
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import boda
from boda.generator import plot_tools
from boda.common import constants

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn

from scipy import spatial
from scipy.cluster import hierarchy

from boda.generator import BasicParameters, NaiveMH, SimulatedAnnealing
from boda.generator.metropolis_hastings import PolynomialDecay
from boda.generator.energy import OverMaxEnergy

In [2]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')

hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20210623_102310__205717.tar.gz'

unpack_artifact(hpo_rec)

model_dir = './artifacts'

my_model = model_fn(model_dir)
my_model.cuda()
my_model.eval()

left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)
print(f'Left flanking sequence shape: {left_flank.shape}')

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)
print(f'Left flanking sequence shape: {right_flank.shape}')

Loaded model from 20210623_102310 in eval mode
Left flanking sequence shape: torch.Size([1, 4, 200])
Left flanking sequence shape: torch.Size([1, 4, 200])


In [3]:
bsz = 512
step_count = 1000
test_insert = dist.OneHotCategorical(logits=torch.randn(bsz,200,4)).sample().permute(0,2,1)

my_params = BasicParameters(
    nn.Parameter(test_insert), 
    left_flank.expand(bsz, -1, -1), 
    right_flank.expand(bsz, -1, -1)
)
my_energy = OverMaxEnergy(my_model, bias_cell=0)

In [4]:
mcmc = NaiveMH(my_params, my_energy)

In [None]:
samples = mcmc.collect_samples(n_samples=step_count, n_burnin=step_count)
mcmc_run = samples['samples']

  0%|          | 0/1000 [00:00<?, ?it/s]

burn in


  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
 91%|█████████ | 906/1000 [00:51<00:05, 17.64it/s]

In [None]:
mcmc_run['samples'].shape

In [None]:
for i in range(test_insert.shape[0]):
    plt.plot(torch.arange(step_count),mcmc_run['energies'][:,i])
plt.show()

plt.hist( mcmc_run['energies'][-1].flatten().numpy(), bins=50 )
plt.show()

In [None]:
temp = PolynomialDecay(a=1, b=1, gamma=0.5001)
mcmc = SimulatedAnnealing(my_params, my_energy, n_positions=1, temperature_schedule=temp)

In [None]:
samples = mcmc.collect_samples(n_samples=step_count, n_burnin=0)
mcmc_run = samples['samples']

In [None]:
mcmc_run['samples'].shape

In [None]:
for i in range(test_insert.shape[0]):
    plt.plot(torch.arange(step_count),mcmc_run['energies'][:,i])
plt.show()

plt.hist( mcmc_run['energies'][-1].flatten().numpy(), bins=50 )
plt.show()