In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
import math
import time
import random
import tempfile
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs

from google.cloud import storage
import csv
from io import StringIO

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm
from boda.model.mpra_basset import MPRA_Basset
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

In [2]:
#for fixed-length sequences
def df_to_onehot_tensor(in_df, seq_column='sequence'):
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) \
                                for subsequence in tqdm(in_df[seq_column])])
    return onehot_sequences

#for variable-length sequences
def get_onehots(in_df, seq_column='nt_sequence', extra_str=''):
    padding_fn = partial(utils.row_pad_sequence,
                            in_column_name=seq_column,
                            padded_seq_len=600)    
    print('Padding sequences' + extra_str)
    sequence_list = list(in_df.progress_apply(padding_fn, axis=1))     
    print('Tokenizing sequences' + extra_str)
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) for subsequence in tqdm(sequence_list)])
    return onehot_sequences

def get_predictions(onehot_sequences, model, eval_batch_size = 128, num_workers=2, extra_str=''):
    temp_dataset = TensorDataset(onehot_sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=num_workers)
    print('Getting predictions' + extra_str)  
    preds = []
    for local_batch in tqdm(temp_dataloader):
        preds.append(model(local_batch[0].cuda()).cpu().detach().numpy())       
    preds_array = np.concatenate(preds, axis=0)  
    return preds_array

def entropy(X):
    p_c = F.softmax(torch.tensor(X, dtype=torch.float32), dim=1).numpy()
    return np.sum(- p_c * np.log(p_c), axis=1)

def dna2tensor_approx(sequence_str, vocab_list=constants.STANDARD_NT):
    seq_tensor = np.zeros((len(vocab_list), len(sequence_str)))
    for letterIdx, letter in enumerate(sequence_str):
        try:
            seq_tensor[vocab_list.index(letter), letterIdx] = 1
        except:
            seq_tensor[:, letterIdx] = 0.25
    seq_tensor = torch.Tensor(seq_tensor)
    return seq_tensor

def frame_print(string, marker='*', left_space=25):
    left_spacer = left_space * ' '
    string = marker + ' ' + string.upper() + ' ' + marker
    n = len(string)
    print('', flush=True)
    print('', flush=True)
    print(left_spacer + n * marker, flush=True)
    print(left_spacer + string, flush=True)
    print(left_spacer + n * marker, flush=True)
    print('', flush=True)
    print('', flush=True)
    
def decor_print(string):
    decor = 15*'-'
    print('', flush=True)
    print(decor + ' ' + string + ' ' + decor, flush=True)
    print('', flush=True)
    
def over_max_bent(x, bias_cell=0, bending_factor=1.0):
    x = x - bending_factor * (torch.exp(-x) - 1)
    target = x[...,bias_cell]
    non_target_max = x[...,[ i for i in range(x.shape[-1]) if i != bias_cell]].max(-1).values
    return target - non_target_max

In [3]:
if os.path.isdir('./artifacts'):
    shutil.rmtree('./artifacts')
hpo_rec = 'gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz'
unpack_artifact(hpo_rec)

model_dir = './artifacts'
model = model_fn(model_dir)
model.cuda()
model.eval()
print('')

archive unpacked in ./


Loaded model from 20211113_021200 in eval mode



In [4]:
left_flank = utils.dna2tensor(constants.MPRA_UPSTREAM[-200:]).unsqueeze(0)
right_flank = utils.dna2tensor(constants.MPRA_DOWNSTREAM[:200] ).unsqueeze(0)

In [5]:
def fasta_blob_to_prediction(blob, model, left_flank, right_flank, eval_batch_size=512, pre_batch_size=50000, num_workers=2):
    blob = blob.download_as_string()
    blob = blob.decode('utf-8')
    blob = StringIO(blob)  
    lines = csv.reader(blob)
    fasta_dict = {}
    for line in lines:
        line_str = str(line[0])
        if line_str[0] == '>':
            my_id = line_str.lstrip('>')
            fasta_dict[my_id] = ''
        else:
            fasta_dict[my_id] += line_str.upper()
    temp_df = pd.DataFrame(fasta_dict.items(), columns=['ID', 'nt_sequence'])
    #temp_df['seq_len'] = temp_df.apply(lambda x: len(x['nt_sequence']), axis=1)
    preds = []
    df_len = len(temp_df)
    print(f'Getting {df_len:,} predictions', flush=True) 
    for batch_start in tqdm((range(0, df_len, pre_batch_size))):
        batch_end = batch_start + pre_batch_size
        sub_temp_df = temp_df[batch_start : batch_end]
        onehot_sequences = torch.stack([dna2tensor_approx(subsequence) \
                                        for subsequence in sub_temp_df['nt_sequence']])
        pieces = [left_flank.repeat(onehot_sequences.shape[0], 1, 1), \
                onehot_sequences, \
                right_flank.repeat(onehot_sequences.shape[0], 1, 1)]
        input_tensor = torch.cat(pieces, axis=-1)
        temp_dataset = TensorDataset(input_tensor)
        temp_dataloader = DataLoader(temp_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=num_workers)
        for local_batch in temp_dataloader:
            preds.append(model(local_batch[0].cuda()).cpu().detach().numpy())        
    preds_array = np.concatenate(preds, axis=0)
    temp_df[['K562_pred', 'HepG2_pred', 'SKNSH_pred']] = preds_array    
    return temp_df

In [6]:
%%time

eval_batch_size = 512
fata_names = ['cisX_tall_mel_ase_variants_200.fa', 'hbe1_mpra_sat_mut.fa']
# rootdir = 'data/alkes/fastas/'
# targetdir = 'gs://syrgoth/data/alkes/predictions_v1'
rootdir = 'data/cosmic/fastas/'
targetdir = 'gs://syrgoth/data/cosmic/predictions'

bucket = storage.Client().get_bucket('syrgoth')
for blob in bucket.list_blobs(prefix=rootdir):
    filepath = blob.name
    if filepath.endswith('.fa'): 
        base_name = os.path.basename(blob.name)
        if base_name in fata_names:
            out_file_name = base_name.rstrip('.fa') + '_pred.txt'
            cloud_target = os.path.join(targetdir, out_file_name)

            decor_print(f'Parsing {base_name}')
            pred_df = fasta_blob_to_prediction(blob=blob,
                                             model=model,
                                             left_flank=left_flank,
                                             right_flank=right_flank,
                                             eval_batch_size=512)
            with tempfile.TemporaryDirectory() as tmpdir:
                temp_loc = os.path.join(tmpdir, base_name)  
                pred_df.to_csv(temp_loc, index=None, sep='\t', float_format='%.15f')                
                subprocess.check_call(
                    ['gsutil', 'cp', temp_loc, cloud_target]
                )
                print('Predictions saved in ' + cloud_target, flush=True)
                print('', flush=True)


--------------- Parsing cisX_tall_mel_ase_variants_200.fa ---------------

Getting 352 predictions


  0%|          | 0/1 [00:00<?, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


Predictions saved in gs://syrgoth/data/cosmic/predictions/cisX_tall_mel_ase_variants_200_pred.txt


--------------- Parsing hbe1_mpra_sat_mut.fa ---------------

Getting 2,478 predictions


  0%|          | 0/1 [00:00<?, ?it/s]

Predictions saved in gs://syrgoth/data/cosmic/predictions/hbe1_mpra_sat_mut_pred.txt

CPU times: user 768 ms, sys: 239 ms, total: 1.01 s
Wall time: 5.42 s
