In [None]:
%matplotlib inline

import os as os
import numpy as np
import numpy.random as rng
import json as js
import csv as csv
import pickle as pck
import collections as col
import operator as op
import functools as fnt
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from scipy.stats import mannwhitneyu as mwu

sns.set(style='white',
        font_scale=1.25,
        rc={'font.family': ['sans-serif'],
            'font.sans-serif': ['DejaVu Sans']})

cache_dir = '/home/pebert/.jupyter/cache'

conf_folder = '/home/pebert/work/code/mpggit/crossspecies/graphics'

plot_labels = js.load(open(os.path.join(conf_folder, 'labels', 'cs_labels.json'), 'r'))
plot_colors = js.load(open(os.path.join(conf_folder, 'colors', 'cs_colors.json'), 'r'))
plot_shapes = js.load(open(os.path.join(conf_folder, 'shapes', 'cs_shapes.json'), 'r'))

run_lola_enrich = True

show_figures = True
save_figures = True

project_base = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/crossspecies'
supp_folder = os.path.join(project_base, 'supplement')
fig_folder = os.path.join(project_base, 'figures', 'pub')

unaln_enrich = [('hg19_promoter_custom_unaln-genes_LOLA.tsv', 'main', '3C'),
                ('hg19_promoter_core_unaln-genes_LOLA.tsv', 'main', '3B'),
                ('mm9_promoter_custom_unaln-genes_LOLA.tsv', 'supp', '3B')]

uniq_enrich = [('hg19_promoter_custom_hepa-uniq-tp-genes_LOLA.tsv', 'main', 'X'),
               ('hg19_promoter_custom_ncd4-uniq-tp-genes_LOLA.tsv', 'main', 'X'),
               ('hg19_promoter_custom_esc-uniq-tp-genes_LOLA.tsv', 'main', '4B'),
               ('hg19_promoter_core_hepa-uniq-tp-genes_LOLA.tsv', 'main', 'X'),
               ('hg19_promoter_core_ncd4-uniq-tp-genes_LOLA.tsv', 'main', 'X'),
               ('hg19_promoter_core_esc-uniq-tp-genes_LOLA.tsv', 'main', '4A'),
               ('mm9_promoter_custom_liver-uniq-tp-genes_LOLA.tsv', 'supp', 'X'),
               ('mm9_promoter_custom_ncd4-uniq-tp-genes_LOLA.tsv', 'supp', 'X'),
               ('mm9_promoter_custom_esc-uniq-tp-genes_LOLA.tsv', 'supp', '4')]

def exec_lola_enrich(all_enrichments, tissue_label):
    
    for lola_file, hierarch, number in all_enrichments:
        fn = os.path.basename(lola_file)
        lola_analysis = fn.split('.')[0]
        lola_analysis = lola_analysis.replace('_LOLA', '')
        records, num_tests = read_enrichments(lola_file)
        if 'mm9' in lola_file:
            spec_color = plot_colors['species']['mouse']['rgb']
            cmap = make_color_map('mmu_pv', (1, 1, 1), spec_color, 6, False)
            assm = 'mm9'
        elif 'hg19' in lola_file:
            spec_color = plot_colors['species']['human']['rgb']
            cmap = make_color_map('hsa_pv', (1, 1, 1), spec_color, 6, False)
            assm = 'hg19'
        else:
            raise ValueError(lola_file)
        plot_title = 'LOLA enriched region sets: {}'.format(lola_analysis)
        
        fig, exart = plot_enrich_barchart(records, cmap, make_sig_steps(num_tests, 'qValueLog'),
                                          'qValueLog', plot_title, tissue_label, assm)
        
        if save_figures:
            outpath = os.path.join(fig_folder, hierarch, 'fig_{}_{}_lola_{}.svg'.format(number,
                                                                                         hierarch,
                                                                                         lola_analysis))
            fig.savefig(outpath, bbox_extra_artists=exart, bbox_inches='tight')
            outpath = outpath.replace('.svg', '.png')
            fig.savefig(outpath, bbox_extra_artists=exart, bbox_inches='tight', dpi=300)
    return True


def make_color_map(name, lowest, highest, levels, show=False):
    """
    """
    new_cmap = LinearSegmentedColormap.from_list(name, [lowest, highest], N=levels)
    if show:
        gradient = np.linspace(0, 1, levels)
        gradient = np.vstack((gradient, gradient))
        fig, ax = plt.subplots(nrows=1)
        ax.imshow(gradient, aspect='auto', cmap=new_cmap, interpolation='nearest')
    return new_cmap


