In [7]:
import sys
import os
import subprocess
import tarfile
import shutil
import random
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle
from datetime import datetime
import scipy.stats as stats
import math

import boda
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

from torch.distributions.categorical import Categorical
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm, counts_to_ppm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [8]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
#model.cuda()
model.eval()
print('')

Loaded model from 20211113_021200 in eval mode



archive unpacked in ./


In [9]:
class mpra_predictor(nn.Module):
    def __init__(self,
                 model,
                 pred_idx=0,
                 ini_in_len=200,
                 model_in_len=600,
                 cat_axis=-1):
        super().__init__()
        self.model = model
        self.pred_idx = pred_idx
        self.ini_in_len = ini_in_len 
        self.model_in_len = model_in_len
        self.cat_axis = cat_axis       
        
        try: self.model.eval()
        except: pass
        
        self.register_flanks()
    
    def forward(self, x):
        pieces = [self.left_flank.repeat(x.shape[0], 1, 1), x, self.right_flank.repeat(x.shape[0], 1, 1)]
        in_tensor = torch.cat( pieces, axis=self.cat_axis)
        out_tensor = self.model(in_tensor)[:, self.pred_idx]
        return out_tensor
    
    def register_flanks(self):
        missing_len = self.model_in_len - self.ini_in_len
        left_idx = - missing_len//2 + missing_len%2
        right_idx = missing_len//2 + missing_len%2
        left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[left_idx:]).unsqueeze(0)
        right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:right_idx]).unsqueeze(0)         
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank) 

        
def df_to_onehot_tensor(in_df, seq_column='sequence'):
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

def fasta_to_tensor(file_name):
    fasta_dict = {}
    with open(file_name, 'r') as f:
        for line in f:
            line_str = str(line)
            if line_str[0] == '>':
                my_id = line_str.lstrip('>').rstrip('\n')
                fasta_dict[my_id] = ''
            else:
                fasta_dict[my_id] += line_str.rstrip('\n')
    seq_tensors = []
    for sequence in list(fasta_dict.values()):
        seq_tensors.append(utils.dna2tensor(sequence))
    return torch.stack(seq_tensors, dim=0)

def dna2tensor_approx(sequence_str, vocab_list=constants.STANDARD_NT, N_value=0.25):
    seq_tensor = np.zeros((len(vocab_list), len(sequence_str)))
    for letterIdx, letter in enumerate(sequence_str):
        try:
            seq_tensor[vocab_list.index(letter), letterIdx] = 1
        except:
            seq_tensor[:, letterIdx] = N_value
    seq_tensor = torch.Tensor(seq_tensor)
    return seq_tensor

def frame_print(string, marker='*', left_space=25):
    left_spacer = left_space * ' '
    string = marker + ' ' + string.upper() + ' ' + marker
    n = len(string)
    print('', flush=True)
    print('', flush=True)
    print(left_spacer + n * marker, flush=True)
    print(left_spacer + string, flush=True)
    print(left_spacer + n * marker, flush=True)
    print('', flush=True)
    print('', flush=True)
    
def decor_print(string):
    decor = 15*'-'
    print('', flush=True)
    print(decor + ' ' + string + ' ' + decor, flush=True)
    print('', flush=True)

