In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import random
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle
from datetime import datetime
import scipy.stats as stats
import math

import boda
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

from torch.distributions.categorical import Categorical
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm, counts_to_ppm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [2]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
#model.cuda()
model.eval()
print('')

archive unpacked in ./


Loaded model from 20211113_021200 in eval mode



In [3]:
class mpra_predictor(nn.Module):
    def __init__(self,
                 model,
                 pred_idx=0,
                 ini_in_len=200,
                 model_in_len=600,
                 cat_axis=-1):
        super().__init__()
        self.model = model
        self.pred_idx = pred_idx
        self.ini_in_len = ini_in_len 
        self.model_in_len = model_in_len
        self.cat_axis = cat_axis       
        
        try: self.model.eval()
        except: pass
        
        self.register_flanks()
    
    def forward(self, x):
        pieces = [self.left_flank.repeat(x.shape[0], 1, 1), x, self.right_flank.repeat(x.shape[0], 1, 1)]
        in_tensor = torch.cat( pieces, axis=self.cat_axis)
        out_tensor = self.model(in_tensor)[:, self.pred_idx]
        return out_tensor
    
    def register_flanks(self):
        missing_len = self.model_in_len - self.ini_in_len
        left_idx = - missing_len//2 + missing_len%2
        right_idx = missing_len//2 + missing_len%2
        left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[left_idx:]).unsqueeze(0)
        right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:right_idx]).unsqueeze(0)         
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank) 

def isg_contributions(sequences,
                      predictor,
                      num_steps=50,
                      num_samples=20,
                      eval_batch_size=1024,
                      theta_factor=15,
                      num_workers=0):
    
    batch_size = eval_batch_size // num_samples
    temp_dataset = TensorDataset(sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    all_salient_maps = []
    all_gradients = []
    for local_batch in tqdm(temp_dataloader):
        target_thetas = (theta_factor * local_batch[0].cuda()).requires_grad_()
        #base_thetas = theta_factor / 3 * torch.ones_like(target_thetas)
        line_gradients = []
        for i in range(0, num_steps + 1):
            point_thetas = (i / num_steps * target_thetas)
            #point_thetas = base_thetas + i / num_steps * (target_thetas - base_thetas)
            point_distributions = F.softmax(point_thetas, dim=-2)

            nucleotide_probs = Categorical(torch.transpose(point_distributions, -2, -1))
            sampled_idxs = nucleotide_probs.sample((num_samples, ))
            sampled_nucleotides_T = F.one_hot(sampled_idxs, num_classes=4)
            sampled_nucleotides = torch.transpose(sampled_nucleotides_T, -2, -1)
            distribution_repeater = point_distributions.repeat(num_samples, *[1 for i in range(3)])
            sampled_nucleotides = sampled_nucleotides - distribution_repeater.detach() + distribution_repeater 
            samples = sampled_nucleotides.flatten(0,1)

            preds = predictor(samples)
            point_predictions = preds.unflatten(0, (num_samples, target_thetas.shape[0])).mean(dim=0)
            point_gradients = torch.autograd.grad(point_predictions.sum(), inputs=point_thetas, retain_graph=True)[0]
            line_gradients.append(point_gradients)
            
        gradients = torch.stack(line_gradients).mean(dim=0).detach()
        all_salient_maps.append(gradients * target_thetas.detach())
        all_gradients.append(gradients)
    return torch.cat(all_salient_maps).cpu(), theta_factor * torch.cat(all_gradients).cpu()

def df_to_onehot_tensor(in_df, seq_column='nt_sequence'):
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

In [4]:
#Load data

df =  pd.read_csv('MPRA_ALL_no_cutoffs_v2_pred.txt', sep=' ', low_memory=True)
df['chr'] = df['chr'].astype(str)

  interactivity=interactivity, compiler=compiler, result=result)


In [5]:
#Select chromosome
#train_chrs =  [1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 20, Y]
#val chrs = [19, 21, X]
#test_chrs = [7, 13]

chromosome = '12'

chr_df = df[df['chr'] == chromosome].reset_index(drop=True)

chr_df['seq_len'] = chr_df.apply(lambda x: len(x['nt_sequence']), axis=1)

extended_contribution_df = chr_df[['IDs', 'nt_sequence', 'seq_len']].copy()

print(len(chr_df))

39303


In [6]:
%%time

#Compute contributions

eval_batch_size = 1040

