In [12]:
%matplotlib inline

import os as os
import collections as col
import itertools as itt
import pickle as pck
import json as js
import time as ti

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

import numpy as np
import numpy.random as rng
import scipy.stats as stats
import pandas as pd
import seaborn as sns

# What does this do?
# Plot boxplots of minimum TPM
# log2 fold-change for genes
# stratified by HSP hits
# (body/enhancer)

date = '20180709'

run_plot_gene_enh_hsp = False
run_dump_genetrail_lists = True

save_figures = True

journal = 'ox_bioinf'
res = 'print_lo'
fig_sizes = js.load(open('/home/pebert/work/code/mpggit/statediff/annotation/misc/fig_sizes.json'))
fig_sizes = fig_sizes[journal]
resolution = fig_sizes['resolution']

segmentations = ['cmm18', 'esc18', 'ecs10']
scorings = ['emission', 'replicate']

sns.set(style='white',
        font_scale=1.5,
        rc={'font.family': ['sans-serif'],
            'font.sans-serif': ['DejaVu Sans']})

fhgfs_base = '/TL/deep/fhgfs/projects/pebert/thesis/projects/statediff'
cache_dir = os.path.join(fhgfs_base, 'caching/notebooks')

isect_folder = os.path.join(fhgfs_base, 'bedtools/deep/inv_isect')
de_gene_folder = os.path.join(fhgfs_base, 'deseq/deep')
bed_gene_folder = os.path.join(fhgfs_base, 'deseq/bed_out')

genetrail_out = os.path.join(fhgfs_base, 'genetrail/input_lists', date)

tpm_file = os.path.join(fhgfs_base, 'salmon', 'deep', 'agg_gene_tpm.h5')

base_out = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/statediff'

fig_supp = os.path.join(base_out, 'figures', 'pub', 'supp')
fig_main = os.path.join(base_out, 'figures', 'pub', 'main')
fig_collect = os.path.join(base_out, 'figures', 'pub', 'collection')
                   

def load_gene_table(fpath, genetype):
    
    genes = dict()
    with open(fpath, 'r') as table:
        _ = table.readline()
        for line in table:
            cols = line.strip().split('\t')
            try:
                name, symbol = cols[3].split('.')[0], cols[6]
            except IndexError:
                break
            log2fc = cols[4]
            try:
                gene = genes[name]
            except KeyError:
                gene = {'name': name, 'symbol': symbol,
                        'log2fc': float(log2fc), 'genetype': genetype,
                        'body_hsp': 0}
                genes[name] = gene
            if int(cols[-1]) > 0:
                gene['body_hsp'] += 1
    df = pd.DataFrame.from_records(list(genes.values()), index='name')
    return df


def load_enhancer_table(fpath):
    
    genes = dict()
    with open(fpath, 'r') as table:
        _ = table.readline()
        for line in table:
            cols = line.strip().split('\t')
            try:
                enh_id, ens_id, symbol, intra = cols[3], cols[6], cols[7], int(cols[10])
            except IndexError:
                break
            try:
                gene = genes[ens_id]
            except KeyError:
                gene = {'name': ens_id, 'enhancer': set(), 'enh_hits': set()}
                genes[ens_id] = gene
            gene['enhancer'].add((enh_id, intra))
            if int(cols[-1]) > 0:
                gene['enh_hits'].add((enh_id, intra))
                
    for k, v in genes.items():
        v['total_enh'] = len(v['enhancer'])
        v['intra_enh'] = sum([t[1] for t in v['enhancer']])
        v['inter_enh'] = v['total_enh'] - v['intra_enh']
        v['total_hits'] = len(v['enh_hits'])
        v['intra_hits'] = len([t for t in v['enh_hits'] if t[1] == 1])
        v['inter_hits'] = v['total_hits'] - v['intra_hits']
        del v['enhancer']
        del v['enh_hits']
        
    df = pd.DataFrame.from_records(list(genes.values()), index='name')
    return df
    
                
