In [3]:
%matplotlib inline

import os as os
import collections as col
import itertools as itt
import pickle as pck
import json as js
import time as ti

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import numpy as np
import numpy.random as rng
import scipy.stats as stats
import pandas as pd
import seaborn as sns

# What does this do?
# Plot boxplots of HSP scores
# (or #overlaps) overlapping with genes,
# where genes are binned by minimum
# expression level in any of the samples

date = '20180622'

run_plot_hsp_gene_ovl = True

save_figures = False

journal = 'ox_bioinf'
res = 'print_lo'
fig_sizes = js.load(open('/home/pebert/work/code/mpggit/statediff/annotation/misc/fig_sizes.json'))
fig_sizes = fig_sizes[journal]
resolution = fig_sizes['resolution']

sns.set(style='white',
        font_scale=1.5,
        rc={'font.family': ['sans-serif'],
            'font.sans-serif': ['DejaVu Sans']})

fhgfs_base = '/TL/deep/fhgfs/projects/pebert/thesis/projects/statediff'
cache_dir = os.path.join(fhgfs_base, 'caching/notebooks')

hsp_gene_ovl_folder = os.path.join(fhgfs_base, 'bedtools/deep/gene_isect')
hsp_ensreg_ovl_folder = os.path.join(fhgfs_base, 'bedtools/deep/ensreg_isect')
hsp_enh_ovl_folder = os.path.join(fhgfs_base, 'bedtools/deep/enh_isect')
de_gene_folder = os.path.join(fhgfs_base, 'deseq/deep')
bed_gene_folder = os.path.join(fhgfs_base, 'deseq/bed_out')

tpm_file = os.path.join(fhgfs_base, 'salmon', 'deep', 'agg_gene_tpm.h5')

base_out = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/statediff'
genetrail_out = os.path.join(fhgfs_base, 'genetrail/input_lists')
fig_supp = os.path.join(base_out, 'figures', 'pub', 'supp')
fig_main = os.path.join(base_out, 'figures', 'pub', 'main')
fig_collect = os.path.join(base_out, 'figures', 'pub', 'collection')
                   

def get_gene_counts(comparison):    
    stable_genes = os.path.join(bed_gene_folder, 'deseq2_{}_stable_body.bed'.format(comparison))
    df = pd.read_csv(stable_genes, sep='\t', header=0)
    num_stable = df.shape[0]
    
    diff_genes = os.path.join(bed_gene_folder, 'deseq2_{}_diff_body.bed'.format(comparison))
    df = pd.read_csv(diff_genes, sep='\t', header=0)
    num_diff = df.shape[0]
    
    df['de_name'] = df['name'].str.extract('(?P<ENSID>\w+)')
    df['gene_type'] = 'DIFF'
    df['de_log2fc'] = df['log2fc']
    df['gene_length'] = df['end'] - df['start']
    
    return num_diff, num_stable, df[['de_name', 'gene_type', 'de_log2fc', 'gene_length']]


def get_isect_headers(hsp_header):
    
    deseq_header = "chrom  start   end name    log2fc  strand  symbol  pv_adj"
    deseq_header = deseq_header.split()
    deseq_header = ['de_' + h for h in deseq_header]

    gene_table_header = hsp_header + ['gene_type'] + deseq_header + ['overlap']

    ensreg_header = "chrom start end name score strand feature"
    ensreg_header = ensreg_header.split()
    ensreg_header = ['rgb_' + h for h in ensreg_header]

    ensreg_table_header = hsp_header + ensreg_header + ['overlap']
    
    enh_header = "chrom start end GHid enhancer_score is_elite cluster_id name symbol assoc_score enh_gene_dist"
    enh_header = enh_header.split()
    enh_header = ['enh_' + h for h in enh_header]
    
    enh_table_header = hsp_header + enh_header + ['overlap']
    
    return gene_table_header, ensreg_table_header, enh_table_header
    
    
def load_expression_data(gene_names, c1, c2):
    
    merged = None
    with pd.HDFStore(tpm_file, 'r') as hdf:
        for k in hdf.keys():
            if k.endswith(c1) or k.endswith(c2):
                data = hdf[k]
                data = data.loc[data.index.isin(gene_names), :]
                if merged is None:
                    merged = data
                else:
                    merged = pd.concat([merged, data], axis=1, ignore_index=False)
    merged['de_name'] = merged.index
    merged['tpm'] = merged.min(axis=1)
    merged = merged[['de_name', 'tpm']]
    return merged
    
    
