In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import random
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs
import pickle
from datetime import datetime
import scipy.stats as stats
import math

import boda
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

from torch.distributions.categorical import Categorical
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm, counts_to_ppm

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

### Get coordinates of the last window of the previous run

In [2]:
rootdir = 'gata_locus/contributions_v1'

chunk_dicts = {}
for subdir, dirs, files in os.walk(rootdir):
    for file in tqdm(files):
        full_coordinates = file.split('__')[-1].rstrip('.pt')
        chunk_dicts[full_coordinates] = torch.load(os.path.join(rootdir, file))
        
all_window_coordinates = [coordinate for key in sorted(chunk_dicts.keys()) for coordinate in chunk_dicts[key]['window_coordinates']]

  0%|          | 0/20 [00:00<?, ?it/s]

0it [00:00, ?it/s]

In [3]:
all_window_coordinates[-1]

'chrX:49644750-49644949'

In [5]:
real_end_idx = 49880397 + 202
second_file_coordinates = f'chrX:49644750-{real_end_idx}'
print(second_file_coordinates)

chrX:49644750-49880599


### Now get the contributions

In [6]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
#model.cuda()
model.eval()
print('')

archive unpacked in ./


Loaded model from 20211113_021200 in eval mode



In [7]:
class mpra_predictor(nn.Module):
    def __init__(self,
                 model,
                 pred_idx=0,
                 ini_in_len=200,
                 model_in_len=600,
                 cat_axis=-1):
        super().__init__()
        self.model = model
        self.pred_idx = pred_idx
        self.ini_in_len = ini_in_len 
        self.model_in_len = model_in_len
        self.cat_axis = cat_axis       
        
        try: self.model.eval()
        except: pass
        
        self.register_flanks()
    
    def forward(self, x):
        pieces = [self.left_flank.repeat(x.shape[0], 1, 1), x, self.right_flank.repeat(x.shape[0], 1, 1)]
        in_tensor = torch.cat( pieces, axis=self.cat_axis)
        out_tensor = self.model(in_tensor)[:, self.pred_idx]
        return out_tensor
    
    def register_flanks(self):
        missing_len = self.model_in_len - self.ini_in_len
        left_idx = - missing_len//2 + missing_len%2
        right_idx = missing_len//2 + missing_len%2
        left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[left_idx:]).unsqueeze(0)
        right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:right_idx]).unsqueeze(0)         
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank) 

        
def df_to_onehot_tensor(in_df, seq_column='sequence'):
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

def fasta_to_tensor(file_name):
    fasta_dict = {}
    with open(file_name, 'r') as f:
        for line in f:
            line_str = str(line)
            if line_str[0] == '>':
                my_id = line_str.lstrip('>').rstrip('\n')
                fasta_dict[my_id] = ''
            else:
                fasta_dict[my_id] += line_str.rstrip('\n')
    seq_tensors = []
    for sequence in list(fasta_dict.values()):
        seq_tensors.append(utils.dna2tensor(sequence))
    return torch.stack(seq_tensors, dim=0)

def dna2tensor_approx(sequence_str, vocab_list=constants.STANDARD_NT, N_value=0.25):
    seq_tensor = np.zeros((len(vocab_list), len(sequence_str)))
    for letterIdx, letter in enumerate(sequence_str):
        try:
            seq_tensor[vocab_list.index(letter), letterIdx] = 1
        except:
            seq_tensor[:, letterIdx] = N_value
    seq_tensor = torch.Tensor(seq_tensor)
    return seq_tensor

def frame_print(string, marker='*', left_space=25):
    left_spacer = left_space * ' '
    string = marker + ' ' + string.upper() + ' ' + marker
    n = len(string)
    print('', flush=True)
    print('', flush=True)
    print(left_spacer + n * marker, flush=True)
    print(left_spacer + string, flush=True)
    print(left_spacer + n * marker, flush=True)
    print('', flush=True)
    print('', flush=True)
    
def decor_print(string):
    decor = 15*'-'
    print('', flush=True)
    print(decor + ' ' + string + ' ' + decor, flush=True)
    print('', flush=True)
    
