In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import random
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle
from datetime import datetime
import scipy.stats as stats
import math

import boda
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

from torch.distributions.categorical import Categorical
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm, counts_to_ppm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [2]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
#model.cuda()
model.eval()
print('')

Loaded model from 20211113_021200 in eval mode



archive unpacked in ./


In [3]:
class mpra_predictor(nn.Module):
    def __init__(self,
                 model,
                 pred_idx=0,
                 ini_in_len=200,
                 model_in_len=600,
                 cat_axis=-1):
        super().__init__()
        self.model = model
        self.pred_idx = pred_idx
        self.ini_in_len = ini_in_len 
        self.model_in_len = model_in_len
        self.cat_axis = cat_axis       
        
        try: self.model.eval()
        except: pass
        
        self.register_flanks()
    
    def forward(self, x):
        pieces = [self.left_flank.repeat(x.shape[0], 1, 1), x, self.right_flank.repeat(x.shape[0], 1, 1)]
        in_tensor = torch.cat( pieces, axis=self.cat_axis)
        out_tensor = self.model(in_tensor)[:, self.pred_idx]
        return out_tensor
    
    def register_flanks(self):
        missing_len = self.model_in_len - self.ini_in_len
        left_idx = - missing_len//2 + missing_len%2
        right_idx = missing_len//2 + missing_len%2
        left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[left_idx:]).unsqueeze(0)
        right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:right_idx]).unsqueeze(0)         
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank) 

        
def df_to_onehot_tensor(in_df, seq_column='sequence'):
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

def fasta_to_tensor(file_name):
    fasta_dict = {}
    with open(file_name, 'r') as f:
        for line in f:
            line_str = str(line)
            if line_str[0] == '>':
                my_id = line_str.lstrip('>').rstrip('\n')
                fasta_dict[my_id] = ''
            else:
                fasta_dict[my_id] += line_str.rstrip('\n')
    seq_tensors = []
    for sequence in list(fasta_dict.values()):
        seq_tensors.append(utils.dna2tensor(sequence))
    return torch.stack(seq_tensors, dim=0)

def dna2tensor_approx(sequence_str, vocab_list=constants.STANDARD_NT, N_value=0.25):
    seq_tensor = np.zeros((len(vocab_list), len(sequence_str)))
    for letterIdx, letter in enumerate(sequence_str):
        try:
            seq_tensor[vocab_list.index(letter), letterIdx] = 1
        except:
            seq_tensor[:, letterIdx] = N_value
    seq_tensor = torch.Tensor(seq_tensor)
    return seq_tensor

def frame_print(string, marker='*', left_space=25):
    left_spacer = left_space * ' '
    string = marker + ' ' + string.upper() + ' ' + marker
    n = len(string)
    print('', flush=True)
    print('', flush=True)
    print(left_spacer + n * marker, flush=True)
    print(left_spacer + string, flush=True)
    print(left_spacer + n * marker, flush=True)
    print('', flush=True)
    print('', flush=True)
    
def decor_print(string):
    decor = 15*'-'
    print('', flush=True)
    print(decor + ' ' + string + ' ' + decor, flush=True)
    print('', flush=True)