def cache_gene_ovl_data(rootfolder, cache_file):
    """
    """
    col_select = ['hsp_name', 'hsp_nat_score_lnorm',
                  'gene_type', 'de_name', 'de_log2fc',
                  'overlap', 'gene_length']
    
    hsp_header, gene_table_header = None, None
    
    with pd.HDFStore(cache_file, 'w') as hdf:
        for tsv in os.listdir(rootfolder):
            fpath = os.path.join(rootfolder, tsv)
            if hsp_header is None:
                with open(fpath, 'r') as table:
                    hd = table.readline().strip().split('\t')
                    hsp_header = ['chrom'] + hd[1:]
                    hsp_header = ['hsp_' + h for h in hsp_header]
                    gene_table_header, _, _ = get_isect_headers(hsp_header)
            # deep_hsp_ovl_degenes_cmm18_HG_vs_He_emission.tsv
            infos = tsv.split('.')[0].split('_')
            tool, scoring = infos[4], infos[-1]
            comparison = '_'.join(infos[5:8])
            num_diff, num_stable, all_de_genes = get_gene_counts(comparison)
                        
            df = pd.read_csv(fpath, sep='\t', names=gene_table_header, skiprows=1)
            # remove HSPs not overlapping any gene
            df = df.loc[df['de_name'] != '.', :].copy()
            df['de_name'] = df['de_name'].str.extract('(?P<ENSID>\w+)')
            df['gene_length'] = df['de_end'] - df['de_start']
            df = df.loc[:, col_select].copy()
            
            non_hits = all_de_genes.loc[~all_de_genes['de_name'].isin(df['de_name']), :].copy()
            non_hits['hsp_name'] = 'empty'
            df = pd.concat([df, non_hits], ignore_index=False, axis=0, sort=True)
            df.fillna(0, inplace=True)
            
            df['de_group'] = -1
            df['tpm_group'] = 0
            df['tpm'] = -1
            df['total_hsp_num'] = df['hsp_name'].unique().size
            df.loc[df['gene_type'] == 'STABLE', 'de_group'] = 0
            
            # treat DE genes separately for convenience
            diffs = df.loc[df['gene_type'] == 'DIFF', :].copy()
            diffs['de_log2fc'] = diffs['de_log2fc'].astype(np.float32).abs()
            lower, upper = np.percentile(diffs['de_log2fc'], [25, 75])
            diffs.loc[diffs['de_log2fc'] < lower, 'de_group'] = 1
            diffs.loc[diffs['de_log2fc'] > upper, 'de_group'] = 3
            diffs.loc[diffs['de_group'] == -1, 'de_group'] = 2
            diffs.drop('tpm', axis=1, inplace=True)
            
            # load TPM values for genes
            diff_tpm = load_expression_data(diffs['de_name'], infos[5], infos[7])
            diffs = diffs.merge(diff_tpm, on=['de_name'], how='outer')
                        
            # count HSP overlaps per gene
            hsp_gene_ovl = df['de_name'].value_counts().to_frame()
            hsp_gene_ovl['count_hsp_gene'] = hsp_gene_ovl['de_name']
            hsp_gene_ovl['de_name'] = hsp_gene_ovl.index
            hsp_gene_ovl.loc[hsp_gene_ovl['de_name'].isin(non_hits['de_name']), 'count_hsp_gene'] = 0
            hsp_gene_ovl.reset_index(drop=True, inplace=True)
            
            # merge data back together
            diffs.drop(['de_log2fc'], axis=1, inplace=True)
            df.drop(['de_log2fc'], axis=1, inplace=True)
            df = df.loc[df['gene_type'] != 'DIFF', :].copy()
            df = pd.concat([df, diffs], axis=0, ignore_index=False, sort=True)
            df.reset_index(drop=True, inplace=True)
            df.loc[df['gene_type'] == 'DIFF', 'total_num'] = num_diff
            df.loc[df['gene_type'] == 'STABLE', 'total_num'] = num_stable
            
            df = df.merge(hsp_gene_ovl, how='outer')           
            
            assert not (pd.isnull(df).any(axis=1).any()), 'NA entries'
           
            cache_path = os.path.join(tool, scoring, comparison)
            hdf.put(cache_path, df, format='table')
        
    return cache_file
    
    