def make_sig_steps(num_tests, measure):
    if measure == 'pValueLog':
        pv = np.array([0.05, 0.01, 0.001, 0.0001, 0.00001, 0.000001], dtype=np.float32)
        adj_pv = pv / num_tests
        log_pv = np.log10(adj_pv)
        log_pv *= -1
        pv_labels = ['0.05', '10e-2', '10e-3', '10e-4', '10e-5', '10e-6']
        infos = [(val, lab) for val, lab in zip(log_pv, pv_labels)]
    elif measure == 'qValueLog':
        qv = np.array([0.05, 0.01, 0.001, 0.0001, 0.00001, 0.000001], dtype=np.float32)
        log_qv = np.log10(qv)
        log_qv *= -1
        qv_labels = ['0.05', '10e-2', '10e-3', '10e-4', '10e-5', '10e-6']
        infos = [(val, lab) for val, lab in zip(log_qv, qv_labels)]
    else:
        raise ValueError('Unexpected measure: {}'.format(measure))
    return infos


def read_enrichments(fpath):
    """
    """
    records = []
    num_tests = 0
    with open(fpath, 'r', newline='') as table:
        rows = csv.DictReader(table, delimiter='\t')
        for row in rows:
            if any([int(row['b']) == 0, int(row['c']) == 0, int(row['d']) == 0]):
                continue
            num_tests += 1
            if int(row['support']) < 5:
                continue
            # special: unclear what this stuff is, no tissue or cell type specified
            if row['description'] == 'sheffield_dnase':
                continue
            tmp = {k: row[k] for k in ['logOddsRatio', 'pValueLog', 'description',
                                       'tissue', 'cellType', 'antibody', 'support',
                                       'collection', 'dbSet', 'qValue']}
            tmp['qValueLog'] = -1 * np.log10(float(tmp['qValue']))
            records.append(tmp)
    records = sorted(records, key=lambda d: float(d['logOddsRatio']), reverse=True)
    records = records
    return records, num_tests


def annotate_items(records, cmap, sig_levels, tissue, assembly, measure):
    """
    """
    plot_records = col.defaultdict(list)
    for rec in records:
        sig = float(rec[measure])
        if sig < sig_levels[0][0]:
            # - comparison in log space
            # - not significant at 5pct level
            continue
        for idx, (_, siglab) in enumerate(sig_levels):
            try:
                if sig <= sig_levels[idx+1][0]:
                    c = cmap(idx)
                    l = siglab
                    break
            except IndexError:
                c = cmap(idx)
                l = siglab
        rec['color'] = c
        rec['siglabel'] = l
        if rec['collection'] == 'ucsc_features':
            desc = rec['description']
        else:
            desc = rec['description'].split()[0]
        if desc == 'UCSC':
            desc = rec['description'].replace('UCSC ', '')
        if desc == 'DNase':
            desc = 'ENCODE ' + rec['description']
        if rec['description'] == 'Mouse Open Regulatory Annotation':
            desc = 'OpenRegAnno'
        if rec['description'].startswith('Mouse Fantom'):
            desc = 'Fantom enhancer'
        if desc == 'CpG':
            desc = rec['description']
        if desc.startswith('PhastCons'):
            desc = 'PhastCons elements'
        if rec['cellType'] in ['NA', 'whole-sample', 'adlt']:
            if rec['tissue'] == 'NA' or rec['tissue'] == 'whole-sample':
                label = desc
            else:
                label = rec['tissue']
        else:
            label = rec['cellType']
            if tissue and rec['tissue'] not in ['NA', 'whole-sample', 'adlt']:
                label = rec['tissue'] + '_' + rec['dbSet']
                if 'dnase' in rec['collection']:
                    label += '_DNase'
        if label.startswith('metastatic prostate cancer'):
            label = 'prostate-cancer'
        if label.startswith('embryonic'):
            label += '_' + rec['cellType']
        if rec['antibody'] != 'NA':
            label += '_' + rec['antibody']
        if rec['collection'] == 'sheffield_dnase':
            label = simplify_sheffield(rec['description'], rec['dbSet'])
        if rec['collection'] == 'encode_segmentation':
            label += '_' + desc
        if assembly == 'mm9' and rec['collection'] == 'encode_dnase':
            label = '{}_{}_DNase'.format(rec['tissue'], rec['dbSet'])
        if rec['dbSet'] in label:
            lut_label = label.replace(rec['dbSet'] + '_', '')
            plot_records[(lut_label, l)].append(rec)
        else:
            plot_records[(label, l)].append(rec)
    records, y_limit = merge_entries(plot_records)
    return records, y_limit


