In [1]:
import sys
import os
import subprocess
import tarfile
import shutil

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp, AdaLead
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm
from boda.model.mpra_basset import MPRA_Basset
from boda.common import constants, utils
from boda.generator.energy import BaseEnergy

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

In [2]:
def reverse_complement(in_tensor, alphabet=constants.STANDARD_NT):
    rc_dict = {'A':'T', 'G':'C', 'T':'A', 'C':'G'}
    reversed_alphabet = [rc_dict[nt] for nt in alphabet]
    out_tensor = align_to_alphabet(in_tensor, in_alphabet=alphabet,  out_alphabet=reversed_alphabet)
    out_tensor = torch.flip(out_tensor, dims=[1])
    return out_tensor

def show_streme_motifs(parsed_output):
    motif_dict = parsed_output['motif_results']
    results_alphabet = parsed_output['meta_data']['alphabet']
    for motif_idx in range(len(motif_dict)):
        motif_ppm = torch.tensor((motif_dict[motif_idx]['ppm']))
        motif_ppm = align_to_alphabet(motif_ppm, in_alphabet=results_alphabet)
        motif_ppm_rc = reverse_complement(motif_ppm)
        print(motif_dict[motif_idx]['summary'])
        matrix_to_dms(ppm_to_IC(motif_ppm), y_max=2)
        plt.show()
        matrix_to_dms(ppm_to_IC(motif_ppm_rc), y_max=2)
        plt.show()
        
def fasta_to_input_tensor(file_name, left_flank, right_flank):
    fasta_dict = {}
    with open(file_name, 'r') as f:
        for line in f:
            line_str = str(line)
            if line_str[0] == '>':
                my_id = line_str.lstrip('>').rstrip('\n')
                fasta_dict[my_id] = ''
            else:
                fasta_dict[my_id] += line_str.rstrip('\n')
    seq_tensors = []
    for sequence in list(fasta_dict.values()):
        seq_tensors.append(utils.dna2tensor(sequence))
    sequences = torch.stack(seq_tensors, dim=0)
    pieces = [left_flank.repeat(sequences.shape[0], 1, 1), sequences,  right_flank.repeat(sequences.shape[0], 1, 1)]
    return torch.cat(pieces, axis=-1)

def plot3D_activities(activities_tensor, color='blue', fig_size=(15, 10), alpha=0.2, ax_lims=(-2, 8)):
    xdata = activities_tensor[:,0].cpu().detach().numpy()
    ydata = activities_tensor[:,1].cpu().detach().numpy()
    zdata = activities_tensor[:,2].cpu().detach().numpy()

    fig = plt.figure(figsize=fig_size)
    ax = plt.axes(projection='3d')

    # Data for a three-dimensional line
    xAxisLine = (ax_lims, (0, 0), (0,0))
    ax.plot(xAxisLine[0], xAxisLine[1], xAxisLine[2], 'r')
    yAxisLine = ((0, 0), ax_lims, (0,0))
    ax.plot(yAxisLine[0], yAxisLine[1], yAxisLine[2], 'r')
    zAxisLine = ((0, 0), (0,0), ax_lims)
    ax.plot(zAxisLine[0], zAxisLine[1], zAxisLine[2], 'r')
    dAxisLine = (ax_lims, ax_lims, ax_lims)
    ax.plot(dAxisLine[0], dAxisLine[1], dAxisLine[2], 'gray', linestyle='dashed')

    ax.scatter3D(xdata, ydata, zdata, c=color, alpha=alpha)
    ax.set_xlabel('K562')
    ax.set_ylabel('HepG2')
    ax.set_zlabel('SKNSH')
    ax.view_init(15, -45)
    
def unpickle_logs(log_path):
    log_df = pd.read_pickle(log_path + 'sequence_data.pkl')
    with open(log_path + 'pmms_list.pkl', 'rb') as fp:
        pmms_list = pickle.load(fp)
    return log_df, pmms_list


def create_new_log_folder_in(super_folder):
    log_idx = 0
    folder_name = 'log_' + str(log_idx)
    while os.path.isdir(super_folder + folder_name):
        log_idx += 1
        folder_name = 'log_' + str(log_idx)
    log_path = super_folder + folder_name 
    os.makedirs(log_path)
    return log_path + '/'

def frame_print(string, marker='*', left_space=25):
    left_spacer = left_space * ' '
    string = marker + ' ' + string.upper() + ' ' + marker
    n = len(string)
    print('')
    print('')
    print(left_spacer + n * marker)
    print(left_spacer + string)
    print(left_spacer + n * marker)
    print('')
    print('')
    
def decor_print(string):
    decor = 15*'-'
    print('')
    print(decor + ' ' + string + ' ' + decor)
    print('')

