# Notebook to plot and print tables

In [None]:
# imports
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import os
import json
import scipy

#!pip install pycombat
import sys
from pycombat import Combat

import matplotlib.pyplot as plt
import utils as utils
import argparse
import socket
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import GroupKFold
import tensorflow as tf
import models as models
import evaluation as evaluation
import traceback
import paccmann_model as paccmann_model
import joblib
import warnings
warnings.filterwarnings("ignore")

%pylab inline
%load_ext autoreload
%autoreload 2

print(socket.gethostname())

## Params

In [None]:
save_dir = 'results/'
save_prefix = ''
source = 'gdsc'
targets = ['ccle','pancreas','xenografts','beat_aml']
target_number_mapping = {'ccle':[0,10,50,100,500,1000,5000,10000],
                         'pancreas':[0,10,50,100,500,10000],
                         'beat_aml':[0,10,50,100,500,1000,5000,10000,20000],
                         'xenografts':[0,10,50,10000]}

flag_normalize_descriptors = True
train_mode = 'precision_oncology'
use_netpop = 'ensemble'

investigate_list = ['pearson_lab']#'pearson_inhib',]#'pearson','pearson_inhib','spearman','spearman_inhib']
metric_mapping_dict = {'pearson':'Pearson correlation',
                      'pearson_inhib':'Per-inhibitor pearson correlation',
                      'spearman':'Spearman correlation',
                      'spearman_inhib':'Per-inhibitor spearman correlation',
                      'pearson_lab': 'Per-specimen pearson correlation'}

model_mapping_dict = {'nn_baseline_scratch' :'Conv NN [scratch]',
                      'nn_baseline_pretrain':'Conv NN [pre-train]',
                      'nn_paccmann_scratch' : 'PaccMann [scratch]',
                      'nn_paccmann_pretrain' :'PaccMann [pre-train]',
                      'rf':'Random Forest',
                      'tDNN_scratch': 'tDNN [scratch]',
                      'tDNN_pretrain': 'tDNN [pre-train]'}


model_names = ['nn_baseline_scratch', 'nn_baseline_pretrain', 
               'nn_paccmann_scratch', 'nn_paccmann_pretrain',
               'tDNN_scratch', 'tDNN_pretrain',
              # 'rf',
              ]

investigate_list_small = ['pearson_lab']#['pearson_inhib','pearson_lab']

# plot params
pval = 0.05 #p-value for ttest
decs = 3
invers_scale_list = [False]

