In [51]:
%matplotlib inline

import os as os
import collections as col
import json as js
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as mpatch

mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['font.size'] = 14

model_colors = {'seq': 'blue', 'hist': 'orange', 'full': 'green'}

basedir = '/home/pebert/temp/creepiest/figures/expvalue'
inputtop = os.path.join(basedir, 'input')
output = os.path.join(basedir, 'output')

sym_grt = u'\u003E'
sym_geq = u'\u2265'

def extract_cv_std(metadata):
    # this will break as soon as different CV grids are saved
    modelparams = metadata['model_info']['params']
    trees = modelparams['n_estimators']
    splits = modelparams['min_samples_split']
    cvgrid = metadata['training_info']['cv_scores']
    for cvs in cvgrid:
        p = cvs['params']
        if p['n_estimators'] == trees and p['min_samples_split'] == splits:
            return cvs['scores']['std']
    raise ValueError('Could not extract CV std')
            

def extract_performance(fp):
    data = js.load(open(fp, 'r'))
    cvstd = None
    try:
        perf = data['training_info']['best_score']
        cvstd = extract_cv_std(data)
    except KeyError:
        perf = data['testing_info']['performance']
    return perf, cvstd

def collect_trained_models(inputdir):
    all_models = os.listdir(inputdir)
    collect_models = col.defaultdict(list)
    for mf in all_models:
        mbase_id = mf.rsplit('.', 4)[0]
        collect_models[mbase_id].append(os.path.join(inputdir, mf))
    assert collect_models, 'No traindata collected'
    return collect_models

def collect_test_data(inputdir):
    collect_testdata = col.defaultdict(list)
    for root, dirs, files in os.walk(inputdir):
        if files:
            for f in files:
                try:
                    _, mf = f.split('_with_')
                except ValueError:
                    continue
                mbase_id = mf.rsplit('.', 4)[0]
                collect_testdata[mbase_id].append(os.path.join(root, f))
    assert collect_testdata, 'No testdata collected'
    return collect_testdata

def extract_ordering(fpath):
    order_idx = []
    base = os.path.basename(fpath)
    dirname = os.path.dirname(fpath)
    if dirname.endswith('/train'):
        order_idx.append(0)
    elif dirname.endswith('/true'):
        order_idx.append(1)
    else:
        order_idx.append(2)
    if 'grt0' in base:
        order_idx.append(0)
    else:
        order_idx.append(1)
    if 'seq' in base:
        order_idx.append(0)
        order_idx.append(model_colors['seq'])
    elif 'hist' in base:
        order_idx.append(1)
        order_idx.append(model_colors['hist'])
    else:
        order_idx.append(2)
        order_idx.append(model_colors['full'])
    assert len(order_idx) == 4, '?!.WTF.$'
    return tuple(order_idx)

def format_title(infile, model=True):
    # ENCSR000EYP_hg19_H1hESC_mRNA_BWCALT.genes.7xBBBRD.to.mm9.rfcls.seq.uw.grt0.json'
    fn = os.path.basename(infile)
    if model:
        modelpart = fn.split('.')[1:]
        if modelpart[3] == 'to':
            to_assm = modelpart[4]
        else:
            to_assm = modelpart[3]
        samplepart = fn.split('_')[:4]
        modeltitle = ' '.join(['Train:', samplepart[0], samplepart[1], ''.join(modelpart[1][:2]) + 'hm',
                               samplepart[2], samplepart[3], '[to ' + to_assm + ']'])
        title_info = modeltitle
    else:
        samplepart = fn.split('_')[:4]
        modelpart = fn.split('.')[1:]
        if modelpart[3] == 'from':
            from_assm = modelpart[4]
        else:
            from_assm = modelpart[3]
        sampletitle = ' '.join(['Test:', samplepart[0], samplepart[1], ''.join(modelpart[1][:2]) + 'hm',
                                samplepart[2], samplepart[3], '[from ' + from_assm + ']'])
        title_info = sampletitle
    return title_info
   

def annotate_files(allfiles):
    allfiles = sorted([(extract_ordering(f), f) for f in allfiles])
    mt = format_title(allfiles[0][1])
    st = format_title(allfiles[-1][1], False)
    yvals, colors = [], []
    for idx, fp in allfiles:
        colors.append(idx[3])
        yvals.append(extract_performance(fp))
    return yvals, colors, mt, st
    

