In [67]:
%matplotlib inline

import os as os
import collections as col
import pickle as pck

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scipy.stats as stat

date = '20180322'

run_plot_gumbel_fit = True

save_figures = True

sns.set(style='white',
        font_scale=1.5,
        rc={'font.family': ['sans-serif'],
            'font.sans-serif': ['DejaVu Sans']})

np.seterr(all='raise')

fhgfs_base = '/TL/deep/fhgfs/projects/pebert/thesis/projects/statediff'
cache_dir = os.path.join(fhgfs_base, 'caching/notebooks')

hsp_files_folder = os.path.join(fhgfs_base, 'solidstate/deep')

base_out = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/statediff'
fig_supp = os.path.join(base_out, 'figures', 'pub', 'supp')
fig_main = os.path.join(base_out, 'figures', 'pub', 'main')
fig_collect = os.path.join(base_out, 'figures', 'pub', 'collection')

def adapt_color(triple):
    x = triple.strip('rgb()')
    x = x.split(',')
    x = [round(int(c) / 255., 2) for c in x]
    return tuple(x)

four_colors = ['rgb(215,25,28)','rgb(253,174,97)',
               'rgb(171,217,233)','rgb(44,123,182)']
four_colors = list(map(adapt_color, four_colors))
six_colors = ['rgb(215,48,39)','rgb(252,141,89)','rgb(254,224,144)',
              'rgb(224,243,248)','rgb(145,191,219)','rgb(69,117,180)']
six_colors = list(map(adapt_color, six_colors))

color_scheme = {4: four_colors, 6: six_colors}


def collect_block_maxima():
    hsp_run_root = os.path.join(fhgfs_base, 'solidstate/deep')
    block_hsps = dict()
    for root, dirs, datafiles in os.walk(hsp_run_root):
        if root.endswith('hsp_run'):
            for df in datafiles:
                fpath = os.path.join(root, df)
                segment = root.split('/')[-2]
                parts = df.split('.')[0].split('_')
                c1, c2 = parts[5], parts[8]
                datasets = col.defaultdict(list)
                with pd.HDFStore(fpath, 'r') as hdf:
                    for k in hdf.keys():
                        if k.endswith('/thresholds'):
                            continue
                        if k.startswith('/sstrun'):
                            _, _, scoring, chrom = k.split('/')
                            data = hdf[k]
                            data['chrom'] = chrom
                            datasets[scoring].append(data)
                for scoring, subsets in datasets.items():
                    collector = col.defaultdict(list)
                    hsps = pd.concat(subsets, axis=0, ignore_index=False)
                    hsps = hsps.groupby(['sample1', 'sample2', 'chrom'])['norm_nat_score'].max()
                    for (s1, s2, chrom), v in hsps.to_dict().items():
                        collector[(s1, s2)].append(v)
                    block_hsps[(c1, c2, segment, scoring)] = collector
    return block_hsps
                    

def make_probplot(groups, colors, figtitle):
    """
    """
    merged = []
    leg_labels = dict()
    for idx, (pair, vals) in enumerate(groups.items()):
        p1, p2 = pair
        p1 = p1[7:9] + p1[3]
        p2 = p2[7:9] + p2[3]
        leg_labels[idx] = p1 + ' v ' + p2
        for v in vals:
            merged.append((v, colors[idx], idx))
    merged = sorted(merged)
    merged_vals = [t[0] for t in merged]
    merged_colors = [t[1] for t in merged]
        
    fig, ax = plt.subplots(figsize=(10, 10))

    xlabel = 'Theoretical quantiles'
    ylabel = 'Observed values'

    est_param = stat.gumbel_r.fit(merged_vals)
    (osm, osr), (slope, inter, cod) = stat.probplot(merged_vals, est_param,
                                                    dist='gumbel_r', fit=True,
                                                    plot=ax)
    # markers
    ax.get_lines()[0].set_visible(False)
    # trendline
    ax.get_lines()[1].set_color('black')
    
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(figtitle)
    
    point_clouds = dict()
    for i in range(len(colors)):
        point_clouds[i] = {'x': [], 'y': []}
    for x, (y, c, i) in zip(osm, merged):
        point_clouds[i]['x'].append(x)
        point_clouds[i]['y'].append(y)
        
    for idx in point_clouds.keys():
        c = colors[idx]
        x_vals = point_clouds[idx]['x']
        y_vals = point_clouds[idx]['y']
        ax.scatter(x_vals, y_vals, edgecolors='black',
                   marker='o', color=c, s=100, zorder=3,
                   label=leg_labels[idx])

    ax.legend(loc='lower right')
    
    cod = str(np.round(cod, 3))
    ax.text(0.1, 0.9, '$R$ = ' + cod, fontsize=16, transform=ax.transAxes)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
           
    return fig, []
    
    
def plot_hsp_gumbel_fit():
    cache_file = os.path.join(cache_dir, '{}_plot_gumbel_hsps.pck'.format(date))
    if os.path.isfile(cache_file):
        with open(cache_file, 'rb') as cache:
            hsps = pck.load(cache)
    else:
        hsps = collect_block_maxima()
        with open(cache_file, 'wb') as cache:
            pck.dump(hsps, cache)
    for (c1, c2, seg, score), values in hsps.items():
        colors = color_scheme[len(values.keys())]
        fig_title = '{} vs {} Gumbel fit to HSP scores for {} / {}'.format(c1, c2, seg, score)
        fig, exart = make_probplot(values, colors, fig_title)
        if save_figures:
            outname = '{}_fig_X_hsp_gumbelfit_{}_{}_{}-vs-{}'.format(date, seg, score, c1, c2)
            out_svg = os.path.join(fig_collect, outname + '.svg')
            fig.savefig(out_svg, bbox_inches='tight', extra_artists=exart)
            out_pdf = os.path.join(fig_collect, outname + '.pdf')
            fig.savefig(out_pdf, bbox_inches='tight', extra_artists=exart)
            out_png = os.path.join(fig_collect, outname + '.png')
            fig.savefig(out_png, bbox_inches='tight', extra_artists=exart, dpi=300)
        plt.close(fig)
    return 0
        
    
if run_plot_gumbel_fit:
    plot_hsp_gumbel_fit()