In [None]:
for target in targets:
    use_samples_list = target_number_mapping[target]
    list_use_samples = []
    all_result_dicts = []
    for use_samples in use_samples_list:
        cur_load_path = save_dir +  save_prefix + source + '_' + target + '_' + str(use_samples) +\
                '_' + str(flag_normalize_descriptors) + '_' + str(train_mode) + '_' + str(use_netpop) + '.joblib'
        print('cur_load_path: ' + cur_load_path)
        if not os.path.exists(cur_load_path):
            print('does not exist')
            continue
        cur_tmp_dict = joblib.load(cur_load_path)
        all_result_dicts.append(cur_tmp_dict)
        list_use_samples.append(use_samples)
    
    print('source: ' + str(source))
    print('target: ' + str(target))
    result_table_dict = dict()
    keys = np.arange(len(all_result_dicts))
    for j in tqdm(np.arange(len(keys))):
        cur_key = list_use_samples[j]
        if cur_key not in result_table_dict:
            result_table_dict[cur_key] = dict()
        cur_num_train_results = all_result_dicts[j]
        model_keys = model_mapping_dict.keys()
        for model_key in model_keys:
            if model_key not in result_table_dict[cur_key]:
                result_table_dict[cur_key][model_key] = dict()
            if model_key in cur_num_train_results:
                cur_model_results = cur_num_train_results[model_key]
                min_max_scaler = None
                if 'min_max_scaler' in cur_num_train_results:
                    min_max_scaler = cur_num_train_results['min_max_scaler']
                tmp_res_dict = evaluation.get_metric_dict(cur_model_results['pred_complete'], cur_model_results['gt_complete'],
                                            cur_model_results['inhib_data_complete'],cur_model_results['lab_data_complete'],
                                            min_max_scaler, invers_scale_list = invers_scale_list,
                                            flag_calc_per_lab = True,
                                            flag_calc_per_inhib = False)

                tmp_res_keys = tmp_res_dict.keys()
                for key in tmp_res_keys:
                    result_table_dict[cur_key][model_key][key] = tmp_res_dict[key]
            
    """
    print('################################################')
    print('#')
    print('# PRINT TABLE')
    print('#')
    print('################################################')
    for investigate in investigate_list:
        num_trains = result_table_dict.keys()
        for num_train in num_trains:
            cur_str = metric_mapping_dict[investigate]
            if num_train is None:
                display_number = 'all'
            else:
                display_number = str(num_train)
            cur_str += ' & ' + display_number
            cur_results = result_table_dict[num_train]        
            result_lists = []
            result_means = []
            result_strings = []
            for model in model_names:
                if model not in cur_results:
                    continue
                model_res = cur_results[model]
                cur_list = model_res[investigate + '_list']
                result_lists.append(cur_list)
                result_means.append(np.nanmean(cur_list))
                result_string = str(np.round(np.nanmean(cur_list),decimals = decs))
                if 'pretrain' in model and model.replace('pretrain','scratch') in model_names:
                    pretrain_list = cur_list
                    scratch_list = cur_results[model.replace('pretrain','scratch')][investigate + '_list']
                    if np.mean(pretrain_list) > np.mean(scratch_list):
                        # perform statistical test to compare pretrain with scratch version
                        stats,pvalue = scipy.stats.ttest_ind(pretrain_list,scratch_list)

                        if pvalue < pval:
                            result_string += '*'
                result_strings.append(result_string)
            value_arg_sort = np.argsort(result_means)[::-1]
            # perform statistical test best with second best model
            stats,pvalue = scipy.stats.ttest_ind(result_lists[value_arg_sort[0]],result_lists[value_arg_sort[1]])

            if pvalue < pval:
                result_strings[value_arg_sort[0]] += '$\dagger$'
            result_strings[value_arg_sort[0]] = '\\textbf{' + result_strings[value_arg_sort[0]] + '}'
            number_string = ' & '.join(result_strings)
            cur_str += ' &' + number_string + '\\\\'
            print(cur_str)
    """


    print('################################################')
    print('#')
    print('# PRINT TABLE ONLY PEARSON')
    print('#')
    print('################################################')    

    for investigate in investigate_list_small:
        print('investigate: ' + str(investigate))
        num_trains = result_table_dict.keys()
        for num_train in num_trains:
            cur_str = ''#metric_mapping_dict[investigate]
            if num_train is None:
                display_number = 'all'
            else:
                display_number = str(num_train)
            cur_str += display_number #' & ' + display_number
            cur_results = result_table_dict[num_train]        
            result_lists = []
            result_means = []
            restult_std_errors = []
            result_strings = []
            for model in model_names:
                if model not in cur_results:
                    continue
                model_res = cur_results[model]
                if len(model_res) > 0:
                    cur_list = model_res[investigate + '_list']
                    result_lists.append(cur_list)
                    result_means.append(np.mean(cur_list))
                    restult_std_errors.append(np.std(cur_list) / np.sqrt(len(cur_list)))
                    result_string = str(np.round(np.mean(cur_list),decimals = decs)) +\
                                ' $\pm$ ' + str(np.round(np.std(cur_list) / np.sqrt(len(cur_list)),decimals = decs))
                    if 'pretrain' in model and model.replace('pretrain','scratch') in model_names:
                        pretrain_list = cur_list
                        try:
                            scratch_list = cur_results[model.replace('pretrain','scratch')][investigate + '_list']
                            if np.mean(pretrain_list) > np.mean(scratch_list):
                                # perform statistical test to compare pretrain with scratch version
                                stats,pvalue = scipy.stats.ttest_ind(pretrain_list,scratch_list)

                                if pvalue < pval:
                                    result_string += '*'
                        except:
                            pass
                else:
                    result_string = '--'
                result_strings.append(result_string)
            value_arg_sort = np.argsort(result_means)[::-1]
            # perform statistical test best with second best model
            stats,pvalue = scipy.stats.ttest_ind(result_lists[value_arg_sort[0]],result_lists[value_arg_sort[1]])

            if pvalue < pval:
                result_strings[value_arg_sort[0]] += '$\dagger$'
            result_strings[value_arg_sort[0]] = '\\textbf{' + result_strings[value_arg_sort[0]] + '}'
            number_string = ' & '.join(result_strings)
            cur_str += ' &' + number_string + '\\\\'
            print(cur_str)
    
    """
    print('################################################')
    print('#')
    print('# PRINT TABLE ONLY MSE')
    print('#')
    print('################################################')    

    mse_list = ['MSE','MSE_lab','MSE_inhib']
    for investigate in mse_list:
        print('investigate: ' + str(investigate))
        num_trains = result_table_dict.keys()
        for num_train in num_trains:
            cur_str = ''#metric_mapping_dict[investigate]
            if num_train is None:
                display_number = 'all'
            else:
                display_number = str(num_train)
            cur_str += display_number #' & ' + display_number
            cur_results = result_table_dict[num_train]        
            result_lists = []
            result_means = []
            restult_std_errors = []
            result_strings = []
            for model in model_names:
                if model not in cur_results:
                    continue
                model_res = cur_results[model]
                cur_list = model_res[investigate + '_list']
                result_lists.append(cur_list)
                result_means.append(np.mean(cur_list))
                restult_std_errors.append(np.std(cur_list) / np.sqrt(len(cur_list)))
                result_string = str(np.round(np.mean(cur_list),decimals = decs)) +\
                            ' $\pm$ ' + str(np.round(np.std(cur_list) / np.sqrt(len(cur_list)),decimals = decs))
                if 'pretrain' in model and model.replace('pretrain','scratch') in model_names:
                    pretrain_list = cur_list
                    scratch_list = cur_results[model.replace('pretrain','scratch')][investigate + '_list']
                    if np.mean(pretrain_list) > np.mean(scratch_list):
                        # perform statistical test to compare pretrain with scratch version
                        stats,pvalue = scipy.stats.ttest_ind(pretrain_list,scratch_list)

                        if pvalue < pval:
                            result_string += '*'
                result_strings.append(result_string)
            value_arg_sort = np.argsort(result_means)[::-1]
            # perform statistical test best with second best model
            stats,pvalue = scipy.stats.ttest_ind(result_lists[value_arg_sort[0]],result_lists[value_arg_sort[1]])

            if pvalue < pval:
                result_strings[value_arg_sort[0]] += '$\dagger$'
            result_strings[value_arg_sort[0]] = '\\textbf{' + result_strings[value_arg_sort[0]] + '}'
            number_string = ' & '.join(result_strings)
            cur_str += ' &' + number_string + '\\\\'
            print(cur_str)
    """
    
    print('################################################')
    print('#')
    print('# PLOT FIGURE')
    print('#')
    print('################################################')
    
    use_model_names = ['nn_baseline_scratch',                 
                 'nn_paccmann_scratch',
                 'tDNN_scratch',
                 'nn_baseline_pretrain',
                 'nn_paccmann_pretrain',                
                 'tDNN_pretrain']
    for investigate in investigate_list:
        fig_filename = 'plots/' + train_mode + '_'+ target + '_' + investigate + '.pdf'
        for model_name in use_model_names:
            if model_name not in result_table_dict[list(num_trains)[0]]:
                continue
            display_name = model_mapping_dict[model_name]
            mean_vals = []
            std_err_vals = []
            key_vals = []
            for num_train in num_trains:
                cur_vals = result_table_dict[num_train][model_name][investigate + '_list']
                mean_vals.append(np.mean(cur_vals))
                std_err_vals.append(np.std(cur_vals) / np.sqrt(len(cur_vals)))
                display_key = num_train
                if display_key == list(num_trains)[-1]:
                    display_key = 'all'
                key_vals.append(display_key)
            plt.errorbar(np.arange(len(num_trains)),mean_vals,std_err_vals,label=display_name)

        plt.xticks(np.arange(len(num_trains)),key_vals)
        plt.xlabel('Number of used examples for target data set')
        plt.ylabel(metric_mapping_dict[investigate]) 
        plt.legend(ncol=2,loc=4)
        plt.savefig(fig_filename)
        plt.show()

In [None]:
plot_models = ['tDNN_pretrain','nn_baseline_pretrain',
              'nn_paccmann_pretrain','rf']

for plot_model in plot_models:
    gt_complete = all_result_dicts[-1][plot_model]['gt_complete']
    pred_complete = all_result_dicts[-1][plot_model]['pred_complete']
    gt_completes = []
    pred_completes = []
    for i in range(len(gt_complete)):
        gt_completes += list(gt_complete[i])
        pred_completes += list(pred_complete[i].flatten())
    pearson = scipy.stats.pearsonr(pred_completes,gt_completes)[0]
    plt.scatter(gt_completes,pred_completes,label=model_mapping_dict[plot_model] + ' [' +\
                str(np.round(pearson,decimals=3))+ ']',alpha=0.5)
    plt.legend()
    plt.xlabel('GT')
    plt.ylabel('Prediction')
    plt.show()