def merge_entries(rec_collect):
    """
    """
    plot_records = []
    ylim = 0
    for k, v in rec_collect.items():
        label, siglabel = k
        if label == 'esc_218':
            print(v)
            raise
        if len(v) == 1:
            rec = v[0]
            oddsratio = np.float16(rec['logOddsRatio'])
            color = rec['color']
            plot_records.append((oddsratio, np.float16(0), label, color))
        else:
            err = np.std([np.float16(d['logOddsRatio']) for d in v])
            oddr = np.mean([np.float16(d['logOddsRatio']) for d in v])
            plot_records.append((oddr, err, label, v[0]['color']))
            if oddr + err > ylim:
                ylim = oddr + err
    plot_records = sorted(plot_records, key=lambda x: x[0], reverse=True)
    return plot_records[:min(30, len(plot_records))], ylim


def simplify_sheffield(description, dbset):
    """
    """
    d = description.lower()
    if ';' not in d:
        if 'stem' in d:
            return 'stem_var_{}_DNase'.format(dbset)
        if 'sk-n-sh' in d:
            d = d.replace('sk-n-sh', 'SKnSH')
        if d == 'hematopoietic':
            return 'hema_{}_DNase'.format(dbset)
        return '{}_{}_DNase'.format(d, dbset)
    label = ''
    if 'weak' in d:
        label += 'weak_'
        d = d.replace('weak', '')
    if 'stem' in d:
        label += 'stem_'
        d = d.replace('stem', '')
    if 'hematopoietic' in d:
        label += 'hema_'
        d = d.replace('hematopoietic', '')
    if 'liver' in d:
        label += 'liver_'
        d = d.replace('liver', '')
    if len(d) < 7:
        label += '{}_DNase'.format(dbset)
    else:
        label += 'var_{}_DNase'.format(dbset)
    return label
    


def plot_enrich_barchart(records, color_map, sig_levels, measure, title, tissue, assembly):
    """
    """        
    records, y_limit = annotate_items(records, color_map, sig_levels, tissue, assembly, measure)
    
    oddsratios = [r[0] for r in records]
    error_high = [r[1] for r in records]
    labels = [r[2] for r in records]
    colors = [r[3] for r in records]
    
    if len(oddsratios) < 10:
        fig, ax = plt.subplots(figsize=(8, 6))
    else:    
        fig, ax = plt.subplots(figsize=(10, 8))
    extra_artists = []
    ax.set_xlim(0, len(oddsratios) + 0.5)
    ylim = np.ceil(max(oddsratios[0], y_limit))
    ax.set_ylim(0, ylim)
    
    bar_width = 0.5
    x_ticks = np.arange(0.5, len(oddsratios), 0.75)
        
    x_ticks = x_ticks[:len(oddsratios)]
    ax.set_xlim(0, x_ticks.max() + 0.5)
    
    ax.bar(x_ticks, oddsratios, bar_width, color=colors,
           yerr=[[0] * len(error_high), error_high],
           edgecolor='black', linewidth=1)
    
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(labels, rotation='vertical')
    
    handles, labels = ax.get_legend_handles_labels()
    
    bar_shading = []
    for idx, (val, lab) in enumerate(sig_levels):
        c = color_map(idx)
        bar_shading.append(mpatches.Patch(facecolor=c, label=lab, edgecolor='black'))

    leg_pv = plt.legend(handles=bar_shading, loc=1, title='adj. p-val < x',
                        fontsize=12, bbox_to_anchor=(1.0, 0.975))
        
    ax.set_ylabel('odds ratio', fontsize=14)
    ax.set_title(title, fontsize=16)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    return fig, [leg_pv]


if run_lola_enrich:
    unaln_genes = [(os.path.join(supp_folder, 'lola', fn[0]), fn[1], fn[2]) for fn in unaln_enrich]
    exec_lola_enrich(unaln_genes, False)
    uniq_genes = [(os.path.join(supp_folder, 'lola', fn[0]), fn[1], fn[2]) for fn in uniq_enrich]
    exec_lola_enrich(uniq_genes, True)
    print('LOLA enrichment plotted')