In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import random
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle
from datetime import datetime
import scipy.stats as stats
import math

import boda
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

from torch.distributions.categorical import Categorical
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm, counts_to_ppm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [2]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
#model.cuda()
model.eval()
print('')

Loaded model from 20211113_021200 in eval mode



archive unpacked in ./


In [3]:
class mpra_predictor(nn.Module):
    def __init__(self,
                 model,
                 pred_idx=0,
                 ini_in_len=200,
                 model_in_len=600,
                 cat_axis=-1):
        super().__init__()
        self.model = model
        self.pred_idx = pred_idx
        self.ini_in_len = ini_in_len 
        self.model_in_len = model_in_len
        self.cat_axis = cat_axis       
        
        try: self.model.eval()
        except: pass
        
        self.register_flanks()
    
    def forward(self, x):
        pieces = [self.left_flank.repeat(x.shape[0], 1, 1), x, self.right_flank.repeat(x.shape[0], 1, 1)]
        in_tensor = torch.cat( pieces, axis=self.cat_axis)
        out_tensor = self.model(in_tensor)[:, self.pred_idx]
        return out_tensor
    
    def register_flanks(self):
        missing_len = self.model_in_len - self.ini_in_len
        left_idx = - missing_len//2 + missing_len%2
        right_idx = missing_len//2 + missing_len%2
        left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[left_idx:]).unsqueeze(0)
        right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:right_idx]).unsqueeze(0)         
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank) 

def isg_contributions(sequences,
                      predictor,
                      num_steps=50,
                      num_samples=20,
                      eval_batch_size=1024,
                      theta_factor=15,
                      num_workers=0):
    
    batch_size = eval_batch_size // num_samples
    temp_dataset = TensorDataset(sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    all_salient_maps = []
    all_gradients = []
    for local_batch in tqdm(temp_dataloader):
        target_thetas = (theta_factor * local_batch[0].cuda()).requires_grad_()
        #base_thetas = theta_factor / 3 * torch.ones_like(target_thetas)
        line_gradients = []
        for i in range(0, num_steps + 1):
            point_thetas = (i / num_steps * target_thetas)
            #point_thetas = base_thetas + i / num_steps * (target_thetas - base_thetas)
            point_distributions = F.softmax(point_thetas, dim=-2)

            nucleotide_probs = Categorical(torch.transpose(point_distributions, -2, -1))
            sampled_idxs = nucleotide_probs.sample((num_samples, ))
            sampled_nucleotides_T = F.one_hot(sampled_idxs, num_classes=4)
            sampled_nucleotides = torch.transpose(sampled_nucleotides_T, -2, -1)
            distribution_repeater = point_distributions.repeat(num_samples, *[1 for i in range(3)])
            sampled_nucleotides = sampled_nucleotides - distribution_repeater.detach() + distribution_repeater 
            samples = sampled_nucleotides.flatten(0,1)

            preds = predictor(samples)
            point_predictions = preds.unflatten(0, (num_samples, target_thetas.shape[0])).mean(dim=0)
            point_gradients = torch.autograd.grad(point_predictions.sum(), inputs=point_thetas, retain_graph=True)[0]
            line_gradients.append(point_gradients)
            
        gradients = torch.stack(line_gradients).mean(dim=0).detach()
        all_salient_maps.append(gradients * target_thetas.detach())
        all_gradients.append(gradients)
    return torch.cat(all_salient_maps).cpu(), theta_factor * torch.cat(all_gradients).cpu()

def df_to_onehot_tensor(in_df, seq_column='nt_sequence'):
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

In [4]:
#Load data

boda2_df = pd.read_csv('BODA2_MPRA_results_pred_v3.txt', sep=' ', low_memory=True)

In [5]:
#Drop controls ands check lengths

no_controls_df = boda2_df[boda2_df['method'].notnull()].copy().reset_index(drop=True)
# no_controls_df['seq_len'] = no_controls_df.progress_apply(lambda x: len(x['sequence']), axis=1)

# no_controls_df['seq_len'].unique()

In [6]:
#Get one-hot sequences

sequences = no_controls_df['sequence'].tolist()
onehot_sequences = torch.stack([utils.dna2tensor(sequence) for sequence in tqdm(sequences)])

  0%|          | 0/117900 [00:00<?, ?it/s]

In [7]:
%%time

#Compute contributions

seq_len = 200
eval_batch_size = 1040
num_sequences = onehot_sequences.shape[0]

predictors = [mpra_predictor(model=model, pred_idx=i, ini_in_len=seq_len).cuda() for i in range(3)]

cell_idxs = {0: 'K562', 1: 'HepG2', 2:'SKNSH'}
for idx, predictor in enumerate(predictors):
    contributions, _ = isg_contributions(onehot_sequences, predictor, eval_batch_size=eval_batch_size)
    contributions = contributions.sum(dim=1)
    str_contribution_list = [np.array2string(contributions[i,...].numpy(), separator=' ', precision=16, max_line_width=1e6) \
                         for i in range(num_sequences)]
    no_controls_df[f'contrib_{cell_idxs[idx]}'] = str_contribution_list
    no_controls_df.to_csv('BODA2_MPRA_results_no_controls_pred_contributions.txt', index=None, sep=' ')

  0%|          | 0/2268 [00:00<?, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


  0%|          | 0/2268 [00:00<?, ?it/s]

  0%|          | 0/2268 [00:00<?, ?it/s]

CPU times: user 10h 3min 40s, sys: 42.6 s, total: 10h 4min 23s
Wall time: 10h 4min 3s


In [9]:
no_controls_df.to_csv('BODA2_MPRA_results_no_controls_pred_contributions.txt', index=None, sep=' ')