In [1]:
import sys
import os
import subprocess
import tarfile
import shutil
from functools import partial
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (random_split, DataLoader, TensorDataset, ConcatDataset)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm
from boda.model.mpra_basset import MPRA_Basset
from boda.common import constants, utils

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

In [2]:
#for variable-length sequences
def get_onehots(in_df, seq_column='nt_sequence', extra_str=''):
    padding_fn = partial(utils.row_pad_sequence,
                            in_column_name=seq_column,
                            padded_seq_len=600)    
    print('Padding sequences' + extra_str)
    sequence_list = list(in_df.progress_apply(padding_fn, axis=1))     
    print('Tokenizing sequences' + extra_str)
    onehot_sequences = torch.stack([utils.dna2tensor(subsequence) for subsequence in tqdm(sequence_list)])
    return onehot_sequences

def get_predictions(onehot_sequences, model, eval_batch_size = 128, num_workers=2, extra_str=''):
    temp_dataset = TensorDataset(onehot_sequences)
    temp_dataloader = DataLoader(temp_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=num_workers)
    print('Getting predictions' + extra_str)  
    preds = []
    for local_batch in tqdm(temp_dataloader):
        preds.append(model(local_batch[0].cuda()).cpu().detach().numpy())       
    preds_array = np.concatenate(preds, axis=0)  
    return preds_array

def append_predictions(dfs, model_paths, model_nicknames=None):
    activity_columns = ['K562_mean', 'HepG2_mean', 'SKNSH_mean']
    print('------------- Getting input tensors for each df -------------')
    print('')
    onehot_inputs = [get_onehots(df) for df in dfs]    
    if model_nicknames is None:
        model_nicknames = [str(i) for i in range(1, len(model_paths)+1)]
    assert len(model_nicknames) == len(model_paths)
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')
    prediction_columns_dict = {}
    for model_idx, model_path in enumerate(model_paths):
        unpack_artifact(model_path)
        model_dir = './artifacts'
        model = model_fn(model_dir)
        model.cuda()
        model.eval()
        model_nickname = model_nicknames[model_idx]
        prediction_columns = [activity_name.rstrip('mean') + 'pred_' \
                              + model_nickname for activity_name in activity_columns]
        prediction_columns_dict[model_nickname] = prediction_columns
        print('')
        print(f'------------- Getting model_{model_nickname} predictions for each df -------------')
        print('')
        for df_idx, df in enumerate(dfs):
            df[prediction_columns] = get_predictions(onehot_inputs[df_idx], model)
    return prediction_columns_dict

def single_scatterplot(data_df, x_axis, y_axis, color_axis, fig_size=(15,8), dot_size=0.5, title='',
                       dot_alpha=0.5, style='seaborn-whitegrid', colormap='winter',
                       x_label='True', y_label='Predicted', color_label='l2fc SE', title_font_size=18,
                       title_font_weight='medium', axis_font_size=16):
    with plt.style.context(style):
        fig, ax = plt.subplots()    
        data_df.plot(kind='scatter', x=x_axis, y=y_axis, figsize=fig_size, c=color_axis, ax=ax,
                        alpha=dot_alpha, s=dot_size, colormap=colormap)
        plt.xlabel(x_label, fontsize=axis_font_size)
        plt.ylabel(y_label, fontsize=axis_font_size)

        f = plt.gcf()
        cax = f.get_axes()[1]
        cax.set_ylabel(color_label, fontsize=axis_font_size)

        x_min, y_min = data_df[[x_axis, y_axis]].min().to_numpy() 
        x_max, y_max = data_df[[x_axis, y_axis]].max().to_numpy()
        min_point, max_point = max(x_min, y_min), min(x_max, y_max)
        plt.plot((min_point,max_point), (min_point,max_point), color='black', linestyle='--', alpha=0.5)

        Pearson = round(data_df[[x_axis, y_axis]].corr(method='pearson')[x_axis][1], 2)
        Spearman = round(data_df[[x_axis, y_axis]].corr(method='spearman')[x_axis][1], 2)

        title = f'{title}  |  Pearson={Pearson}  Spearman={Spearman}'
        ax.set_title(title, fontdict={'fontsize': title_font_size, 'fontweight': title_font_weight}, pad=15)

In [12]:
# gtex_df = pd.read_csv('gs://syrgoth/data/MPRA_GTEX.txt', sep=" ", low_memory=False)
# gtex_noShrink_df = pd.read_csv('gs://syrgoth/data/MPRA_GTEX_cellDisp_noShrink.txt', sep=" ", low_memory=False)
# gtex_Shrink_df = pd.read_csv('gs://syrgoth/data/MPRA_GTEX_cellDisp_Shrink.txt', sep=" ", low_memory=False)