def create_boxplot(datapoints, fkey):
    """
    """
    boxcolor = 'dimgrey'
    medcolor = 'grey'
    median_props = {'color': medcolor, 'linewidth': 2}
    box_props = {'color': boxcolor, 'linewidth': 1}
    whisker_props = {'color': boxcolor, 'linewidth': 1}
    cap_props = {'color': boxcolor, 'linewidth': 1}
        
    len_dist = []
    tpm_dist = []
    pct_scores = []
    for b in sorted(datapoints['bin'].unique()):
        sub = datapoints.loc[datapoints['bin'] == b, :]
        
        len_dist.append(np.log10(sub['gene_length']))
        tpms = sub['tpm']
        s = int(np.round(stats.scoreatpercentile(tpms, 95), 2))
        pct_scores.append(s)        
        tpm_dist.append(tpms)
    
    fig, (ax_top, ax_bottom) = plt.subplots(figsize=fig_sizes['one_col']['double'],
                                                    nrows=2, ncols=1, sharex=True, sharey=False,
                                                    gridspec_kw={'height_ratios': [1, 2]})
    
    plt.subplots_adjust(hspace=0.005)
    
    # plot gene length distribution
    bb_len = ax_top.boxplot(len_dist, sym="", labels=None,
                            medianprops=median_props, boxprops=box_props,
                            whiskerprops=whisker_props, capprops=cap_props)
        
    top_max_y = 0
    for cap in bb_len['caps']:
        if max(cap.get_ydata()) > top_max_y:
            top_max_y = max(cap.get_ydata())
    
    ax_top.set_ylim(1, np.ceil(top_max_y) + 0.5)
    ax_top.set_ylabel('Gene length\n(log10)', fontsize=14)
    top_yticks = np.arange(2, top_max_y, step=2, dtype=np.int32)
    
    ax_top.set_yticks(top_yticks)
    ax_top.set_yticklabels(list(map(str, top_yticks)), fontsize=12)
    ax_top.spines['bottom'].set_visible(False)
    fig_title = ax_top.set_title('Figure {}'.format(fkey), fontsize=16)
    fig_title.set_position((0.15, 1.01))
    
    # plot HSP scores
    bb_tpm = ax_bottom.boxplot(tpm_dist, sym="", labels=None,
                               medianprops=median_props, boxprops=box_props,
                               whiskerprops=whisker_props, capprops=cap_props)
    bottom_max_y = 0
    for cap in bb_tpm['caps']:
        if max(cap.get_ydata()) > bottom_max_y:
            bottom_max_y = max(cap.get_ydata())
    bottom_max_y = bottom_max_y + 0.5
    ax_bottom.set_ylim(-0.1, bottom_max_y)
    
    bottom_yticks = np.arange(0, bottom_max_y, step=2, dtype=np.int16)

    ax_bottom.set_yticks(bottom_yticks)
    ax_bottom.set_yticklabels(list(map(str, bottom_yticks)), fontsize=12)
    ax_bottom.set_ylabel('Minimum TPM', fontsize=14, labelpad=10)
    
    xlabels = []
    for c, s in zip(range(4), pct_scores):
        if c == 3:
            xlabels.append('3+\n({})'.format(c, s))
        else:
            xlabels.append('{}\n({})'.format(c, s))        
    
    ax_bottom.set_xticklabels(xlabels, fontsize=12) 
    ax_bottom.set_xlabel('HSP hits per DE gene\n(TPM 95th %ile)', fontsize=14)
    return fig, []

            
