In [None]:
%matplotlib inline

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['font.family'] = 'Arial'

In [None]:
import sys
sys.path.append('./../src/')

In [None]:
import glob
import os

import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt

from scipy.stats import spearmanr

import nar170830f_predictions as forec
import resci_inout as rinout
import resci_tools as ret

In [None]:
save_images = False
save_tables = True

In [None]:
category_of_interest = 'budget_for_attention'

In [None]:
# prediction_folders = glob.glob(
#     os.path.join(
#         rinout.get_internal_path(
#             '171202f_predict_money'), 
#         '171015_human_*{}'.format(category_of_interest)))
model_name = 'zgbrh_p90_e300'

In [None]:
def get_predictions(comparison):

    p = os.path.join(
        rinout.get_internal_path(
                '171202f_predict_money'),
        comparison,
        'zgbrh_p90_e300',
        'pooled_target_and_prediciton.csv.gz'
    )
    df = pd.read_csv(p)
    df = df.set_index('gene_ncbi')
    target = df.pop('target')
    predictions = df
    
    return target, predictions


def pooling_fun(x):
    return np.nanmedian(x)

In [None]:
c = [
 '171202_human_BioExpYearhomallDis_log_budget_for_attention',
 '171202_human_BioExpDis_log_budget_for_attention',
 '171202_human_BioExpYearhomall_log_budget_for_attention',
 '171202_human_BioExp_log_budget_for_attention',
 '171202_human_Dis_log_budget_for_attention',
]

In [None]:
def make_plot(df):
#     years_to_display = (1960, 2015)
    sns.jointplot(
        x='target',
        y='predicted',
        kind='reg',
        data=df,
#         xlim=years_to_display,
#         ylim=years_to_display,
        joint_kws={
            'line_kws':{'color':'gray'},
            'lowess': True,
            'scatter_kws':{'s':1}},
        stat_func=spearmanr)

In [None]:
for comparison in c:
    print(comparison)
    ta, po = get_predictions(comparison)
    po = po.apply(pooling_fun, axis=1)
    df = pd.concat(
            [ta, po], axis=1, join='inner').rename(
            columns={0: 'predicted'})

    hexplot = sns.jointplot(
        x='target',
        y='predicted',
        kind='hex',
        data=df,
        gridsize=30,
        stat_func=spearmanr
    )

    plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2)  # shrink fig so cbar is visible
    cax = hexplot.fig.add_axes([.85, .25, .05, .4])  # x, y, width, height
    plt.colorbar(cax=cax)
        
    if save_images:
        ret.export_image('171208f_visualize_budget_predictions/{}_hex_{}.pdf'.format(
            category_of_interest, comparison))
   
    plt.close()
    
    
    
    
    
    make_plot(df)
    if save_images:
        ret.export_image('171208f_visualize_budget_predictions/{}_scatter_fixed_y_{}.pdf'.format(
            category_of_interest, comparison))
    plt.ylim(2.5, 9.5)
    plt.close()

    
    
    
    make_plot(df)
    if save_images:
        ret.export_image('171208f_visualize_budget_predictions/{}_scatter_{}.pdf'.format(
            category_of_interest, comparison))
        
    if save_tables:
        ret.export_full_frame(
            '171208f_visualize_budget_predictions/{}_{}_data.csv'.format(
                category_of_interest,
                comparison),
        df
        )

        
    plt.close()

In [None]:
# df.head()

In [None]:
# sns.jointplot(x='target', y='predicted', data=df, kind='reg')