In [12]:
%matplotlib inline

import os as os
import collections as col
import json as js
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as mpatch

mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['font.size'] = 14

model_colors = {'seq': 'blue', 'hist': 'orange', 'full': 'green'}

basedir = '/home/pebert/temp/creepiest/figures/expclass'
inputtop = os.path.join(basedir, 'input')
output = os.path.join(basedir, 'output')

sym_grt = u'\u003E'
sym_geq = u'\u2265'

def extract_cv_std(metadata):
    # this will break as soon as different CV grids are saved
    modelparams = metadata['model_info']['params']
    trees = modelparams['n_estimators']
    splits = modelparams['min_samples_split']
    cvgrid = metadata['training_info']['cv_scores']
    for cvs in cvgrid:
        p = cvs['params']
        if p['n_estimators'] == trees and p['min_samples_split'] == splits:
            return cvs['scores']['std']
    raise ValueError('Could not extract CV std')
            

def extract_performance(fp):
    data = js.load(open(fp, 'r'))
    cvstd = None
    try:
        perf = data['training_info']['best_score']
        cvstd = extract_cv_std(data)
    except KeyError:
        perf = data['testing_info']['performance']
    return perf, cvstd

def collect_trained_models(inputdir):
    all_models = os.listdir(inputdir)
    collect_models = col.defaultdict(list)
    for mf in all_models:
        mbase_id = mf.rsplit('.', 4)[0]
        collect_models[mbase_id].append(os.path.join(inputdir, mf))
    return collect_models

def collect_test_data(inputdir):
    all_testdata = os.listdir(inputdir)
    collect_testdata = col.defaultdict(list)
    for tf in all_testdata:
        try:
            _ , mf = tf.split('_with_')
        except ValueError:
            continue
        mbase_id = mf.rsplit('.', 4)[0]
        collect_testdata[mbase_id].append(os.path.join(inputdir, tf))
    return collect_testdata

def extract_ordering(fpath):
    order_idx = []
    base = os.path.basename(fpath)
    if '_with_' in base:
        order_idx.append(1)
    else:
        order_idx.append(0)
    if 'grt0' in base:
        order_idx.append(0)
    else:
        order_idx.append(1)
    if 'seq' in base:
        order_idx.append(0)
        order_idx.append(model_colors['seq'])
    elif 'hist' in base:
        order_idx.append(1)
        order_idx.append(model_colors['hist'])
    else:
        order_idx.append(2)
        order_idx.append(model_colors['full'])
    assert len(order_idx) == 4, '?!.WTF.$'
    return tuple(order_idx)

def format_title(infile, model=True):
    # ENCSR000EYP_hg19_H1hESC_mRNA_BWCALT.genes.7xBBBRD.to.mm9.rfcls.seq.uw.grt0.json'
    fn = os.path.basename(infile)
    if model:
        modelpart = fn.split('.')[1:]
        if modelpart[3] == 'to':
            to_assm = modelpart[4]
        else:
            to_assm = modelpart[3]
        samplepart = fn.split('_')[:4]
        modeltitle = ' '.join(['Train:', samplepart[0], samplepart[1], ''.join(modelpart[1][:2]) + 'hm',
                               samplepart[2], samplepart[3], '[to ' + to_assm + ']'])
        title_info = modeltitle
    else:
        samplepart = fn.split('_')[:4]
        modelpart = fn.split('.')[1:]
        if modelpart[3] == 'from':
            from_assm = modelpart[4]
        else:
            from_assm = modelpart[3]
        sampletitle = ' '.join(['Test:', samplepart[0], samplepart[1], ''.join(modelpart[1][:2]) + 'hm',
                                samplepart[2], samplepart[3], '[from ' + from_assm + ']'])
        title_info = sampletitle
    return title_info
   

def annotate_files(allfiles):
    allfiles = sorted([(extract_ordering(f), f) for f in allfiles])
    mt = format_title(allfiles[0][1])
    st = format_title(allfiles[-1][1], False)
    yvals, colors = [], []
    for idx, fp in allfiles:
        colors.append(idx[3])
        yvals.append(extract_performance(fp))
    return yvals, colors, mt, st
    

def make_barchart(trainfiles, testfiles, outdir):
    if not testfiles:
        return
    assert len(trainfiles) == len(testfiles), 'Too many files: {}\nvs\n{}'.format(trainfiles, testfiles)
    assert len(trainfiles) == 6, 'Unexpected number of trained models: {}'.format(trainfiles)
    outname = os.path.basename(testfiles[0]).rsplit('.', 2)[0]
    xvals = [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]
    xlim, ylim = (0, 16), (0, 1.01)
    model_alpha = 0.75
    ma = model_alpha
    yvals, colors, mt, st = annotate_files(trainfiles + testfiles)
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.spines['top'].set_color('none')
    ax.spines['right'].set_color('none')
    ax.set_title('{}\n{}'.format(mt, st), y=1.05)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_ylabel('RandomForest AUC')
    plt.axhline(y=0.5, xmin=0., xmax=16, linewidth=2,
                color='gray', linestyle='dashed', alpha=0.5)
    error_config = {'ecolor': 'black', 'capsize': 3,
                    'barsabove': True}
    for x, y, c in zip(xvals, yvals, colors):
        if y[1] is None:
            ax.bar(x, y[0], width=0.5, color=c, yerr=None, alpha=ma)
        else:
            ax.bar(x, y[0], width=0.5, color=c, yerr=[[0] ,[y[1]]],
                   alpha=ma, error_kw=error_config)
            
    x_adj = 0.225
    ax.set_xticks([v + x_adj for v in [2, 6, 10, 14]])
    ax.set_xticklabels(['TPM {} 0'.format(sym_grt),
                        'TPM {} 1'.format(sym_geq),
                        'TPM {} 0'.format(sym_grt),
                        'TPM {} 1'.format(sym_geq)], fontsize=14)
    ax.tick_params(axis='x', direction='out', color='black', width=1, length=5,
                   bottom='off', labelbottom='on', labeltop='off', top='off')
    ax.tick_params(axis='y', direction='out', color='black', width=1, length=5,
                   labelleft='on', left='on', labelright='off', right='off')
    seq_leg = mpatch.Patch(color='blue', label='Seq.')
    hist_leg = mpatch.Patch(color='orange', label='Hist.')
    full_leg = mpatch.Patch(color='green', label='Full')
    plt.legend(handles=[seq_leg, hist_leg, full_leg],
               bbox_to_anchor=(1.25, 0.6))
    ax.text(2.5, 0.95, 'Training (CV)')
    ax.text(11, 0.95, 'Testing')
    for ext in ['.png', '.svg']:
        outpath = os.path.join(outdir, outname + ext)
        fig.savefig(outpath, dpi=300, bbox_inches='tight', pad_inches=0.25)
    plt.close()
    return
    

# run notebook

trained_models = collect_trained_models(os.path.join(inputtop, 'train'))
testdata = collect_test_data(os.path.join(inputtop, 'apply'))

for k, v in trained_models.items():
    all_testfiles = sorted(testdata[k])
    if len(all_testfiles) > 6:
        assert len(all_testfiles) % 6 == 0, 'IDK'
        for idx in range(0, len(all_testfiles), 6):
            _ = make_barchart(v, all_testfiles[idx:idx+6], output)
    else:
        _ = make_barchart(v, testdata[k], output)
    
print('Done')


Done