all_boda_df = pd.read_csv('gs://syrgoth/data/MPRA_ALL_no_cutoffs.txt', sep=" ", low_memory=False)
all_boda_df.at[345812, 'nt_sequence'] = 'TGTAGAAAAAAATATATATATATATGAACAACGCATAATCCTGGAAATATAAGGAAAAATTAAATTTTCTCCTCTGGGAAAAATTTATACAGTAATGATTCTTGCTCTTTAATTTTTGTTTGAAAGAAATCTAGACATTTAAAAAACCCCAGTGGTAGAATTGTCTTGTTAAAAAGGGACATCAAGTAAAAGGCCAGGGG'


In [16]:
model_nicknames = ['relu', 'relu6', 'relu_HD']
model_paths = ['gs://syrgoth/aip_ui_test/model_artifacts__20211113_021200__287348.tar.gz',
               'gs://syrgoth/aip_ui_test/model_artifacts__20211110_194934__672830.tar.gz',
               'gs://syrgoth/aip_ui_test/model_artifacts__20211119_011437__338420.tar.gz']

# performance_dfs = [gtex_df, gtex_noShrink_df, gtex_Shrink_df]
performance_dfs = [all_boda_df]

prediction_columns_dict = append_predictions(performance_dfs, model_paths, model_nicknames=model_nicknames)

------------- Getting input tensors for each df -------------

Padding sequences


  0%|          | 0/813051 [00:00<?, ?it/s]

Tokenizing sequences


  0%|          | 0/813051 [00:00<?, ?it/s]

archive unpacked in ./


Loaded model from 20211113_021200 in eval mode

------------- Getting model_relu predictions for each df -------------

Getting predictions


  0%|          | 0/6352 [00:00<?, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


Loaded model from 20211110_194934 in eval mode

------------- Getting model_relu6 predictions for each df -------------

Getting predictions


archive unpacked in ./


  0%|          | 0/6352 [00:00<?, ?it/s]

Loaded model from 20211119_011437 in eval mode

------------- Getting model_relu_HD predictions for each df -------------

Getting predictions


archive unpacked in ./


  0%|          | 0/6352 [00:00<?, ?it/s]

In [17]:
avg_prediction_columns = ['K562_pred_aggreg', 'HepG2_pred_aggreg', 'SKNSH_pred_aggreg']
for i in range(len(avg_prediction_columns)):
    for df in performance_dfs:
        df[avg_prediction_columns[i]] = df[[columns[i] for columns in prediction_columns_dict.values()]].mean(axis=1)

In [None]:
# column_drop_list = [column for sublist in prediction_columns_dict.values() for column in sublist]

# for df in performance_dfs:
#     df.drop(column_drop_list, axis=1, inplace=True)

In [20]:
prediction_columns_dict

{'relu': ['K562_pred_relu', 'HepG2_pred_relu', 'SKNSH_pred_relu'],
 'relu6': ['K562_pred_relu6', 'HepG2_pred_relu6', 'SKNSH_pred_relu6'],
 'relu_HD': ['K562_pred_relu_HD', 'HepG2_pred_relu_HD', 'SKNSH_pred_relu_HD']}

In [25]:
#pd.set_option('display.max_columns', None)

all_boda_df.columns

Index(['HepG2_mean', 'HepG2_std', 'ID_count', 'IDs', 'K562_mean', 'K562_std',
       'OL', 'OL_count', 'SKNSH_mean', 'SKNSH_std', 'chr', 'class',
       'ctrl_mean_hepg2', 'ctrl_mean_k562', 'ctrl_mean_sknsh', 'data_project',
       'exp_mean_hepg2', 'exp_mean_k562', 'exp_mean_sknsh', 'lfcSE_hepg2',
       'lfcSE_k562', 'lfcSE_sknsh', 'nt_sequence', 'K562_pred_relu',
       'HepG2_pred_relu', 'SKNSH_pred_relu', 'K562_pred_relu6',
       'HepG2_pred_relu6', 'SKNSH_pred_relu6', 'K562_pred_relu_HD',
       'HepG2_pred_relu_HD', 'SKNSH_pred_relu_HD', 'K562_pred_aggreg',
       'HepG2_pred_aggreg', 'SKNSH_pred_aggreg'],
      dtype='object')

In [24]:
#file_names = ['MPRA_GTEX_pred.txt', 'MPRA_GTEX_cellDisp_noShrink_pred.txt', 'MPRA_GTEX_cellDisp_Shrink_pred.txt']
file_names = ['MPRA_ALL_no_cutoffs_pred.txt']

#performance_dfs = [gtex_df, gtex_noShrink_df, gtex_Shrink_df]
for name_idx, df in tqdm(enumerate(performance_dfs)):
    df.to_csv(file_names[name_idx], index=None, sep=' ')

0it [00:00, ?it/s]