In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import random
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle
from datetime import datetime
import scipy.stats as stats
import math

import boda
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

from torch.distributions.categorical import Categorical
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm, counts_to_ppm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

### Batch example

In [11]:
def dna2tensor_approx(sequence_str, vocab_list=constants.STANDARD_NT):
    seq_tensor = np.zeros((len(vocab_list), len(sequence_str)))
    for letterIdx, letter in enumerate(sequence_str):
        try:
            seq_tensor[vocab_list.index(letter), letterIdx] = 1
        except:
            seq_tensor[:, letterIdx] = 0.25
    seq_tensor = torch.Tensor(seq_tensor)
    return seq_tensor

def df_to_onehot_tensor(in_df, seq_column='nt_sequence'):
    onehot_sequences = torch.stack([dna2tensor_approx(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

In [5]:
chunk_idx = 1
chunk_name = 'chunk_' + str(chunk_idx).zfill(2)
print(chunk_name)
chunk_path = 'df_chunks/' + chunk_name + '.txt'

line_dict = {} 
with open(chunk_path, 'r') as f:
    for line in f:
        ID, sequence = line.lstrip('>::').rstrip('\n').split('\t')
        line_dict[ID] = sequence.upper()
temp_df = pd.DataFrame(line_dict.items(), columns=['ID', 'nt_sequence'])
temp_df['seq_len'] = temp_df.apply(lambda x: len(x['nt_sequence']), axis=1)

chunk_01


In [31]:
example_batch = df_to_onehot_tensor(temp_df[-20:])

  0%|          | 0/20 [00:00<?, ?it/s]

In [32]:
example_batch.shape

torch.Size([20, 4, 200])

# Function drafts

In [15]:
class mpra_predictor(nn.Module):
    def __init__(self,
                 model,
                 pred_idx=0,
                 ini_in_len=200,
                 model_in_len=600,
                 cat_axis=-1,
                 dual_pred=False):
        super().__init__()
        self.model = model
        self.pred_idx = pred_idx
        self.ini_in_len = ini_in_len 
        self.model_in_len = model_in_len
        self.cat_axis = cat_axis  
        self.dual_pred = dual_pred
        
        try: self.model.eval()
        except: pass
        
        self.register_flanks()
    
    def forward(self, x):
        pieces = [self.left_flank.repeat(x.shape[0], 1, 1), x, self.right_flank.repeat(x.shape[0], 1, 1)]
        in_tensor = torch.cat( pieces, axis=self.cat_axis)
        if self.dual_pred:
            dual_tensor = utils.reverse_complement_onehot(in_tensor)
            out_tensor = self.model(in_tensor)[:, self.pred_idx] + self.model(dual_tensor)[:, self.pred_idx]
            out_tensor = out_tensor / 2.0
        else:
            out_tensor = self.model(in_tensor)[:, self.pred_idx]
        return out_tensor
    
    def register_flanks(self):
        missing_len = self.model_in_len - self.ini_in_len
        left_idx = - missing_len//2 + missing_len%2
        right_idx = missing_len//2 + missing_len%2
        left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[left_idx:]).unsqueeze(0)
        right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:right_idx]).unsqueeze(0)         
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank) 
        

def isg_contributions(sequences,
                      predictor,
                      num_steps=50,
                      num_samples=20,
                      eval_batch_size=1024,
                      theta_factor=15):
    
    batch_size = eval_batch_size // num_samples
    temp_dataset = TensorDataset(sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    all_salient_maps = []
    all_gradients = []
    for local_batch in tqdm(temp_dataloader):
        target_thetas = (theta_factor * local_batch[0].cuda()).requires_grad_()
        line_gradients = []
        for i in range(0, num_steps + 1):
            point_thetas = (i / num_steps * target_thetas)
            point_distributions = F.softmax(point_thetas, dim=-2)

            nucleotide_probs = Categorical(torch.transpose(point_distributions, -2, -1))
            sampled_idxs = nucleotide_probs.sample((num_samples, ))
            sampled_nucleotides_T = F.one_hot(sampled_idxs, num_classes=4)
            sampled_nucleotides = torch.transpose(sampled_nucleotides_T, -2, -1)
            distribution_repeater = point_distributions.repeat(num_samples, *[1 for i in range(3)])
            sampled_nucleotides = sampled_nucleotides - distribution_repeater.detach() + distribution_repeater 
            samples = sampled_nucleotides.flatten(0,1)

            preds = predictor(samples)
            point_predictions = preds.unflatten(0, (num_samples, target_thetas.shape[0])).mean(dim=0)
            point_gradients = torch.autograd.grad(point_predictions.sum(), inputs=point_thetas, retain_graph=True)[0]
            line_gradients.append(point_gradients)
            
        gradients = torch.stack(line_gradients).mean(dim=0).detach()
        all_salient_maps.append(gradients * target_thetas.detach())
        all_gradients.append(gradients)
    return theta_factor * torch.cat(all_gradients).cpu()
    # return torch.cat(all_salient_maps).cpu(), theta_factor * torch.cat(all_gradients).cpu()


def batch_to_contributions(onehot_sequences,
                           model,
                           model_output_len=3,
                           seq_len = 200,
                           eval_batch_size=1040):
    
    extended_contributions = []
    for i in range(model_output_len):
        predictor = mpra_predictor(model=model, pred_idx=i, ini_in_len=seq_len).cuda()
        extended_contributions.append(isg_contributions(onehot_sequences, predictor, eval_batch_size=eval_batch_size))
        
    return torch.stack(extended_contributions)  

### Mock run        

In [14]:
# if os.path.isdir('./artifacts'):
#     shutil.rmtree('./artifacts')
# hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
# unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
#model.cuda()
model.eval()
print('')

Loaded model from 20211113_021200 in eval mode



In [33]:
out_contributions = batch_to_contributions(onehot_sequences=example_batch,
                                           model=model,
                                           eval_batch_size=1024)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [34]:
out_contributions

tensor([[[[ 1.3563e-02,  1.1861e-02,  4.3333e-03,  ..., -1.7194e-02,
           -2.7673e-02, -1.6030e-03],
          [ 1.7821e-03, -1.6564e-02, -1.2404e-02,  ...,  1.0409e-02,
            5.4094e-02, -3.4538e-02],
          [-2.2099e-02,  7.3780e-04, -6.5452e-03,  ...,  2.1037e-02,
           -4.0910e-03, -1.7237e-02],
          [ 6.7543e-03,  3.9654e-03,  1.4616e-02,  ..., -1.4251e-02,
           -2.2330e-02,  5.3378e-02]],

         [[-2.6193e-02,  5.1062e-02,  3.4286e-02,  ..., -1.4842e-02,
           -2.1771e-02, -6.2029e-03],
          [-9.9690e-03, -2.9622e-03, -1.9139e-02,  ...,  1.1137e-02,
            2.3121e-02, -2.2516e-02],
          [-6.0820e-02, -3.1922e-02,  1.0546e-02,  ...,  9.2081e-03,
           -2.8667e-04,  4.1596e-02],
          [ 9.6982e-02, -1.6178e-02, -2.5692e-02,  ..., -5.5031e-03,
           -1.0633e-03, -1.2877e-02]],

         [[-1.2120e-02,  6.7470e-02,  4.3175e-02,  ..., -2.4565e-02,
            2.6996e-05,  2.3800e-02],
          [-5.3221e-03, -2.2707e-