In [None]:
import sys
import os
import subprocess
import tarfile
import shutil

import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC
from boda.model.mpra_basset import MPRA_Basset
from boda.common import constants

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

In [None]:
#----------------------- HPO model -----------------------
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20210623_102310__205717.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
hpo_model = model_fn(model_dir)
hpo_model.eval()

In [None]:
#----------------------- Artisanal model -----------------------
! gsutil cp gs://syrgoth/checkpoints/manual_checkpoint_multioutput_lasthidden250_L1.ckpt ./

artisan_model = MPRA_Basset(extra_hidden_size = 250)
checkpoint = torch.load('manual_checkpoint_multioutput_lasthidden250_L1.ckpt')
artisan_model.load_state_dict(checkpoint['state_dict'])
artisan_model.eval()

In [None]:
left_flank = boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_UPSTREAM[-200:] 
).unsqueeze(0)

right_flank= boda.common.utils.dna2tensor( 
    boda.common.constants.MPRA_DOWNSTREAM[:200] 
).unsqueeze(0)

In [None]:
def castro_reward(x):
    return torch.exp(-x) - x - 1

def basic_reward(x):
    return x

def k562_score(x):
    return -castro_reward(x[:,0]) + 0.5 * (castro_reward( x[:,1]) + castro_reward( x[:,2]))
# def k562_score(x):
#     return x[:,0] - torch.mean(x[:,1:], axis=1) 
def k562_specific(x):
    scores = k562_score(x)
    return torch.mean(-scores)

def hepg2_score(x):
    return -castro_reward(x[:,1]) + 0.5 * (castro_reward( x[:,2]) + castro_reward( x[:,0]))
# def hepg2_score(x):
#     return x[:,1] - torch.mean(x[:, np.r_[0,2]], axis=1)
def hepg2_specific(x):
    scores = hepg2_score(x)
    return torch.mean(- scores)

class mpra_energy(nn.Module):
    def __init__(self,
                 predictor,
                 loss_fn,
                 **kwrags):
        super().__init__()
        self.predictor = predictor
        self.loss_fn = loss_fn

        try: self.predictor.eval()
        except: pass
               
    def forward(self, x):
        preds = self.predictor(x)
        return self.loss_fn(preds)

In [None]:
#Comparing entropy distributions between Affine and No-Affine

batch_size = 50
sample_takes = 10
n_samples = 10
num_steps = 300
scheduler = False

loss_fn = hepg2_specific
model = hpo_model #hpo_model or artisan_model

energy = mpra_energy(predictor=model,
                     loss_fn=loss_fn)

#--------------------- Affine ------------------------
theta_ini = torch.randn(batch_size, 4, 200)
params = StraightThroughParameters(data=theta_ini,
                                   left_flank=left_flank,
                                   right_flank=right_flank,
                                   n_samples=n_samples,
                                   affine=True)
generator = FastSeqProp(energy_fn=energy,
                        params=params)
generator.cuda()
generator.run(steps=num_steps,
              learning_rate=0.5,
              step_print=20,
              lr_scheduler=scheduler)

entropies_affine = []
for i in range(sample_takes):
    preds = energy.predictor(params())
    entropies_affine += list(boda.graph.utils.shannon_entropy(preds).detach().log().cpu().numpy())

#--------------------- No Affine ------------------------
theta_ini = torch.randn(batch_size, 4, 200)
params = StraightThroughParameters(data=theta_ini,
                                   left_flank=left_flank,
                                   right_flank=right_flank,
                                   n_samples=n_samples,
                                   affine=False)
generator = FastSeqProp(energy_fn=energy,
                        params=params)
generator.cuda()
generator.run(steps=num_steps,
              learning_rate=0.05,
              step_print=20,
              lr_scheduler=scheduler)

entropies_no_affine = []
for i in range(sample_takes):
    preds = energy.predictor(params())
    entropies_no_affine += list(boda.graph.utils.shannon_entropy(preds).detach().log().cpu().numpy())

df_1 = pd.DataFrame(entropies_affine, columns=['entropy'])
df_1['type'] = 'Affine'
df_2 = pd.DataFrame(entropies_no_affine, columns=['entropy'])
df_2['type'] = 'No affine'
df = pd.concat([df_1, df_2])

sns.displot(data=df, x='entropy', hue='type', kind='kde', fill=True, height=7, aspect=10/6)
#plt.xlim(0, 1.2)
plt.show()

In [None]:
params.preds.unflatten(0, (n_samples, batch_size)).mean(dim=0)

