In [73]:
%matplotlib inline

import os as os
import collections as col
import itertools as itt
import pickle as pck
import time as ti

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

import numpy as np
import numpy.random as rng
import scipy.stats as stats
import pandas as pd
import seaborn as sns

# What does this do?
# Plot HSP scores vs p-values
# and indicate all possible threshold
# values derived either by rules from
# the literature or by sampling

date = '20180405'

run_plot_score_thresholds = True

save_figures = True

sns.set(style='white',
        font_scale=1.5,
        rc={'font.family': ['sans-serif'],
            'font.sans-serif': ['DejaVu Sans']})

fhgfs_base = '/TL/deep/fhgfs/projects/pebert/thesis/projects/statediff'
cache_dir = os.path.join(fhgfs_base, 'caching/notebooks')

hsp_files_folder = os.path.join(fhgfs_base, 'solidstate/deep')

base_out = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/statediff'
fig_supp = os.path.join(base_out, 'figures', 'pub', 'supp')
fig_main = os.path.join(base_out, 'figures', 'pub', 'main')
fig_collect = os.path.join(base_out, 'figures', 'pub', 'collection')
                   
t_colors = [(102,194,165), (252,141,98), (141,160,203),
            (231,138,195), (166,216,84)]
t_colors = list(map(lambda x: (x[0]/255, x[1]/255, x[2]/255), t_colors))

    
def collect_hsp_threshold_data(rootfolder, cache_file):
    filemode = 'w'
    for root, dirs, datafiles in os.walk(rootfolder):
        if root.endswith('hsp_run') and datafiles:
            for df in datafiles:
                if not df.endswith('.h5'):
                    continue
                if 'cmm18' not in df:
                    continue
                infos = df.split('.')[0].split('_')
                seg = infos[3]
                c1, c2 = infos[5], infos[8]
                fpath = os.path.join(root, df)
                this_data = []
                this_thresholds = None
                scoring = None
                with pd.HDFStore(fpath, 'r') as hdf:
                    md = hdf['/metadata/comparisons']
                    sample_pairs = set(md['sample1']).union(set(md['sample2']))
                    for k in hdf.keys():
                        if k.startswith('/segments'):
                            if k.endswith('/thresholds'):
                                _, _, scoring, _ = k.split('/')
                                this_thresholds = hdf[k]
                                continue
                            this_data.append(hdf[k])
                this_data = pd.concat(this_data, axis=0, ignore_index=False)
                this_data.reset_index(drop=True, inplace=True)
                this_data = this_data.loc[:, ['norm_nat_score', 'segment_pv', 'summed_pv']]
                assert scoring is not None, 'Path to saved: {}'.format(df)
                sample_t = collect_sampling_thresholds(root.replace('hsp_run', 'smp_run'),
                                                       scoring, sample_pairs)
                if sample_t is not None:
                    this_thresholds = pd.concat([this_thresholds, sample_t], axis=1, ignore_index=False)
                with pd.HDFStore(cache_file, filemode) as hdf:
                    hdf.put(os.path.join(seg, c1, c2, scoring, 'data'), this_data)
                    hdf.put(os.path.join(seg, c1, c2, scoring, 'thresholds'), this_thresholds)
                filemode = 'a'
    return cache_file
    

def collect_sampling_thresholds(basefolder, scoring, samples):
    collector = None
    for smpfile in os.listdir(basefolder):
        if smpfile.endswith('.h5'):
            # change after update
            smptype = smpfile.split('.')[0].split('_')[-1]
            smptype = {'smprep': 'replicate', 'smprand': 'random'}[smptype]
            fpath = os.path.join(basefolder, smpfile)
            with pd.HDFStore(fpath, 'r') as hdf:
                # change after update
                load_keys = [k for k in hdf.keys() if k.startswith('/sstsmp') and scoring in k]
                assert load_keys, 'No data to load for scoring {} (path {})'.format(scoring, basefolder)
                dataset = []
                for k in load_keys:
                    chrom_data = hdf[k]
                    chrom_data['chrom'] = k.split('/')[-1]
                    dataset.append(chrom_data)
                dataset = pd.concat(dataset, axis=0, ignore_index=False)
                if smptype == 'replicate':
                    row_idx1 = np.array(dataset['sample1'].isin(samples), dtype=np.bool)
                    row_idx2 = np.array(dataset['sample2'].isin(samples), dtype=np.bool)
                    row_idx = np.logical_and(row_idx1, row_idx2)
                    dataset = dataset.loc[row_idx, :].copy()
                dataset = dataset.groupby('chrom')['norm_nat_score'].mean()
                dataset.name = smptype + '_lo'
                dataset = dataset.to_frame()
                if collector is None:
                    collector = dataset
                else:
                    collector = pd.concat([collector, dataset], axis=1, ignore_index=False)
    return collector
    
    