def isg_contributions(sequences,
                      predictor,
                      num_steps=50,
                      num_samples=20,
                      eval_batch_size=1024,
                      theta_factor=15):
    
    batch_size = eval_batch_size // num_samples
    temp_dataset = TensorDataset(sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    all_salient_maps = []
    for local_batch in tqdm(temp_dataloader):
        target_thetas = (theta_factor * local_batch[0].cuda()).requires_grad_()
        line_gradients = []
        for i in range(0, num_steps + 1):
            point_thetas = (i / num_steps * target_thetas)
            point_distributions = F.softmax(point_thetas, dim=-2)

            nucleotide_probs = Categorical(torch.transpose(point_distributions, -2, -1))
            sampled_idxs = nucleotide_probs.sample((num_samples, ))
            sampled_nucleotides_T = F.one_hot(sampled_idxs, num_classes=4)
            sampled_nucleotides = torch.transpose(sampled_nucleotides_T, -2, -1)
            distribution_repeater = point_distributions.repeat(num_samples, *[1 for i in range(3)])
            sampled_nucleotides = sampled_nucleotides - distribution_repeater.detach() + distribution_repeater 
            samples = sampled_nucleotides.flatten(0,1)

            preds = predictor(samples)
            point_predictions = preds.unflatten(0, (num_samples, target_thetas.shape[0])).mean(dim=0)
            point_gradients = torch.autograd.grad(point_predictions.sum(), inputs=point_thetas, retain_graph=True)[0]
            line_gradients.append(point_gradients)
            
        gradients = torch.stack(line_gradients).mean(dim=0) 
        all_salient_maps.append(gradients * target_thetas)
    return torch.cat(all_salient_maps)

In [8]:
k562_predictor = mpra_predictor(model=model, pred_idx=0).cuda()

In [36]:
#! gsutil cp gs://syrgoth/data/locus_select/chrX-49,644,750-49,880,599.txt ./

left_jump = 10
gata_chr = 'X'

gata_locus_file = 'chrX-49,644,750-49,880,599.txt'
gata_locus_str = ''
with open(gata_locus_file) as f:
    for line in f:
        if line[0] != '>':
            gata_locus_str += line.strip()

gata_locus_str = gata_locus_str[left_jump:]         
print(len(gata_locus_str), len(range(49644750+left_jump, 49880600)))

235840 235840


In [37]:
gata_locus_tensor = dna2tensor_approx(gata_locus_str, N_value=0.)
gata_locus_tensor.shape

torch.Size([4, 235840])

In [38]:
#create windows
window_len = 200
step_size = 10
locus_tensor_windows = [gata_locus_tensor[:, start:start+window_len] for start in range(0, gata_locus_tensor.shape[1]-window_len+1, step_size)]
locus_tensor_windows = torch.stack(locus_tensor_windows)

In [39]:
gata_locus_start = 49644750 + left_jump

windows_coordinates = [f'chr{gata_chr}:{gata_locus_start + start}-{gata_locus_start + start + window_len-1}' for start in range(0, gata_locus_tensor.shape[1]-window_len+1, step_size)]

In [42]:
chunk_example = locus_tensor_windows[:1020,...]
chunk_example.shape

torch.Size([1020, 4, 200])

In [55]:
%%time
data_tensor = locus_tensor_windows #locus_tensor_windows
chunk_size = 10002 #10002 #204
eval_batch_size = 1040

cell_type = 'k562'
targetdir = 'gata_locus/contributions_v1'

print(f'Results will be saved at {targetdir}', flush=True)
print('', flush=True)

num_chunks = math.ceil(data_tensor.shape[0] / chunk_size)
processed_chunks = 0
for i in range(0, data_tensor.shape[0], chunk_size):
    start_time = datetime.now()
    
    decor_print(f'Processing chunk {processed_chunks+1}/{num_chunks}')
        
    chunk = data_tensor[i:i+chunk_size, ...]    
    
    salient_maps = isg_contributions(chunk, k562_predictor, eval_batch_size=eval_batch_size)
    coordinate_list = windows_coordinates[i:i+chunk.shape[0]]
    
    save_dict = {}
    save_dict['window_contributions'] = salient_maps
    save_dict['window_coordinates'] = coordinate_list
    
    first_coordinate = coordinate_list[0].split('-')[0]
    last_coordinate = coordinate_list[-1].split('-')[1]
    chunk_name = f'gata_locus_contributions__{cell_type}__window_len_{window_len}__step_size_{step_size}'
    chunk_name += f'__{first_coordinate}-{last_coordinate}' + '.pt'
    
    save_path = os.path.join(targetdir, chunk_name)   
    torch.save(save_dict, save_path)
    
    print(f'Contributions saved in {save_path}')
    print('', flush=True)
    
    processed_chunks += 1
    left_chunks = num_chunks - processed_chunks
    end_time = datetime.now()
    chunk_time = end_time - start_time
    
    print(f'Chunk processing time: {chunk_time}', flush=True)
    print('', flush=True)
    print(f'Estimated time remaining: {chunk_time*left_chunks}', flush=True)
    print('', flush=True)

Results will be saved at gata_locus/contributions_v1


--------------- Processing chunk 1/3 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49644760-49744969.pt

Chunk processing time: 0:16:38.003676

Estimated time remaining: 0:33:16.007352


--------------- Processing chunk 2/3 ---------------



  0%|          | 0/193 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49744780-49844989.pt

Chunk processing time: 0:16:39.534007

Estimated time remaining: 0:16:39.534007


--------------- Processing chunk 3/3 ---------------



  0%|          | 0/69 [00:00<?, ?it/s]

Contributions saved in gata_locus/contributions_v1/gata_locus_contributions__k562__window_len_200__step_size_10__chrX:49844800-49880599.pt

Chunk processing time: 0:05:56.229789

Estimated time remaining: 0:00:00

CPU times: user 39min 12s, sys: 3.04 s, total: 39min 15s
Wall time: 39min 13s