In [None]:
### Generating multiple mini-batches

In [None]:
%%time
#------------------ Choose settings ------------------
affine_trans = False
iterations = 20
batch_size = 50
#sample_takes = 10
n_samples = 20
num_steps = 300
scheduler = True
loss_plots = False

loss_fn = k562_specific
model = artisan_model     #hpo_model or artisan_model

#------------------ Optimization run ------------------
energy = mpra_energy(predictor=model,
                     loss_fn=loss_fn)

distributions = []
sequence_samples = []
predictions = []
entropies = []
for iteration in range(iterations):
    theta_ini = torch.randn(batch_size, 4, 200)
    params = StraightThroughParameters(data=theta_ini,
                                       left_flank=left_flank,
                                       right_flank=right_flank,
                                       n_samples=n_samples,
                                       affine=affine_trans)
    generator = FastSeqProp(energy_fn=energy,
                            params=params)
    generator.cuda()
    generator.run(steps=num_steps,
                  learning_rate=0.5,
                  step_print=20,
                  lr_scheduler=scheduler,
                  create_plot=loss_plots)
       
    samples = params()
    preds = energy.predictor(samples)
    
    distributions.append(params.get_probs().detach().cpu())    
    sequence_samples.append(samples.detach().cpu().unflatten(0, (n_samples, batch_size)))
    predictions.append(preds.detach().cpu().unflatten(0, (n_samples, batch_size)))
    entropies.append(boda.graph.utils.shannon_entropy(preds).detach().cpu().unflatten(0, (n_samples, batch_size)))

entropy_tensor = torch.cat(entropies, dim=1)
prediction_tensor = torch.cat(predictions, dim=1)
sequences_tensor = torch.cat(sequence_samples, dim=1)
distributions_tensor = torch.cat(distributions, dim=0)

In [None]:
#------------------ Select best sequences ------------------
best_entropy_idxs = torch.argmin(entropy_tensor, dim=0)
best_entropies = []
best_predictions = []
best_sequences = []
for idx, best_idx in enumerate(best_entropy_idxs.tolist()):
    best_entropies.append(entropy_tensor[best_idx, idx])
    best_predictions.append(prediction_tensor[best_idx, idx, :])
    best_sequences.append(sequences_tensor[best_idx, idx, :, 200:400])
    
best_entropies = torch.tensor(best_entropies)
best_predictions = torch.stack(best_predictions, dim=0)
best_sequences = torch.stack(best_sequences, dim=0)

#------------------ Plot entropy distribution ------------------
sns.displot(data=best_entropies, kind='kde', fill=True, height=5, aspect=10/6)
plt.xlim(0, 1)
plt.xlabel('Entropy')
plt.show()

#------------------ Plot activities in 3D ------------------
xdata = best_predictions[:,0].cpu().detach().numpy()
ydata = best_predictions[:,1].cpu().detach().numpy()
zdata = best_predictions[:,2].cpu().detach().numpy()

fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection='3d')

# Data for a three-dimensional line
xAxisLine = ((-2, 8), (0, 0), (0,0))
ax.plot(xAxisLine[0], xAxisLine[1], xAxisLine[2], 'r')
yAxisLine = ((0, 0), (-2, 8), (0,0))
ax.plot(yAxisLine[0], yAxisLine[1], yAxisLine[2], 'r')
zAxisLine = ((0, 0), (0,0), (-2, 8))
ax.plot(zAxisLine[0], zAxisLine[1], zAxisLine[2], 'r')

ax.scatter3D(xdata, ydata, zdata, c='blue')
ax.set_xlabel(\"K562\")
ax.set_ylabel(\"HepG2\")
ax.set_zlabel(\"SKNSH\")
ax.view_init(15, -45)
plt.show()

In [None]:
file_name = 'k562_1000dist_castro.txt'
batch2fasta(best_sequences, file_name)

In [None]:
## Analyse with STREME

In [None]:
file_name = 'k562_1000dist_castro.txt'
test_seq_file = '/home/ubuntu/boda2/analysis/RC04_FastSeqProp_MotifPenalty/' + file_name

streme_results = streme(test_seq_file)
for i, line in enumerate(streme_results['output'].decode("utf-8").split('\n')):
    print(line)                                                 

In [None]:
parsed_output = parse_streme_output(streme_results['output'])
print(parsed_output)

In [None]:
temp_tensor = ppm_to_IC(torch.tensor(parsed_output['motif_results'][0]['ppm']))

In [None]:
matrix_to_dms(ppm_to_IC(temp_tensor), y_max=2)