#cell_idxs = {0: 'K562', 1: 'HepG2', 2: 'SKNSH'}
cell_types = ['K562', 'HepG2', 'SKNSH']
length_list = sorted(chr_df['seq_len'].unique())
for seq_length in length_list:
    row_filter = (chr_df['seq_len'] == seq_length)
    temp_df = chr_df[row_filter]
    sequences = temp_df['nt_sequence'].tolist()
    num_sequences = len(sequences)
    print(f'---------- Processing {num_sequences} sequences of length {seq_length} ----------')
    print('')
    print(f'Tokenizing sequences:')
    onehot_sequences = torch.stack([utils.dna2tensor(sequence) for sequence in tqdm(sequences)])    
    for cell_idx, cell_type in enumerate(cell_types):
        predictor = mpra_predictor(model=model, pred_idx=cell_idx, ini_in_len=seq_length).cuda()       
        #Run contributions
        print(f'{cell_type} contributions')
        contributions, ext_contributions = isg_contributions(onehot_sequences, predictor, eval_batch_size=eval_batch_size)
        #Store flat contributions
        contributions = contributions.sum(dim=1)
        str_contribution_list = [np.array2string(contributions[i,...].numpy(), separator=' ', precision=16, max_line_width=1e6) \
                             for i in range(num_sequences)]
        column_name = f'contrib_{cell_type}'
        chr_df.loc[row_filter, column_name] = str_contribution_list
        #Store extended contributions
        str_ext_contribution_list = [np.array2string(ext_contributions[i,...].numpy(), separator=' ', precision=16, max_line_width=1e6) \
                             for i in range(num_sequences)]
        ext_column_name = f'ext_contrib_{cell_type}'       
        extended_contribution_df.loc[row_filter, ext_column_name] = str_ext_contribution_list
        #Save file after each cell type
        if seq_length == 200:
            chr_df.to_csv(f'train_set_contribution_files/BODA2_TrainSet_contributions_chr{chromosome}.txt', index=None, sep=' ')
            extended_contribution_df.to_csv(f'train_set_contribution_files/BODA2_TrainSet_extended_contributions_chr{chromosome}.txt', index=None, sep=' ')

---------- Processing 1 sequences of length 148 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 157 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 3 sequences of length 160 ----------

Tokenizing sequences:


  0%|          | 0/3 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 161 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 164 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 166 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 167 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 168 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 170 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 174 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 3 sequences of length 175 ----------

Tokenizing sequences:


  0%|          | 0/3 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 176 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 13 sequences of length 177 ----------

Tokenizing sequences:


  0%|          | 0/13 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 1 sequences of length 178 ----------

Tokenizing sequences:


  0%|          | 0/1 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 3 sequences of length 179 ----------

Tokenizing sequences:


  0%|          | 0/3 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 3 sequences of length 180 ----------

Tokenizing sequences:


  0%|          | 0/3 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 2 sequences of length 181 ----------

Tokenizing sequences:


  0%|          | 0/2 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 4 sequences of length 182 ----------

Tokenizing sequences:


  0%|          | 0/4 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 4 sequences of length 183 ----------

Tokenizing sequences:


  0%|          | 0/4 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 2 sequences of length 184 ----------

Tokenizing sequences:


  0%|          | 0/2 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 11 sequences of length 185 ----------

Tokenizing sequences:


  0%|          | 0/11 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 12 sequences of length 186 ----------

Tokenizing sequences:


  0%|          | 0/12 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 8 sequences of length 187 ----------

Tokenizing sequences:


  0%|          | 0/8 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 6 sequences of length 188 ----------

Tokenizing sequences:


  0%|          | 0/6 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 10 sequences of length 189 ----------

Tokenizing sequences:


  0%|          | 0/10 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 14 sequences of length 190 ----------

Tokenizing sequences:


  0%|          | 0/14 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 25 sequences of length 191 ----------

Tokenizing sequences:


  0%|          | 0/25 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 24 sequences of length 192 ----------

Tokenizing sequences:


  0%|          | 0/24 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 26 sequences of length 193 ----------

Tokenizing sequences:


  0%|          | 0/26 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 40 sequences of length 194 ----------

Tokenizing sequences:


  0%|          | 0/40 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/1 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/1 [00:00<?, ?it/s]

---------- Processing 67 sequences of length 195 ----------

Tokenizing sequences:


  0%|          | 0/67 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/2 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/2 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/2 [00:00<?, ?it/s]

---------- Processing 214 sequences of length 196 ----------

Tokenizing sequences:


  0%|          | 0/214 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/5 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/5 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/5 [00:00<?, ?it/s]

---------- Processing 216 sequences of length 197 ----------

Tokenizing sequences:


  0%|          | 0/216 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/5 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/5 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/5 [00:00<?, ?it/s]

---------- Processing 312 sequences of length 198 ----------

Tokenizing sequences:


  0%|          | 0/312 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/6 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/6 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/6 [00:00<?, ?it/s]

---------- Processing 901 sequences of length 199 ----------

Tokenizing sequences:


  0%|          | 0/901 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/18 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/18 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/18 [00:00<?, ?it/s]

---------- Processing 37369 sequences of length 200 ----------

Tokenizing sequences:


  0%|          | 0/37369 [00:00<?, ?it/s]

K562 contributions


  0%|          | 0/719 [00:00<?, ?it/s]

HepG2 contributions


  0%|          | 0/719 [00:00<?, ?it/s]

SKNSH contributions


  0%|          | 0/719 [00:00<?, ?it/s]

CPU times: user 3h 39min 17s, sys: 24.6 s, total: 3h 39min 41s
Wall time: 3h 40min 17s


In [7]:
# chr_df.to_csv(f'train_set_contribution_files/BODA2_TrainSet_contributions_chr{chromosome}.txt', index=None, sep=' ')
# extended_contribution_df.to_csv(f'train_set_contribution_files/BODA2_TrainSet_extended_contributions_chr{chromosome}.txt', index=None, sep=' ')