In [10]:
def isg_contributions(sequences,
                      predictor,
                      num_steps=50,
                      num_samples=20,
                      eval_batch_size=1024,
                      theta_factor=15):
    
    batch_size = eval_batch_size // num_samples
    temp_dataset = TensorDataset(sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    all_salient_maps = []
    for local_batch in tqdm(temp_dataloader):
        target_thetas = (theta_factor * local_batch[0].cuda()).requires_grad_()
        line_gradients = []
        for i in range(0, num_steps + 1):
            point_thetas = (i / num_steps * target_thetas)
            point_distributions = F.softmax(point_thetas, dim=-2)

            nucleotide_probs = Categorical(torch.transpose(point_distributions, -2, -1))
            sampled_idxs = nucleotide_probs.sample((num_samples, ))
            sampled_nucleotides_T = F.one_hot(sampled_idxs, num_classes=4)
            sampled_nucleotides = torch.transpose(sampled_nucleotides_T, -2, -1)
            distribution_repeater = point_distributions.repeat(num_samples, *[1 for i in range(3)])
            sampled_nucleotides = sampled_nucleotides - distribution_repeater.detach() + distribution_repeater 
            samples = sampled_nucleotides.flatten(0,1)

            preds = predictor(samples)
            point_predictions = preds.unflatten(0, (num_samples, target_thetas.shape[0])).mean(dim=0)
            point_gradients = torch.autograd.grad(point_predictions.sum(), inputs=point_thetas, retain_graph=True)[0]
            line_gradients.append(point_gradients)
            
        gradients = torch.stack(line_gradients).mean(dim=0) 
        all_salient_maps.append(gradients * target_thetas)
    return torch.cat(all_salient_maps)

In [11]:
k562_predictor = mpra_predictor(model=model, pred_idx=0).cuda()

In [12]:
left_pad = 200
right_pad = 200
locus_chr = '11'
locus_start = 61787329 - left_pad
locus_end = 61898348 + right_pad - 1
locus_coord = 'chr' + locus_chr + ':'+ f'{locus_start:,}' + '-' + f'{locus_end:,}'
print(locus_coord)

chr11:61,787,129-61,898,547


In [13]:
#! gsutil cp gs://syrgoth/data/locus_select/chr11-61,787,129-61,898,547.txt ./

locus_file = 'chr11-61,787,129-61,898,547.txt'
locus_str = ''
with open(locus_file) as f:
    for line in f:
        if line[0] != '>':
            locus_str += line.strip()
            
print(len(locus_str[left_pad:-right_pad]), len(range(61787329, 61898348)))

111019 111019


In [14]:
locus_tensor = dna2tensor_approx(locus_str, N_value=0.)
locus_tensor.shape

torch.Size([4, 111419])

In [15]:
#create windows
window_len = 200
step_size = 10
locus_tensor_windows = [locus_tensor[:, start:start+window_len] for start in range(0, locus_tensor.shape[1]-window_len+1, step_size)]
locus_tensor_windows = torch.stack(locus_tensor_windows)

In [16]:
windows_coordinates = [f'chr{locus_chr}:{locus_start + start}-{locus_start + start + window_len-1}' for start in range(0, locus_tensor.shape[1]-window_len+1, step_size)]

In [17]:
chunk_example = locus_tensor_windows[:1020,...]
chunk_example.shape

torch.Size([1020, 4, 200])

In [56]:
%%time
data_tensor = locus_tensor_windows #locus_tensor_windows
chunk_size = 5561 #10002 #204
eval_batch_size = 1040

cell_type = 'k562'
targetdir = 'fads_locus/contributions_v1'

print(f'Results will be saved at {targetdir}', flush=True)
print('', flush=True)

num_chunks = math.ceil(data_tensor.shape[0] / chunk_size)
processed_chunks = 0
for i in range(0, data_tensor.shape[0], chunk_size):
    start_time = datetime.now()
    
    decor_print(f'Processing chunk {processed_chunks+1}/{num_chunks}')
        
    chunk = data_tensor[i:i + chunk_size, ...]    
    
    salient_maps = isg_contributions(chunk, k562_predictor, eval_batch_size=eval_batch_size)
    coordinate_list = windows_coordinates[i:i + chunk.shape[0]]
    
    save_dict = {}
    save_dict['window_contributions'] = salient_maps
    save_dict['window_coordinates'] = coordinate_list
    
    first_coordinate = coordinate_list[0].split('-')[0]
    last_coordinate = coordinate_list[-1].split('-')[1]
    chunk_name = f'fads_locus_contributions__{cell_type}__window_len_{window_len}__step_size_{step_size}'
    chunk_name += f'__{first_coordinate}-{last_coordinate}' + '.pt'
    
    save_path = os.path.join(targetdir, chunk_name)   
    torch.save(save_dict, save_path)
    
    print(f'Contributions saved in {save_path}')
    print('', flush=True)
    
    processed_chunks += 1
    left_chunks = num_chunks - processed_chunks
    end_time = datetime.now()
    chunk_time = end_time - start_time
    
    print(f'Chunk processing time: {chunk_time}', flush=True)
    print('', flush=True)
    print(f'Estimated time remaining: {chunk_time*left_chunks}', flush=True)
    print('', flush=True)

Results will be saved at fads_locus/contributions_v1


--------------- Processing chunk 1/2 ---------------



  0%|          | 0/107 [00:00<?, ?it/s]

Contributions saved in fads_locus/contributions_v1/fads_locus_contributions__k562__window_len_200__step_size_10__chr11:61787129-61842928.pt

Chunk processing time: 0:09:16.025981

Estimated time remaining: 0:09:16.025981


--------------- Processing chunk 2/2 ---------------



  0%|          | 0/107 [00:00<?, ?it/s]

Contributions saved in fads_locus/contributions_v1/fads_locus_contributions__k562__window_len_200__step_size_10__chr11:61842739-61898538.pt

Chunk processing time: 0:09:16.687396

Estimated time remaining: 0:00:00

CPU times: user 18min 31s, sys: 1.68 s, total: 18min 33s
Wall time: 18min 32s