def load_expression_data(s1, s2):
    
    with pd.HDFStore(tpm_file, 'r') as hdf:
        load_keys = [k for k in hdf.keys() if k.endswith(s1) or k.endswith(s2)]
        assert len(load_keys) == 2, 'data missing'
        data = pd.concat([hdf[load_keys[0]], hdf[load_keys[1]]], axis=1, ignore_index=False)
        data['min_tpm'] = data.min(axis=1)
        data = data.loc[:, ['min_tpm']].copy()
    return data
        
    
def cache_overlap_data(rootfolder, cache_file):
    """
    """
    
    defiles = [f for f in os.listdir(rootfolder) if f.startswith('deep_degenes_') and f.endswith('.tsv')]
    defiles = sorted(defiles)
    
    for defile in defiles:
        depath = os.path.join(rootfolder, defile)
        de_genes = load_gene_table(depath, 'diff')
        stpath = depath.replace('_degenes_', '_stgenes_')
        st_genes = load_gene_table(stpath, 'stable')
        all_genes = pd.concat([de_genes, st_genes], axis=0, ignore_index=False)
        enhpath = depath.replace('_degenes_', '_enh_')
        enh_genes = load_enhancer_table(enhpath)
        parts = defile.split('.')[0].split('_')
        seg, comp, scoring = parts[4], parts[5] + '_vs_' + parts[7], parts[8]
        
        tpm_data = load_expression_data(parts[5], parts[7])
        tpm_data = tpm_data.loc[tpm_data.index.isin(all_genes.index), :]
        
        joined = pd.concat([all_genes, enh_genes, tpm_data], axis=1,
                           ignore_index=False, sort=True)
        joined.fillna(0, inplace=True)
        
        store_path = os.path.join(seg, scoring, comp)
        with pd.HDFStore(cache_file, 'a') as hdf:
            hdf.put(store_path, joined, format='table')
            
    return cache_file
    
    
def prepare_plot_data(data, data_type):
    
    boxes = []
    box_labels = []
    # genes w/o body hit and no enhancer hit
    select = np.logical_and(data['body_hsp'] == 0, data['total_hits'] == 0)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:0  E:0 ({})'.format(sub.size))
    
    # genes w/o body hit AND at least 1 enhancer hit
    select = np.logical_and(data['body_hsp'] == 0, data['total_hits'] == 1)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:0  E:1 ({})'.format(sub.size))
    
    select = np.logical_and(data['body_hsp'] == 0, data['total_hits'] == 2)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:0  E:2 ({})'.format(sub.size))
    
    select = np.logical_and(data['body_hsp'] == 0, data['total_hits'] >= 3)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:0  E:3+ ({})'.format(sub.size))
    
    # genes with at least one body hit
    select = np.logical_and(data['body_hsp'] > 0, data['total_hits'] == 0)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:1+  E:0 ({})'.format(sub.size))
    
    # genes with at least one body hit
    select = np.logical_and(data['body_hsp'] > 0, data['total_hits'] == 1)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:1+  E:1 ({})'.format(sub.size))
    
    # genes with at least one body hit
    select = np.logical_and(data['body_hsp'] > 0, data['total_hits'] == 2)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:1+  E:2 ({})'.format(sub.size))
    
    # genes with at least one body hit
    select = np.logical_and(data['body_hsp'] > 0, data['total_hits'] >= 3)
    sub = data.loc[select, data_type].abs()
    boxes.append(sub)
    box_labels.append('B:1+  E:3+ ({})'.format(sub.size))
    
    return boxes, box_labels
    
    