def make_barchart(trainfiles, testfiles, outdir):
    if not testfiles:
        return
    assert (len(trainfiles) * 2) == len(testfiles), 'Too many files: {}\nvs\n{}'.format(trainfiles, testfiles)
    assert len(trainfiles) == 6, 'Unexpected number of trained models: {}'.format(trainfiles)
    outname = os.path.basename(testfiles[0]).rsplit('.', 2)[0]
    model_alpha = 0.75
    ma = model_alpha
    xvals, colors, mt, st = annotate_files(trainfiles + testfiles)
    yvals = list([0.25 + x for x in np.arange(0, len(xvals), 0.65)])
    ylim, xlim = (0, 12.25), (-0.3, 1.01)
    fig, ax = plt.subplots(figsize=(5, 11))
    ax.spines['top'].set_color('none')
    ax.spines['right'].set_color('none')
    ax.set_title('{}\n{}'.format(mt, st), y=1.05)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel('RandomForest R2')
    error_config = {'ecolor': 'black', 'capsize': 3,
                    'barsabove': True}
    for x, y, c in zip(xvals, yvals, colors):
        if x[1] is None:
            ax.barh(y, x[0], height=0.5, color=c, xerr=None, alpha=ma)
        else:
            ax.barh(y, x[0], height=0.5, color=c, xerr=[[0] ,[x[1]]],
                   alpha=ma, error_kw=error_config)
            
    y_adj = 0.
    ax.set_yticks([v + y_adj for v in [1,3,5,7,9,11]])
    ax.set_yticklabels(['TPM {} 0'.format(sym_grt),
                        'TPM {} 1'.format(sym_geq),
                        'True TPM {} 0'.format(sym_grt),
                        'True TPM {} 1'.format(sym_geq),
                        'Est. TPM {} 0'.format(sym_grt),
                        'Est. TPM {} 1'.format(sym_geq)], fontsize=14)
    ax.tick_params(axis='x', direction='out', color='black', width=1, length=5,
                   bottom='off', labelbottom='on', labeltop='off', top='off')
    ax.tick_params(axis='y', direction='out', color='black', width=1, length=5,
                   labelleft='on', left='on', labelright='off', right='off')
    seq_leg = mpatch.Patch(color='blue', label='Seq.')
    hist_leg = mpatch.Patch(color='orange', label='Hist.')
    full_leg = mpatch.Patch(color='green', label='Full')
    plt.legend(handles=[full_leg, hist_leg, seq_leg],
               bbox_to_anchor=(1.4, 0.6))
    plt.axhline(y=4.05, xmin=-0.25, xmax=1, linewidth=2,
                color='gray', linestyle='dashed', alpha=0.5)
    plt.axhline(y=7.95, xmin=-0.25, xmax=1, linewidth=2,
                color='gray', linestyle='dashed', alpha=0.5)
    ax.text(0.85, 11, 'Testing (Est.)', rotation=-90)
    ax.text(0.85, 7, 'Testing (True)', rotation=-90)
    ax.text(0.85, 3, 'Training (CV)', rotation=-90)
    for ext in ['.png', '.svg']:
        outpath = os.path.join(outdir, outname + ext)
        fig.savefig(outpath, dpi=300, bbox_inches='tight', pad_inches=0.25)
    plt.close()
    #plt.show()
    return
    

# run notebook

trained_models = collect_trained_models(os.path.join(inputtop, 'train'))
testdata = collect_test_data(os.path.join(inputtop, 'apply'))

for k, v in trained_models.items():
    all_testfiles = sorted(testdata[k], key=lambda x: os.path.basename(x))
    if not all_testfiles:
        print('No testdata: {}'.format(k))
        continue
    if len(all_testfiles) < 12:
        print('Too few testdata')
        print(k)
        raise
    if len(all_testfiles) > 12:
        assert len(all_testfiles) % 12 == 0, 'IDK'
        for idx in range(0, len(all_testfiles), 12):
            _ = make_barchart(v, all_testfiles[idx:idx+12], output)
    else:
        _ = make_barchart(v, testdata[k], output)
    
print('Done')


No testdata: ENCSR000EYR_hg19_HepG2_mRNA_BWCALT.genes.7xBBBRD.to.canFam3.rfreg
No testdata: ENCSR962TBJ_hg19_H1hESC_mRNA_BWCALT.genes.7xBBBRD.to.canFam3.rfreg
No testdata: ENCSR077AZT_hg19_GM12878_mRNA_BWCALT.genes.7xBBBRD.to.bosTau7.rfreg
No testdata: ENCSR000EYR_hg19_HepG2_mRNA_BWCALT.genes.7xBBBRD.to.susScr2.rfreg
No testdata: ENCSR000CHA_mm9_kidney_mRNA_BRUCSD.genes.5xBRUCSD.to.hg19.rfreg
No testdata: ENCSR000EYP_hg19_H1hESC_mRNA_BWCALT.genes.7xBBBRD.to.bosTau7.rfreg
No testdata: SRX211657_mm9_kidney_mRNA_EWHA.genes.5xBRUCSD.to.hg19.rfreg
No testdata: ENCSR000EYP_hg19_H1hESC_mRNA_BWCALT.genes.7xBBBRD.to.canFam3.rfreg
No testdata: ENCSR000EYO_hg19_K562_mRNA_BWCALT.genes.7xBBBRD.to.bosTau7.rfreg
No testdata: ENCSR000EYN_hg19_GM12878_mRNA_BWCALT.genes.7xBBBRD.to.canFam3.rfreg
No testdata: ENCSR000CHM_mm9_MEL_mRNA_BRUCSD.genes.6xBRUCSD.to.susScr2.rfreg
No testdata: ENCSR077AZT_hg19_GM12878_mRNA_BWCALT.genes.7xBBBRD.to.canFam3.rfreg
No testdata: ENCSR000EYO_hg19_K562_mRNA_BWCALT.genes.7