def plot_hsp_gene_ovl():
    cache_file = os.path.join(cache_dir, '{}_plot_hsp_gene-ovl_tpm-bin.h5'.format(date))
    if not os.path.isfile(cache_file):
        _ = cache_gene_ovl_data(hsp_gene_ovl_folder, cache_file)
    elif os.stat(cache_file).st_size < 10e3:
        _ = cache_gene_ovl_data(hsp_gene_ovl_folder, cache_file)
    else:
        print('Assuming cache file is valid')
    
    utr_file = '/home/pebert/work/code/github/gencode_regions/gtf_out/3UTRs.bed'
    utrs = pd.read_csv(utr_file, header=None, sep='\t',
                       names=['chrom', 'start', 'end', 'name', 'score', 'strand'])
    utrs['utr_length'] = utrs['end'] - utrs['start']
    ulen = utrs.groupby(['name'], as_index=False)['utr_length'].mean()
    
    
    with pd.HDFStore(cache_file, 'r') as hdf:
        cached_keys = list(hdf.keys())
        x_labels = ['no\novl.', 'stable\ngenes', 'DEG\n(lo)', 'DEG\n(mid)', 'DEG\n(hi)']
        for seg in ['cmm18', 'ecs18', 'ecs10']:
            if seg == 'ecs10':
                continue
            for scoring in ['emission', 'replicate']:
                if scoring != 'emission':
                    continue
                load_keys = [k for k in cached_keys if seg in k and scoring in k]
                
                collector = []
                for k in load_keys:
                    comp = k.split('/')[-1]
                    print(comp)
                    data = hdf[k]                   
                    data = data.loc[data['gene_type'] == 'DIFF', ['de_name', 'tpm', 'count_hsp_gene', 'gene_length']]
                    
                    # dump individual comparisons data
                    # for miRNA enrichment
                    
                    bins = [0, 1, 2, 3, data['count_hsp_gene'].max() + 1]
                    data['bin'] = np.digitize(data['count_hsp_gene'],
                                              bins=bins, right=False)
                    # bin 1 = no overlaps
                    # bin > 1 = at least one overlap

                    gt_outlist_fg = os.path.join(genetrail_out, 'mirna_fg_{}_{}_{}.tsv'.format(seg, scoring, comp))
                    foreground = data.loc[data['bin'] == 1, ['de_name']].copy()
                    foreground.drop_duplicates(inplace=True)
                    print(foreground.shape[0])
                    lens = np.percentile(ulen.loc[ulen['name'].isin(foreground['de_name']), 'utr_length'], [25, 50, 75])
                    print('FG: ', lens)
                    foreground.to_csv(gt_outlist_fg, sep='\t', header=False, index=False, mode='w')
                    #
                    gt_outlist_bg = os.path.join(genetrail_out, 'mirna_bg_{}_{}_{}.tsv'.format(seg, scoring, comp))
                    background = data.loc[data['bin'] == 2, ['de_name']].copy()
                    background.drop_duplicates(inplace=True)
                    print(background.shape[0])
                    lens = np.percentile(ulen.loc[ulen['name'].isin(background['de_name']), 'utr_length'], [25, 50, 75])
                    print('BG: ', lens)
                    background.to_csv(gt_outlist_bg, sep='\t', header=False, index=False, mode='w')             
                    
                    collector.append(data)
                continue             
                plot_data = pd.concat(collector, axis=0, ignore_index=False)
                
                
                if seg == 'ecs18':
                    fk = 'SX'
                    subfolder = fig_supp
                else:
                    fk = 'X'
                    subfolder = fig_main
                
                fig, exart = create_boxplot(plot_data, fk)

                if save_figures:
                    outname = 'fig_{}_hsp-gene_tpm_{}_{}'.format(fk, seg, scoring)
                    
                    out_pdf = os.path.join(subfolder, outname + '.pdf')
                    fig.savefig(out_pdf, bbox_inches='tight', extra_artists=exart)
                    out_png = os.path.join(subfolder, outname + '.png')
                    fig.savefig(out_png, bbox_inches='tight', extra_artists=exart, dpi=resolution[res])
    return 0



    
if run_plot_hsp_gene_ovl:
    plot_hsp_gene_ovl()


Assuming cache file is valid
HG_vs_He
3164
FG:  [ 253.29166667  524.87121212 1104.525     ]
1096
BG:  [ 287.46052632  626.6        1391.875     ]
HG_vs_Ma
2564
FG:  [ 241.66666667  493.         1054.66666667]
1906
BG:  [ 297.    603.25 1298.  ]
HG_vs_Mo
3467
FG:  [ 249.89285714  512.04166667 1077.5       ]
2103
BG:  [ 289.375  600.    1307.125]
He_vs_Ma
2551
FG:  [ 246.40625     518.375      1140.45833333]
1586
BG:  [ 295.90384615  578.5        1259.125     ]
He_vs_Mo
3001
FG:  [ 258.    526.75 1133.  ]
1806
BG:  [ 289.    594.5  1229.05]
Ma_vs_Mo
2159
FG:  [ 298.71296296  612.125      1266.375     ]
194
BG:  [ 341.19047619  721.98076923 1342.75      ]
HG_vs_He
4147
FG:  [ 268.4   568.25 1197.  ]
500
BG:  [ 302.83333333  624.35       1308.        ]
HG_vs_Ma
3498
FG:  [ 259.1875      524.58333333 1139.18333333]
1359
BG:  [ 314.5  638.4 1373. ]
HG_vs_Mo
4586
FG:  [ 259.45833333  537.65       1140.75      ]
1401
BG:  [ 315.57954545  648.28571429 1396.06818182]
He_vs_Ma
3187
FG:  [ 266.475