def generate_bins(dataset, stepsize):
    """
    """
    fc_min = dataset['log2fc'].min()
    assert fc_min < 0, 'That is unexpected...'
    fc_lower = (fc_min // stepsize) * stepsize - stepsize
    fc_max = dataset['log2fc'].max()
    fc_upper = (fc_max // stepsize) * stepsize + stepsize
    bins = np.arange(fc_lower, fc_upper + 0.1, stepsize)
    bin_num = [i for i, b in enumerate(bins, start=1)]
    return bins, bin_num
    

def compute_densities(dataset, bins, hit_type):
    
    histograms = []
    labels = []
    lo_bounds = [0, 2, 5, 8, 11]
    hi_bounds = [0, 3, 6, 9, 1000]
    # no overlap
    for lo, hi in zip(lo_bounds, hi_bounds):
        select = np.logical_and(dataset[hit_type] >= lo, dataset[hit_type] <= hi)
        group = dataset.loc[select, 'log2fc']
        if group.size == 0:
            hist = np.zeros_like(histograms[0])
            histograms.append(hist)
        else:
            hist, bin_edges = np.histogram(group, bins, density=True)
            histograms.append(np.concatenate([[0], hist]))
        if lo == 0:
            labels.append('0: {}'.format(group.size))
        elif lo == max(lo_bounds):
            labels.append('{}+: {}'.format(lo, group.size))
        else:
            labels.append('{}-{}: {}'.format(lo, hi, group.size))
    
    return histograms, labels
    
    
    
def create_cdf_plot(datapoints, comp, fkey):
    
    fig, axes = plt.subplots(figsize=fig_sizes['two_col']['square'],
                             nrows=3, ncols=2,
                             sharex=False, sharey=False)
    stepsize = 0.5
    bins, bin_nums = generate_bins(datapoints, stepsize)
    
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    
    data_colors = ['dimgrey', 'dodgerblue', 'limegreen', 'orange', 'magenta']
    
    # for all enhancers
    total_hist, total_labels = compute_densities(datapoints, bins, 'total_hits')
    # for intragenic
    intra_hist, intra_labels = compute_densities(datapoints, bins, 'intra_hits')
    # for intragenic
    inter_hist, inter_labels = compute_densities(datapoints, bins, 'inter_hits')
    row_title = {0: 'anywhere', 1: 'intragenic', 2: 'intergenic'}
    
    s1, s2 = comp.split('_vs_')
    for row_idx, (data, lab) in enumerate([(total_hist, total_labels),
                                           (intra_hist, intra_labels),
                                           (inter_hist, inter_labels)]):
        ax = axes[row_idx, 0]
        if row_idx == 0:
            tt = ax.set_title('HSP gene enhancer hits: {} v {}'.format(s1, s2), fontsize=14)
            tt.set_position((1, 1))
        
        leg_handles = []
        for c, hist, l in zip(data_colors, data, lab):
            if c == 'dimgrey':
                style = 'dashed'
                zorder = 0
            else:
                style = 'solid'
                zorder = 2
            assert np.isclose(hist.sum() * stepsize, 1., atol=1e-6), 'CDF not 1'
            ax.plot(bins, np.cumsum(hist) * stepsize,
                    lw=2, ls=style, c=c, zorder=zorder,
                    label=l)
#             ax.axhline(0.75, 0, 1, c='lightgrey',
#                        ls='dashed', zorder=0, lw=1)
#             ax.axhline(0.5, 0, 1, c='lightgrey',
#                        ls='dashed', zorder=0, lw=1)
#             ax.axhline(0.25, 0, 1, c='lightgrey',
#                        ls='dashed', zorder=0, lw=1)
            
            ax.axvline(-3, 0, 1, c='black', alpha=0.5,
                       ls='dashed', zorder=0, lw=1)
            ax.axvline(3, 0, 1, c='black', alpha=0.5,
                       ls='dashed', zorder=0, lw=1)
            leg_patch = mlines.Line2D([], [], marker='s', markersize=7,
                                      lw=0, color=c, label=l)
            leg_handles.append(leg_patch)
        
        if comp == 'HG_vs_Mo':
            ax_leg = ax.legend(loc='right', handles=leg_handles,
                               fontsize=10, bbox_to_anchor=(1.05, 0.4))
        else:
            ax_leg = ax.legend(loc='best', handles=leg_handles,
                               fontsize=10)
        tt = ax_leg.set_title(row_title[row_idx], prop={'size': 12})
        
        if row_idx == 1:
            ax.set_ylabel('Cumulative probability', fontsize=14)
                
        if row_idx == 2:
            neg_ticks = np.arange(0, bins.min(), -10)[::-1]
            pos_ticks = np.arange(0, bins.max(), 10)

            x_ticks = np.concatenate([neg_ticks, pos_ticks[1:]])
            x_ticks = x_ticks.astype(np.int8)
            ax.set_xticks(x_ticks)
            ax.set_xticklabels(map(str, x_ticks), fontsize=12)
            tt = ax.set_xlabel('Gene expression fold change (log2)', fontsize=14)
            tt.set_position((1, 0))
        else:
            ax.set_xticks([])
            
        ax.set_xlim(bins.min(), bins.max())

        y_ticks = [0., 0.5, 1]
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(map(str, y_ticks), fontsize=12)
        ax.set_ylim(-0.05, 1.05)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

            
        ax = axes[row_idx, 1]
        
        for c, hist in zip(data_colors, data):
            if c == 'dimgrey':
                style = 'dashed'
                zorder = 0
            else:
                style = 'solid'
                zorder = 2
            assert np.isclose(hist.sum() * stepsize, 1., atol=1e-6), 'CDF not 1'
            ax.plot(bins, np.cumsum(hist) * stepsize,
                    lw=2, ls=style, c=c, zorder=zorder)
            ax.axhline(0.75, 0, 1, c='lightgrey',
                       ls='dashed', zorder=0, lw=1)
            ax.axhline(0.5, 0, 1, c='lightgrey',
                       ls='dashed', zorder=0, lw=1)
            ax.axhline(0.25, 0, 1, c='lightgrey',
                       ls='dashed', zorder=0, lw=1)
            
            ax.axvline(-3, 0, 1, c='black', alpha=0.5,
                       ls='dashed', zorder=0, lw=1)
            ax.axvline(3, 0, 1, c='black', alpha=0.5,
                       ls='dashed', zorder=0, lw=1)
            
        ax.set_yticks([])
        if row_idx == 2:
            x_ticks = [-3, 0, 3]
            ax.set_xticks(x_ticks)
            ax.set_xticklabels(map(str, x_ticks), fontsize=12)
        else:
            ax.set_xticks([])
        
        ax.set_xlim(-5, 5)
        #ax.set_ylim(-0.05, 1.05)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    return fig, []
    
    
def create_boxplot(datapoints, comp, fkey):
    """
    """
    boxcolor = 'dimgrey'
    medcolor = 'grey'
    median_props = {'color': medcolor, 'linewidth': 2}
    box_props = {'color': boxcolor, 'linewidth': 1}
    whisker_props = {'color': boxcolor, 'linewidth': 1}
    cap_props = {'color': boxcolor, 'linewidth': 1}
        
    
    fig, axes = plt.subplots(figsize=fig_sizes['two_col']['square'],
                             nrows=2, ncols=2, sharex=False, sharey=False,
                             gridspec_kw={'height_ratios': [1, 2]})
    data_lut = {(0, 0): ('diff', 'log2fc'), (0, 1): ('stable', 'log2fc'),
                (1, 0): ('diff', 'min_tpm'), (1, 1): ('stable', 'min_tpm')}
        
    plt.subplots_adjust(hspace=0.05, wspace=0.2)
    s1, s2 = comp.split('_vs_')
    for row in range(2):
        for col in range(2):
            ax = axes[row, col]
            gene_type, data_type = data_lut[(row, col)]
            
            subset = datapoints.loc[datapoints['genetype'] == gene_type, :].copy()
            
            if row == 0 and col == 0:
                ax.set_title('{} v {}\nDEG N={}'.format(s1, s2, subset.shape[0]),
                             fontsize=14)
                ax.set_ylabel('abs ( log2 fc )', fontsize=14)
            elif row == 0 and col == 1:
                ax.set_title('{} v {}\nstable N={}'.format(s1, s2, subset.shape[0]),
                             fontsize=14)
                ax.axhline(2, 0, 1, lw=2, ls='dashed',
                           color='lightgrey', zorder=0)
            else:
                pass
            
            if row == 1 and col == 0:
                ax.set_ylabel('minimum expression (TPM)', fontsize=14)
                xl_pos = ax.set_xlabel('HSP hits in gene bodies (B) or enhancers (E)', fontsize=14)
                xl_pos.set_position((1, 0))
                                    
            boxes, boxlabels = prepare_plot_data(subset, data_type)
                        
            ax.set_xlim(0, len(boxes) + 1)
                       
            bb = ax.boxplot(boxes, sym="", labels=None, widths=0.35,
                            medianprops=median_props, boxprops=box_props,
                            whiskerprops=whisker_props, capprops=cap_props)
            max_y = 0
            for cap in bb['caps']:
                max_y = max(max_y, cap.get_ydata()[0])
            if max_y < 12:
                ax.set_ylim(-0.5, max_y // 2 * 2 + 2)
                yticks = np.arange(0, np.ceil(max_y), 3, dtype=np.int8)
            else:
                ax.set_ylim(-0.5, max_y // 5 * 5 + 5)
                yticks = np.arange(0, np.ceil(max_y), 5, dtype=np.int8)
            
            ax.set_yticks(yticks)
            ax.set_yticklabels(map(str, yticks), fontsize=12)
            
            if row == 0:
                ax.set_xticks([])
            if row == 1:
                ax.set_xticklabels(boxlabels, fontsize=12, rotation=90)
            if col == 0:
                ax.axvline(1, 0, 1, lw=25, ls='solid', color='red',
                           alpha=0.1, zorder=0)
            
    return fig, []


def dump_genetrail_lists():
    cache_file = os.path.join(cache_dir, '{}_plot_gene-enh-hsp_tpm-bin.h5'.format(date))
    if not os.path.isfile(cache_file):
        raise RuntimeError('No cache file detected')
    dump_folder = genetrail_out + '_miRNA'
    os.makedirs(dump_folder, exist_ok=True)
        
    with pd.HDFStore(cache_file, 'r') as hdf:
        cached_keys = list(hdf.keys())
        for seg in segmentations:
            if seg != 'cmm18':
                continue
            for scoring in scorings:
                if scoring != 'emission':
                    continue
                load_keys = [k for k in cached_keys if seg in k and scoring in k]
                
                for k in load_keys:
                    comp = k.split('/')[-1]
                    if comp not in ['Ma_vs_Mo', 'He_vs_Mo']:
                        continue
                    print('Dumping GeneTrail lists for {}'.format(comp))
                    data = hdf[k]
                    
                    base_file = '_'.join([seg, scoring, comp, '{}']) + '.tsv'
                    deg = data.loc[data['genetype'] == 'diff', :].copy()
                    
                    select_miss = np.logical_and(deg['body_hsp'] == 0,
                                                 deg['total_hits'] == 0)
                    miss = deg.loc[select_miss, :]
                    miss_file = os.path.join(dump_folder, base_file.format('deg-miss'))
                    with open(miss_file, 'w') as dump:
                        _ = dump.write('\n'.join(miss.index.tolist()) + '\n')
                    
                    select_hit = np.logical_or(deg['body_hsp'] > 0,
                                               deg['total_hits'] > 0)
                    hit = deg.loc[select_hit, :]                    
                    hit_file = os.path.join(dump_folder, base_file.format('deg-hit'))
                    with open(hit_file, 'w') as dump:
                        _ = dump.write('\n'.join(hit.index.tolist()) + '\n')
                        
                    bmiss = deg.loc[deg['body_hsp'] == 0, :]
                    bmiss_file = os.path.join(dump_folder, base_file.format('deg-body-miss'))
                    with open(bmiss_file, 'w') as dump:
                        _ = dump.write('\n'.join(bmiss.index.tolist()) + '\n')
                    
                    bhit = deg.loc[deg['body_hsp'] > 0, :]
                    bhit_file = os.path.join(dump_folder, base_file.format('deg-body-hit'))
                    with open(bhit_file, 'w') as dump:
                        _ = dump.write('\n'.join(bhit.index.tolist()) + '\n')
                                            
                    stg = data.loc[data['genetype'] == 'stable', :].copy()
                    stg_file = os.path.join(dump_folder, base_file.format('seg-any'))
                    with open(stg_file, 'w') as dump:
                        _ = dump.write('\n'.join(stg.index.tolist()) + '\n')
                        
                    stgmiss = stg.loc[stg['body_hsp'] == 0, :]
                    stgmiss_file = os.path.join(dump_folder, base_file.format('seg-body-miss'))
                    with open(stgmiss_file, 'w') as dump:
                        _ = dump.write('\n'.join(stgmiss.index.tolist()) + '\n')

    return


def print_mirna_info(genes, mirna, cell):
    
    print('Cell ', cell)
    select = np.logical_and(mirna['name'].isin(genes.index),
                            mirna['celltype'] == cell)
    sub = mirna.loc[select, ['norm_count', 'mirna_id']]
    print('Unique miRNAs ', sub['mirna_id'].unique().size)
    dist = np.percentile(sub['norm_count'], [25, 50, 75, 95, 99])
    print('Count distribution ', dist)
    return


def make_mirna_scatter(fg, bg, mirna, s1, s2, fk='X'):
    
    fig, (fg_ax, bg_ax) = plt.subplots(figsize=fig_sizes['two_col']['half'],
                                       nrows=1, ncols=2, sharex=False, sharey=False)
    
    fg_mirna = mirna.loc[mirna['name'].isin(fg['name']), ['name', 'mirna_id', 'norm_count', 'celltype']]
    fg_ids = set(fg_mirna['mirna_id'].unique())
    print('FG ', len(fg_ids))
    fg_s1 = fg_mirna.loc[fg_mirna['celltype'] == s1, ['name', 'norm_count']]
    fg_s1['norm_count'] = np.log10(fg_s1['norm_count'] + 1)
    fg_s1 = fg_s1.merge(fg, on='name', how='outer')
    fg_s1.fillna({'norm_count': 0}, inplace=True)
        
    fg_s2 = fg_mirna.loc[fg_mirna['celltype'] == s2, ['name', 'norm_count']]
    fg_s2['norm_count'] = -1 * np.log10(fg_s2['norm_count'] + 1)
    fg_s2 = fg_s2.merge(fg, on='name', how='outer')
    fg_s2.fillna({'norm_count': 0}, inplace=True)
    
    fg_ax.set_xlim(min(fg_s1['log2fc'].min(), fg_s2['log2fc'].min() - 1),
                   max(fg_s1['log2fc'].max(), fg_s2['log2fc'].max()) + 1)
    fg_ax.set_ylim(fg_s2['norm_count'].min() - 0.5,
                   fg_s1['norm_count'].max() + 0.5)
    
    fg_ax.scatter(fg_s1['log2fc'], fg_s1['norm_count'],
                  c='dodgerblue', alpha=0.7,
                  marker='.', s=20)
    
    fg_ax.scatter(fg_s2['log2fc'], fg_s2['norm_count'],
                  c='orange', alpha=0.7,
                  marker='.', s=20)
    
    
    
    bg_mirna = mirna.loc[mirna['name'].isin(bg['name']), ['name', 'mirna_id', 'norm_count', 'celltype']]
    bg_ids = set(bg_mirna['mirna_id'].unique())
    print('BG ', len(bg_ids))
    bg_s1 = bg_mirna.loc[bg_mirna['celltype'] == s1, ['name', 'norm_count']]
    bg_s1['norm_count'] = np.log10(bg_s1['norm_count'] + 1)
    bg_s1 = bg_s1.merge(bg, on='name', how='outer')
    bg_s1.fillna({'norm_count': 0}, inplace=True)
        
    bg_s2 = bg_mirna.loc[bg_mirna['celltype'] == s2, ['name', 'norm_count']]
    bg_s2['norm_count'] = -1 * np.log10(bg_s2['norm_count'] + 1)
    bg_s2 = bg_s2.merge(bg, on='name', how='outer')
    bg_s2.fillna({'norm_count': 0}, inplace=True)
    
    bg_ax.set_xlim(min(bg_s1['log2fc'].min(), bg_s2['log2fc'].min() - 1),
                   max(bg_s1['log2fc'].max(), bg_s2['log2fc'].max()) + 1)
    bg_ax.set_ylim(bg_s2['norm_count'].min() - 0.5,
                   bg_s1['norm_count'].max() + 0.5)
    
    bg_ax.scatter(bg_s1['log2fc'], bg_s1['norm_count'],
                  c='dodgerblue', alpha=0.7,
                  marker='.', s=20)
    
    bg_ax.scatter(bg_s2['log2fc'], bg_s2['norm_count'],
                  c='orange', alpha=0.7,
                  marker='.', s=20)
    print(len(fg_ids.intersection(bg_ids)))
    
    
    return fig, []


def plot_mirna_changes():
    cache_file = os.path.join(cache_dir, '{}_plot_gene-enh-hsp_tpm-bin.h5'.format(date))
    if not os.path.isfile(cache_file):
        raise RuntimeError('No cache file detected')
        
    mirna_cache = os.path.join(fhgfs_base, 'caching', 'notebooks', '20180711_mirna_exp_targets.h5')
    if not os.path.isfile(mirna_cache):
        raise RuntimeError('No miRNA annotation detected')
    print('Cached data available...')
    
    with pd.HDFStore(cache_file, 'r') as hdf:
        cached_keys = list(hdf.keys())
        for seg in segmentations:
            if seg != 'cmm18':
                continue
            for scoring in scorings:
                if scoring != 'emission':
                    continue
                load_keys = [k for k in cached_keys if seg in k and scoring in k]
                
                for k in load_keys:
                    comp = k.split('/')[-1]
                    if comp not in ['Ma_vs_Mo']:
                        continue
                    s1, s2 = comp.split('_vs_')
                    data = hdf[k]
                    de_mirna_file = os.path.join(base_out, 'de_mirna.tsv')
                    de_mirna = pd.read_csv(de_mirna_file, sep='\t', header=None,
                                           names=['mirna_id', 'log2fc'])
#                     print(de_mirna.shape)
#                     print('Loading miRNA')
#                     with pd.HDFStore(mirna_cache, 'r') as cache:
#                         mirna = cache['mirna_exp/' + comp]
#                         mirna = mirna.loc[mirna['mirna_id'].isin(de_mirna['mirna_id']), :].copy()
#                         mirna.drop_duplicates(['mirna_id', 'name'], inplace=True)
                    ref_folder = os.path.join(fhgfs_base, 'references')
                    trg_scan = 'TargetScan_v7.2_default_predictions.txt.zip'
                    mirna = pd.read_csv(os.path.join(ref_folder, trg_scan), sep='\t', header=0)
                    mirna = mirna.loc[mirna['Species ID'] == 9606, :].copy()
                    mirna['name'] = mirna['Gene ID'].str.extract('(?P<ENSID>\w+)', expand=False)
                    mirna['enst'] = mirna['Transcript ID'].str.extract('(?P<ENSID>\w+)', expand=False)
                    mirna['hit_start'] = mirna['UTR start']
                    mirna['hit_end'] = mirna['UTR end']
                    mirna['mirna_family'] = 'hsa-' + mirna['miR Family']
                    mirna = mirna.loc[:, ['mirna_family', 'name', 'enst', 'hit_start', 'hit_end']]
                    
                    split_set = []
                    de_mirna_ids = set(de_mirna['mirna_id'])
                    for row in mirna.itertuples():
                        members = row.mirna_family.split('/')
                        for m in members:
                            if m.startswith('hsa-miR'):
                                keep = m
                            else:
                                keep = 'hsa-miR-' + row.mirna_family
                            if keep in de_mirna_ids and row.name in data.index:
                                split_set.append((keep, row.name, row.enst, row.hit_start, row.hit_end))
                    mirna = pd.DataFrame(split_set, columns=['mirna_id', 'name', 'transcript',
                                                             'hit_start', 'hit_end'])
                    mirna.drop_duplicates(['mirna_id', 'transcript', 'hit_start'])
                    print(col.Counter(mirna['mirna_id']).most_common(10))                                                     
                    data['group'] = 'empty'
                    data['name'] = data.index
                    data.reset_index(drop=True, inplace=True)
                                        
                    group_a = np.logical_and(data['genetype'] == 'diff',
                                             np.logical_or(data['body_hsp'] > 0,
                                                           data['total_hits'] > 0))
                    data.loc[group_a, 'group'] = 'deg_hsp-hit'
                    
                    group_b = np.logical_and(data['genetype'] == 'stable',
                                             np.logical_or(data['body_hsp'] > 0,
                                                           data['total_hits'] > 0))
                    data.loc[group_b, 'group'] = 'stable_hsp-hit'
                    
                    group_c = np.logical_and(data['genetype'] == 'diff',
                                             np.logical_and(data['body_hsp'] == 0,
                                                            data['total_hits'] == 0))
                    data.loc[group_c, 'group'] = 'deg_no-hit'
                    
                    group_d = np.logical_and(data['genetype'] == 'stable',
                                             np.logical_and(data['body_hsp'] == 0,
                                                            data['total_hits'] == 0))
                    data.loc[group_d, 'group'] = 'stable_no-hit'
                    
                    src_group = col.Counter(data['group'])

                    data = data.loc[data['name'].isin(mirna['name']), :].copy()
                    group_count = col.Counter(data['group'])
                    data = data.merge(mirna, on='name', how='outer')
                    print(data.shape)
                    hit_counts = col.Counter(data['group'])
                                        
                    for g, hits in hit_counts.items():
                        gsize = group_count[g]
                        print('Group ', g)
                        print('Group size ', gsize, '({})'.format(src_group[g]))
                        print('Target hits in group ', hits)
                        print('Avg. target hits per gene ', np.round(hits / gsize, 2))
                        print('====')
                    
                    
#                     fig, exart = make_mirna_scatter(fg, bg, mirna, s1, s2, 'X')
    return


if run_dump_genetrail_lists:
    #dump_genetrail_lists()
    plot_mirna_changes()
    
    

Cached data available...
[('hsa-miR-1-3p', 983), ('hsa-miR-22-3p', 659), ('hsa-miR-221-3p', 532), ('hsa-miR-132-3p', 514), ('hsa-miR-212-5p', 447), ('hsa-miR-150-5p', 369), ('hsa-miR-342-3p', 326), ('hsa-miR-409-3p', 325), ('hsa-miR-486-5p', 178), ('hsa-miR-210-3p', 48)]
(4400, 17)
Group  stable_hsp-hit
Group size  280 (1697)
Target hits in group  391
Avg. target hits per gene  1.4
====
Group  deg_hsp-hit
Group size  79 (438)
Target hits in group  119
Avg. target hits per gene  1.51
====
Group  deg_no-hit
Group size  340 (1971)
Target hits in group  482
Avg. target hits per gene  1.42
====
Group  stable_no-hit
Group size  2391 (15985)
Target hits in group  3408
Avg. target hits per gene  1.43
====
