In [None]:
import sys
import os
import subprocess
import tarfile
import shutil

import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp
from boda.model.mpra_basset import MPRA_Basset

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn

In [None]:
#----------------------- HPO model -----------------------

if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20210623_102310__205717.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
hpo_model = model_fn(model_dir)
hpo_model.eval()

In [None]:
#----------------------- Artisanal model -----------------------

! gsutil cp gs://syrgoth/checkpoints/manual_checkpoint_multioutput_lasthidden250_L1.ckpt ./

artisan_model = MPRA_Basset(extra_hidden_size = 250)
checkpoint = torch.load('manual_checkpoint_multioutput_lasthidden250_L1.ckpt')
artisan_model.load_state_dict(checkpoint['state_dict'])
artisan_model.eval()

In [None]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)

In [None]:
def castro_reward(x):
    return torch.exp(-x) - x - 1

def basic_reward(x):
    return x

def k562_score(x):
    return -castro_reward(x[:,0]) + 0.5 * (castro_reward( x[:,1]) + castro_reward( x[:,2]))
# def k562_score(x):
#     return x[:,0] - torch.mean(x[:,1:], axis=1) 
def k562_specific(x):
    scores = k562_score(x)
    return torch.mean(-scores)

def hepg2_score(x):
    return -castro_reward(x[:,1]) + 0.5 * (castro_reward( x[:,2]) + castro_reward( x[:,0]))
# def hepg2_score(x):
#     return x[:,1] - torch.mean(x[:, np.r_[0,2]], axis=1)
def hepg2_specific(x):
    scores = hepg2_score(x)
    return torch.mean(- scores)

class mpra_energy(nn.Module):
    def __init__(self,
                 predictor,
                 loss_fn,
                 **kwrags):
        super().__init__()
        self.predictor = predictor
        self.loss_fn = loss_fn

        try: self.predictor.eval()
        except: pass
               
    def forward(self, x):
        preds = self.predictor(x)
        return self.loss_fn(preds)

In [None]:
#Comparing entropy distributions between Affine and No-Affine

batch_size = 50
sample_takes = 10
n_samples = 10
num_steps = 300
scheduler = False

loss_fn = hepg2_specific
model = artisan_model #hpo_model or artisan_model

energy = mpra_energy(predictor=model,
                     loss_fn=loss_fn)

#--------------------- Affine ------------------------
theta_ini = torch.randn(batch_size, 4, 200)
params = StraightThroughParameters(data=theta_ini,
                                   left_flank=left_flank,
                                   right_flank=right_flank,
                                   n_samples=n_samples,
                                   affine=True)
generator = FastSeqProp(energy_fn=energy,
                        params=params)
generator.cuda()
generator.run(steps=num_steps,
              learning_rate=0.5,
              step_print=20,
              lr_scheduler=scheduler)

entropies_affine = []
for i in range(sample_takes):
    preds = energy.predictor(params())
    entropies_affine += list(boda.graph.utils.shannon_entropy(preds).detach().cpu().numpy())

#--------------------- No Affine ------------------------
theta_ini = torch.randn(batch_size, 4, 200)
params = StraightThroughParameters(data=theta_ini,
                                   left_flank=left_flank,
                                   right_flank=right_flank,
                                   n_samples=n_samples,
                                   affine=False)
generator = FastSeqProp(energy_fn=energy,
                        params=params)
generator.cuda()
generator.run(steps=num_steps,
              learning_rate=0.05,
              step_print=20,
              lr_scheduler=scheduler)

entropies_no_affine = []
for i in range(sample_takes):
    preds = energy.predictor(params())
    entropies_no_affine += list(boda.graph.utils.shannon_entropy(preds).detach().cpu().numpy())

df_1 = pd.DataFrame(entropies_affine, columns=['entropy'])
df_1['type'] = 'Affine'
df_2 = pd.DataFrame(entropies_no_affine, columns=['entropy'])
df_2['type'] = 'No affine'
df = pd.concat([df_1, df_2])

sns.displot(data=df, x='entropy', hue='type', kind='kde', fill=True, height=7, aspect=10/6)
plt.xlim(0, 1.2)
plt.show()

In [None]:
params.rebatch(preds)