In [4]:
def isg_contributions(sequences,
                      predictor,
                      num_steps=50,
                      num_samples=20,
                      eval_batch_size=1024,
                      theta_factor=15):
    
    batch_size = eval_batch_size // num_samples
    temp_dataset = TensorDataset(sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    all_salient_maps = []
    for local_batch in tqdm(temp_dataloader):
        target_thetas = (theta_factor * local_batch[0].cuda()).requires_grad_()
        line_gradients = []
        for i in range(0, num_steps + 1):
            point_thetas = (i / num_steps * target_thetas)
            point_distributions = F.softmax(point_thetas, dim=-2)

            nucleotide_probs = Categorical(torch.transpose(point_distributions, -2, -1))
            sampled_idxs = nucleotide_probs.sample((num_samples, ))
            sampled_nucleotides_T = F.one_hot(sampled_idxs, num_classes=4)
            sampled_nucleotides = torch.transpose(sampled_nucleotides_T, -2, -1)
            distribution_repeater = point_distributions.repeat(num_samples, *[1 for i in range(3)])
            sampled_nucleotides = sampled_nucleotides - distribution_repeater.detach() + distribution_repeater 
            samples = sampled_nucleotides.flatten(0,1)

            preds = predictor(samples)
            point_predictions = preds.unflatten(0, (num_samples, target_thetas.shape[0])).mean(dim=0)
            point_gradients = torch.autograd.grad(point_predictions.sum(), inputs=point_thetas, retain_graph=True)[0]
            line_gradients.append(point_gradients)
            
        gradients = torch.stack(line_gradients).mean(dim=0) 
        all_salient_maps.append(gradients * target_thetas)
    return torch.cat(all_salient_maps)

In [5]:
k562_predictor = mpra_predictor(model=model, pred_idx=0).cuda()

In [6]:
left_pad = 200
right_pad = 200
gata_chr = 'X'
gata_locus_start = 47644750 - left_pad
gata_locus_end = 49644750 + right_pad - 1
gata_locus_coord = 'chr' + gata_chr + ':'+ f'{gata_locus_start:,}' + '-' + f'{gata_locus_end:,}'
print(gata_locus_coord)

chrX:47,644,550-49,644,949


In [7]:
#! gsutil cp gs://syrgoth/data/locus_select/chrX-47,644,550-49,644,949.txt ./

gata_locus_file = 'chrX-47,644,550-49,644,949.txt'
gata_locus_str = ''
with open(gata_locus_file) as f:
    for line in f:
        if line[0] != '>':
            gata_locus_str += line.strip()
            
print(len(gata_locus_str[left_pad:-right_pad]), len(range(47644750, 49644750)))

2000000 2000000


In [8]:
gata_locus_tensor = dna2tensor_approx(gata_locus_str, N_value=0.)
gata_locus_tensor.shape

torch.Size([4, 2000400])

In [9]:
#create windows
window_len = 200
step_size = 10
locus_tensor_windows = [gata_locus_tensor[:, start:start+window_len] for start in range(0, gata_locus_tensor.shape[1]-window_len+1, step_size)]
locus_tensor_windows = torch.stack(locus_tensor_windows)

In [10]:
windows_coordinates = [f'chr{gata_chr}:{gata_locus_start + start}-{gata_locus_start + start + window_len-1}' for start in range(0, gata_locus_tensor.shape[1]-window_len+1, step_size)]

In [11]:
# chunk_example = locus_tensor_windows[:1020,...]
# chunk_example.shape

In [12]:
%%time
data_tensor = locus_tensor_windows #locus_tensor_windows
chunk_size = 10002 #10002 #204
eval_batch_size = 1040

cell_type = 'k562'
targetdir = 'gata_locus/contributions_v1'

print(f'Results will be saved at {targetdir}', flush=True)
print('', flush=True)

num_chunks = math.ceil(data_tensor.shape[0] / chunk_size)
processed_chunks = 0
for i in range(0, data_tensor.shape[0], chunk_size):
    start_time = datetime.now()
    
    decor_print(f'Processing chunk {processed_chunks+1}/{num_chunks}')
        
    chunk = data_tensor[i:i+chunk_size, ...]    
    
    salient_maps = isg_contributions(chunk, k562_predictor, eval_batch_size=eval_batch_size)
    coordinate_list = windows_coordinates[i:i+chunk.shape[0]]
    
    save_dict = {}
    save_dict['window_contributions'] = salient_maps
    save_dict['window_coordinates'] = coordinate_list
    
    first_coordinate = coordinate_list[0].split('-')[0]
    last_coordinate = coordinate_list[-1].split('-')[1]
    chunk_name = f'gata_locus_contributions__{cell_type}__window_len_{window_len}__step_size_{step_size}'
    chunk_name += f'__{first_coordinate}-{last_coordinate}' + '.pt'
    
    save_path = os.path.join(targetdir, chunk_name)   
    torch.save(save_dict, save_path)
    
    print(f'Contributions saved in {save_path}')
    print('', flush=True)
    
    processed_chunks += 1
    left_chunks = num_chunks - processed_chunks
    end_time = datetime.now()
    chunk_time = end_time - start_time
    
    print(f'Chunk processing time: {chunk_time}', flush=True)
    print('', flush=True)
    print(f'Estimated time remaining: {chunk_time*left_chunks}', flush=True)
    print('', flush=True)

Results will be saved at gata_locus/contributions_v1


--------------- Processing chunk 1/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:47644550-47744759.pt

Chunk processing time: 0:16:38.812780

Estimated time remaining: 5:16:17.442820


--------------- Processing chunk 2/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:47744570-47844779.pt

Chunk processing time: 0:16:39.803341

Estimated time remaining: 4:59:56.460138


--------------- Processing chunk 3/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:47844590-47944799.pt

Chunk processing time: 0:16:41.040117

Estimated time remaining: 4:43:37.681989


--------------- Processing chunk 4/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:47944610-48044819.pt

Chunk processing time: 0:16:40.756049

Estimated time remaining: 4:26:52.096784


--------------- Processing chunk 5/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48044630-48144839.pt

Chunk processing time: 0:16:39.539755

Estimated time remaining: 4:09:53.096325


--------------- Processing chunk 6/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48144650-48244859.pt

Chunk processing time: 0:16:39.909640

Estimated time remaining: 3:53:18.734960


--------------- Processing chunk 7/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48244670-48344879.pt

Chunk processing time: 0:16:40.100324

Estimated time remaining: 3:36:41.304212


--------------- Processing chunk 8/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48344690-48444899.pt

Chunk processing time: 0:16:40.208348

Estimated time remaining: 3:20:02.500176


--------------- Processing chunk 9/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48444710-48544919.pt

Chunk processing time: 0:16:40.040334

Estimated time remaining: 3:03:20.443674


--------------- Processing chunk 10/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48544730-48644939.pt

Chunk processing time: 0:16:40.055059

Estimated time remaining: 2:46:40.550590


--------------- Processing chunk 11/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48644750-48744959.pt

Chunk processing time: 0:16:40.276289

Estimated time remaining: 2:30:02.486601


--------------- Processing chunk 12/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48744770-48844979.pt

Chunk processing time: 0:16:41.664750

Estimated time remaining: 2:13:33.318000


--------------- Processing chunk 13/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48844790-48944999.pt

Chunk processing time: 0:16:40.322156

Estimated time remaining: 1:56:42.255092


--------------- Processing chunk 14/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:48944810-49045019.pt

Chunk processing time: 0:16:40.946123

Estimated time remaining: 1:40:05.676738


--------------- Processing chunk 15/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49044830-49145039.pt

Chunk processing time: 0:16:41.032199

Estimated time remaining: 1:23:25.160995


--------------- Processing chunk 16/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49144850-49245059.pt

Chunk processing time: 0:16:40.327921

Estimated time remaining: 1:06:41.311684


--------------- Processing chunk 17/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49244870-49345079.pt

Chunk processing time: 0:16:41.251796

Estimated time remaining: 0:50:03.755388


--------------- Processing chunk 18/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49344890-49445099.pt

Chunk processing time: 0:16:41.229941

Estimated time remaining: 0:33:22.459882


--------------- Processing chunk 19/20 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49444910-49545119.pt

Chunk processing time: 0:16:41.269630

Estimated time remaining: 0:16:41.269630


--------------- Processing chunk 20/20 ---------------



  0%|          | 0/192 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49544930-49644949.pt

Chunk processing time: 0:16:38.597384

Estimated time remaining: 0:00:00

CPU times: user 5h 33min 17s, sys: 21.4 s, total: 5h 33min 39s
Wall time: 5h 33min 27s
