In [None]:
import sys
import os
import subprocess
import tarfile
import shutil

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp, AdaLead
from boda.generator.energy import OverMaxEnergy
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

In [None]:
#----------------------- ReLU6 model -----------------------
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211110_194934__672830.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
model.eval()

In [None]:
left_flank = boda.common.utils.dna2tensor(constants.MPRA_UPSTREAM[-200:]).unsqueeze(0)
right_flank = boda.common.utils.dna2tensor(constants.MPRA_DOWNSTREAM[:200] ).unsqueeze(0)

In [None]:
class BasePenalty(nn.Module):
    def __init__(self):
        super().__init__()

    def penalty(self, x):
        raise NotImplementedError("Penalty not implemented")       
        hook = x     
        return hook

class StremePenalty(BasePenalty):
    @staticmethod
    def add_penalty_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        group  = parser.add_argument_group('Penalty Module args')
        group.add_argument('--score_pct', type=float, default=0.3)
        return parser

    @staticmethod
    def process_args(grouped_args):
        penalty_args = grouped_args['Penalty Module args']
        return penalty_args
    
    def __init__(self, score_pct):
        super().__init__()
        
        self.score_pct = score_pct

    def register_penalty(self, x):
        try:
            self.penalty_filters = x.type_as(self.penalty_filters)
        except AttributeError:
            self.register_buffer('penalty_filters', x)
            
    def register_threshold(self, x):
        try:
            self.score_thresholds = x.type_as(self.score_thresholds)
        except AttributeError:
            self.register_buffer('score_thresholds', x)
            
    def streme_penalty(self, streme_output):
        
        try:
            penalty_weight = (self.penalty_filters.shape[0] // 2) + 1
        except AttributeError:
            penalty_weight = 1
        
        motif_data = parse_streme_output(streme_output['output'])
        top_ppm    = common.utils.align_to_alphabet( 
            motif_data['motif_results'][0]['ppm'], 
            motif_data['meta_data']['alphabet'], 
            common.constants.STANDARD_NT 
        )
        top_ppm = torch.tensor(top_ppm).float()
        background = [ motif_data['meta_data']['frequencies'][nt] 
                       for nt in common.constants.STANDARD_NT ]
        top_pwm = ppm_to_pwm(top_ppm, background) * (penalty_weight**0.33) # (4, L)
        max_score = torch.max(top_pwm, dim=0)[0].sum()
        top_pwm_rc = common.utils.reverse_complement_onehot(top_pwm) # (4, L)

        proposed_penalty = torch.stack([top_pwm, top_pwm_rc] ,dim=0) # (2, 4, L)
        proposed_thresholds = torch.tensor(2 * [self.score_pct * max_score]) # (2,)
        
        try:
            penalty_filters = torch.cat(
                [self.penalty_filters, proposed_penalty.to(self.penalty_filters.device)], 
                dim=0
            ) # (2k+2, 4, L)
            score_thresholds= torch.cat(
                [self.score_thresholds, proposed_thresholds.to(self.score_thresholds.device)]
            ) # (2k+2,)
            
        except AttributeError:
            penalty_filters = proposed_penalty.to(self.model.device)
            score_thresholds= proposed_thresholds.to(self.model.device)
            
        self.register_penalty(penalty_filters)
        self.register_threshold(score_thresholds)
                    
    def motif_penalty(self, x):
        try:
            motif_scores = F.conv1d(x, self.penalty_filters)
            score_thresholds = torch.ones_like(motif_scores) * self.score_thresholds[None, :, None]
            mask = torch.ge(motif_scores, score_thresholds)
            #masked_scores = torch.masked_select(motif_scores, mask)
            masked_scores = motif_scores * mask.float()
            return masked_scores.flatten(1).sum(dim=-1).div((self.penalty_filters.shape[0] // 2) * x.shape[0])

        except AttributeError:
            return 0

    def penalty(self, x):
        hook = x.to(self.model.device)
        return self.motif_penalty(hook)

    def update_penalty(self, proposal):
        proposals_list = common.utils.batch2list(proposal['proposals'])
        streme_results = streme(proposals_list, w=15)
        self.streme_penalty(streme_results)
        update_summary = {
            'streme_output': streme_results,
            'filters': self.penalty_filters.detach().clone(),
            'score_thresholds': self.score_thresholds.detach().clone()
        }
        return update_summary

In [None]:
bias_cell = 0

batch_size  = 50 
n_samples  = 20    
num_steps  = 300   
score_pct = 0.0

energy = OverMaxEnergy(model=model, bias_cell=bias_cell, score_pct=.3)

theta_ini = torch.randn(batch_size, 4, 200)
params = StraightThroughParameters(data=theta_ini,
                                   left_flank=left_flank,
                                   right_flank=right_flank,
                                   n_samples=n_samples,
                                   affine=affine_trans)
generator = FastSeqProp(energy_fn=energy,
                        params=params)