def create_score_scatter(data, thresholds, title):
    """
    """
    fig, ax = plt.subplots(figsize=(8, 8))
    
    xvals = data['norm_nat_score']
    yvals = data['segment_pv']
    
    zoom_factor = 1
    if xvals.max() < 2000:
        zoom_factor = 5
    elif xvals.max() < 5000:
        zoom_factor = 15
    else:
        zoom_factor = 20
    
    ax.set_ylim(-10, yvals.max() + 25)
        
    ax.scatter(xvals, yvals, s=10, c='dodgerblue', marker='o',
               label=None)
    
    axins = zoomed_inset_axes(ax, zoom_factor, loc='lower right')
    axins.scatter(xvals, yvals, s=5, c='dodgerblue', marker='o',
                  label=None)
    
    t_vals = []
    for t, c in zip(['ferreira_lo', 'loretan_lo', 'quantile_lo',
                     'replicate_lo', 'random_lo'],
                     t_colors):
                    #['red', 'darkviolet', 'darkorange']):
        t_val = thresholds[t].mean()
        t_vals.append(t_val)
        t_label = t.split('_')[0].capitalize()
        ax.axvline(t_val, ymin=0.01, ymax=0.1,
                   color=c, linestyle='dashed', linewidth=3,
                   label=t_label, zorder=0)
        
        axins.axvline(t_val, ymin=0.01, ymax=0.99,
                      color=c, linestyle='dashed', linewidth=3,
                      label=t_label, zorder=0)

    t_vals = np.array(t_vals, dtype=np.float32)
    t_vals.sort()
            
    #y_limits = ax.get_ylim()
    x0 = max(-5, t_vals.min() - 25)
    xn = t_vals.max() + 25
    
    y0 = -5
    if xvals.max() < 2000:
        yn = 30
    elif xvals.max() < 5000:
        yn = 40
    else:
        yn = 50
    axins.set_xlim(x0, xn)
    axins.set_ylim(y0, yn)
        
    axins.yaxis.set_visible(False)
    axins.xaxis.set_visible(False)
    mark_inset(ax, axins, loc1=2, loc2=3, fc='k', ec='k')

    ax.legend(loc='upper left')
    ax.set_xlabel('HSP score')
    ax.set_ylabel('-log10 (p-value)')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    tt = ax.set_title(title)
    tt.set_position([0.5, 1.01])
    return fig, []
    
            
def plot_hsp_score_thresholds():
    cache_file = os.path.join(cache_dir, '{}_plot_hsp_score_thresholds.h5'.format(date))
    if not os.path.isfile(cache_file):
        cache_file = collect_hsp_threshold_data(hsp_files_folder, cache_file)
    
    with pd.HDFStore(cache_file, 'r') as hdf:
        for k in hdf.keys():
            if k.endswith('thresholds'):
                continue
            if not k.startswith('/cmm18'):
                continue
            scores = hdf[k]
            thres = hdf[k.replace('data', 'thresholds')]
            _, tool, c1, c2, scoring, _ = k.split('/')
            plt_title = 'HSP score thresholds: {} vs {} ({} {} scoring)'.format(c1, c2, tool.upper(), scoring)
            fig, exart = create_score_scatter(scores, thres, plt_title)
            
            if save_figures:
                outname = '{}_fig_X_hsp_thresholds_{}_{}_{}_vs_{}'.format(date, tool, scoring, c1, c2)
                out_svg = os.path.join(fig_collect, outname + '.svg')
                fig.savefig(out_svg, bbox_inches='tight', extra_artists=exart)
                out_pdf = os.path.join(fig_collect, outname + '.pdf')
                fig.savefig(out_pdf, bbox_inches='tight', extra_artists=exart)
                out_png = os.path.join(fig_collect, outname + '.png')
                fig.savefig(out_png, bbox_inches='tight', extra_artists=exart, dpi=300)
            plt.close(fig)
    return 0
     
    
if run_plot_score_thresholds:
    plot_hsp_score_thresholds()