In [3]:
#----------------------- Artisanal model -----------------------
# ! gsutil cp gs://syrgoth/checkpoints/manual_checkpoint_multioutput_lasthidden250_L1Loss_ReLU6_sneak1_double0_ACGT.ckpt ./

artisan_model = MPRA_Basset(extra_hidden_size = 250)
checkpoint = torch.load('manual_checkpoint_multioutput_lasthidden250_L1Loss_ReLU6_sneak1_double0_ACGT.ckpt')
artisan_model.load_state_dict(checkpoint['state_dict'])
artisan_model.eval()

MPRA_Basset(
  (criterion): MSELoss()
  (last_activation): Tanh()
  (basset_net): Basset(
    (pad1): ConstantPad1d(padding=(9, 9), value=0.0)
    (conv1): Conv1dNorm(
      (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
      (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pad2): ConstantPad1d(padding=(5, 5), value=0.0)
    (conv2): Conv1dNorm(
      (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
      (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pad3): ConstantPad1d(padding=(3, 3), value=0.0)
    (conv3): Conv1dNorm(
      (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
      (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pad4): ConstantPad1d(padding=(1, 1), value=0.0)
    (maxpool_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (maxpool_4): MaxPool1d(ker

In [4]:
left_flank = boda.common.utils.dna2tensor(constants.MPRA_UPSTREAM[-200:]).unsqueeze(0)
right_flank = boda.common.utils.dna2tensor(constants.MPRA_DOWNSTREAM[:200] ).unsqueeze(0)

In [15]:
class OverMaxEnergy(BaseEnergy):
    def __init__(self, model, bias_cell=0, bias_alpha=1.):
        super().__init__()
        
        self.model = model
        self.model.eval()
        
        self.bias_cell = bias_cell
        self.bias_alpha= bias_alpha
        
    def forward(self, x):
        hook = x.to(self.model.device)
        
        hook = self.model(hook)
        
        return hook[...,[ x for x in range(hook.shape[-1]) if x != self.bias_cell]].max(-1).values \
                 - hook[...,self.bias_cell].mul(self.bias_alpha)
    
    def register_penalty(self, x):
        try:
            self.penalty = x.type_as(self.penalty)
        except AttributeError:
            self.register_buffer('penalty', x)
            
    def streme_penalty(self, streme_output):
        motif_data = parse_streme_output(streme_results['output'])
        top_ppm    = motif_data['motif_results'][0]['ppm']
       
    
    
class AvgDiffEnergy(BaseEnergy):
    def __init__(self, model, bias_cell=0, bending=False):
        super().__init__()
        
        self.model = model
        try: self.model.eval()
        except: pass
        
        self.bias_cell = bias_cell
        self.silent_cells = np.r_
        self.bending = bending
        
        self.silenced_cells = [0,1,2]
        self.silenced_cells.remove(self.bias_cell)


In [16]:
energy = AvgDiffEnergy(model=artisan_model,
                       bias_cell=0)

In [17]:
energy.silenced_cells

[1, 2]

In [13]:
silenced_cells = [0,1,2]
silenced_cells.remove(1)
silenced_cells = np.r_[silenced_cells]
print(silenced_cells)

[0 2]


In [18]:
x[:, np.r_[energy.silenced_cells]]

tensor([[-2.0885, -1.4613],
        [-0.2209, -0.2344],
        [-1.3188, -0.7413],
        [-0.9203,  1.6975],
        [ 0.6752,  0.7456]])

In [26]:
artisan_model.cuda()

batch_size  = 50 
n_samples  = 20    
num_steps  = 300   

affine_trans = False
scheduler    = True
loss_plots   = False

energy = OverMaxEnergy(model=artisan_model,
                       bias_cell=0)

energy = AvgDiffEnergy(model=artisan_model,
                       bias_cell=0)

theta_ini = torch.randn(batch_size, 4, 200)
params = StraightThroughParameters(data=theta_ini,
                                   left_flank=left_flank,
                                   right_flank=right_flank,
                                   n_samples=n_samples,
                                   affine=affine_trans)
generator = FastSeqProp(energy_fn=energy,
                        params=params)
generator.cuda()
# generator.run(steps=num_steps,
#               learning_rate=0.5,
#               step_print=5,
#               lr_scheduler=scheduler,
#               create_plot=loss_plots)

FastSeqProp(
  (energy_fn): AvgDiffEnergy(
    (model): MPRA_Basset(
      (criterion): MSELoss()
      (last_activation): Tanh()
      (basset_net): Basset(
        (pad1): ConstantPad1d(padding=(9, 9), value=0.0)
        (conv1): Conv1dNorm(
          (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
          (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (pad2): ConstantPad1d(padding=(5, 5), value=0.0)
        (conv2): Conv1dNorm(
          (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
          (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (pad3): ConstantPad1d(padding=(3, 3), value=0.0)
        (conv3): Conv1dNorm(
          (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
          (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (pad4): ConstantPad1d(padding=(1, 1), v