In [14]:
%matplotlib inline

import os as os
import numpy as np
import numpy.random as rng
import json as js
import scipy.spatial.distance as dist
import scipy.cluster.hierarchy as hier
import pickle as pck
import collections as col
import operator as op
import functools as fnt
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from scipy.stats import mannwhitneyu as mwu
import scipy as sci

sns.set(style='white',
        font_scale=1.25,
        rc={'font.family': ['sans-serif'],
            'font.sans-serif': ['DejaVu Sans']})

fhgfs_base = '/TL/deep/fhgfs/projects/pebert/thesis'
stat_folder = os.path.join(fhgfs_base, 'projects/cross_species/processing/norm/task_summarize')
stat_file = os.path.join(stat_folder, 'agg_expstat_est.h5')
ortho_folder = os.path.join(fhgfs_base, 'projects/cross_species/processing/norm/task_ortho_pred')
ortho_pred = os.path.join(ortho_folder, 'orthopred_odb_v9.h5')

cache_dir = '/home/pebert/.jupyter/cache'
clean_cache = False

conf_folder = '/home/pebert/work/code/mpggit/crossspecies/graphics'

plot_labels = js.load(open(os.path.join(conf_folder, 'labels', 'cs_labels.json'), 'r'))
plot_colors = js.load(open(os.path.join(conf_folder, 'colors', 'cs_colors.json'), 'r'))
plot_shapes = js.load(open(os.path.join(conf_folder, 'shapes', 'cs_shapes.json'), 'r'))

run_pred_cluster = True

show_figures = True

out_folder = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/crossspecies/figures/pub'

# for dumping genesets, need promoter annotation
DATA_FREEZE = '201709'
dir_annotation = '/TL/deep/fhgfs/projects/pebert/thesis/refdata/genemodel/subsets/protein_coding/roi_hdf'

gene_annot = {'human': os.path.join(dir_annotation, 'hsa_hg19_gencode_v19.reg5p.h5'),
              'mouse': os.path.join(dir_annotation, 'mmu_mm9_gencode_vM1.reg5p.h5')}

out_genesets = '/TL/deep-external01/nobackup/pebert/cloudshare/mpiinf/phd/chapter_projects/crossspecies/supplement'

save_figures = False

def exec_pred_cluster():
    cache_data = os.path.join(cache_dir, 'plot_pred_cluster.h5')
    fullmodel = 'can'
    if not run_pred_cluster:
        return False
    if clean_cache and os.path.isfile(cache_data):
        os.unlink(cache_data)
    if os.path.isfile(cache_data):
        print('Loading cached data')
        model_perf = dict()
        with pd.HDFStore(cache_data, 'r') as hdf:
            for k in hdf.keys():
                if k == '/switch_genes':
                    switch_genes = set(hdf[k].values)
                    continue
                _, spec_a, spec_b, modeltype = k.split('/')
                data = hdf[k]
                model_perf[spec_a, spec_b] = {modeltype: data}
    else:
        model_perf = collect_ortho_perf(ortho_pred, fullmodel)
        model_perf, switch_genes = collect_model_stat_perf(stat_file, model_perf, 'pos', fullmodel)
        with pd.HDFStore(cache_data, 'w') as hdf:
            for spec_pair in model_perf.keys():
                collected = model_perf[spec_pair]
                for modeltype, df in collected.items():
                    path = os.path.join(spec_pair[0], spec_pair[1], modeltype)
                    hdf.put(path, df, type='fixed')
            hdf.put('switch_genes', pd.Series(sorted(switch_genes), dtype='object'))
        print('Writing cache file predicted label clustering')
    for (spec_a, spec_b), perf in model_perf.items():
        if 'orth' in list(perf.keys())[0]:
            load_key = 'data_orth_pair'
            out_class = 'orthologs'
        else:
            load_key = 'data_crp_{}_wg'.format(fullmodel)
            out_class = 'genes'
        model_perf = perf[load_key]
        genesets = extract_genesets(model_perf)
        dump_genesets(genesets, out_genesets, load_gene_annotation(gene_annot[spec_b]),
                      spec_b, out_class)        
        if False:
            model_perf = model_perf.loc[model_perf.index.isin(switch_genes), :]

            dm = dist.pdist(model_perf.transpose(), metric='hamming')
            link = hier.linkage(dm, method='average')
            dend = hier.dendrogram(link, labels=model_perf.columns, leaf_rotation=90)
                 
        plot_title = 'Testing: gene status prediction - {} model on {}'.format(spec_a, spec_b)
        if save_figures and False:
            outpath = os.path.join(out_folder, 'main', 'fig_X_main_{}-to-{}_testperf_curve.svg'.format(spec_a, spec_b))
            fig.savefig(outpath, bbox_extra_artists=exart, bbox_inches='tight')
            outpath = outpath.replace('.svg', '.png')
            fig.savefig(outpath, bbox_extra_artists=exart, bbox_inches='tight', dpi=300)
    return True
       

def load_gene_annotation(fpath):
    """
    """
    subsets = []
    with pd.HDFStore(fpath, 'r') as hdf:
        for k in hdf.keys():
            if 'metadata' in k:
                continue
            data = hdf[k]
            _ , chrom = k.rsplit('/', 1)
            data['chrom'] = chrom
            subsets.append(data)
    df = pd.concat(subsets, axis=0, ignore_index=True)
    return df


def dump_genesets(genesets, outfolder, annotation, species, out_class):
    """
    """
    for sample, genes in genesets.items():
        subset = annotation.loc[annotation['name'].isin(genes), :].copy()
        outfile = '_'.join([DATA_FREEZE, species, sample, 'TP', out_class, 'promoters']) + '.bed'
        outpath = os.path.join(outfolder, outfile)
        subset.to_csv(outpath, sep='\t', columns=['chrom', 'start', 'end', 'name', 'score', 'strand', 'symbol'],
                      header=False, index=False)
    return

    
def extract_genesets(model_perf):
    """
    """
    biosamples = col.defaultdict(dict)
    # 'EE07_TE07-EE07_TE03_mm9_ESE14_true'
    # 'EE07_TE07-EE07_TE03_mm9_ESE14_pred'
    for c in model_perf.columns:
        setting, sample, labels = c.rsplit('_', 2)
        if setting not in biosamples[sample]:
            biosamples[sample][setting] = {'true': '', 'pred': ''}
        biosamples[sample][setting][labels] = c
    genesets = col.defaultdict(set)
    for smp, records in biosamples.items():
        for setting, labels in records.items():
            selector = np.logical_and(model_perf[labels['true']] > 0, model_perf[labels['pred']] > 0)
            genes_on = model_perf.loc[selector, :].index.tolist()
            if smp in genesets:
                genesets[smp] = genesets[smp].intersection(genes_on)
            else:
                genesets[smp] = set(genes_on)
    return genesets        
    
    
def collect_model_stat_perf(fpath, data_collect, scenario, model_type):
    """
    """
    switching_genes = set()
    with pd.HDFStore(fpath, 'r') as hdf:
        load_keys = [k for k in hdf.keys() if k.startswith('/'.join(['', scenario, model_type])) and k.endswith('/data')]
        for k in load_keys:
            if any([c in k for c in ['GM12878', 'CH12', 'K562', 'MEL']]):
                continue
            # /pos/can/mouse/human/EE12_TS25_mm9_liver/EE12_TD21_hg19_hepa/data
            _, _, _, spec_a, spec_b, data_a, data_b, _ = k.split('/')
            if (spec_a, spec_b) not in [('human', 'mouse'), ('mouse', 'human')]:
                continue
            perf = hdf[k]
            switching = set((perf.loc[perf['switching'] > 0, :]).index.tolist())
            # switching genes are defined based on whole dataset, so this
            # should always be the same set... if I am not mistaken...
            if not switching_genes:
                switching_genes = switching
            else:
                switching_genes = switching_genes.union(switching)
            pair_prefix = data_a.rsplit('_', 2)[0]
            label_true = '{}-{}_{}'.format(pair_prefix, data_b, 'true')
            perf[label_true] = perf['true_class']
            
            label_pred = '{}-{}_{}'.format(pair_prefix, data_b, 'pred')
            perf[label_pred] = perf['pred_class']
            perf = perf[[label_true, label_pred]].copy()
            model = 'data_crp_{}_wg'.format(model_type)
            
            if (spec_a, spec_b) not in data_collect:
                data_collect[(spec_a, spec_b)] = {model: None}
            
            model_perf = data_collect[(spec_a, spec_b)][model]
            if model_perf is None:
                data_collect[(spec_a, spec_b)][model] = perf
            else:
                model_perf = pd.concat([model_perf, perf], ignore_index=False, axis=1)
                data_collect[(spec_a, spec_b)][model] = model_perf            
    return data_collect, switching_genes


def collect_ortho_perf(fpath, model):
    """
    """
    collector = dict()
    with pd.HDFStore(fpath, 'r') as hdf:
        load_keys = [k for k in hdf.keys() if k.startswith('/pos/pair') and k.endswith('/data')]
        for k in load_keys:
            if any([c in k for c in ['GM12878', 'CH12', 'K562', 'MEL']]):
                continue
            # ['', 'pos', 'pair', 'mouse', 'human', 'TS31_mm9_ncd4', 'TD29_hg19_ncd4', 'data']
            _, _, _, spec_a, spec_b, data_a, data_b, _ = k.split('/')
            if (spec_a, spec_b) not in [('mouse', 'human'), ('human', 'mouse')]:
                continue
            perf = hdf[k]
            pair_prefix = data_a.split('_')[0]
            label_true = '{}-{}_{}'.format(pair_prefix, data_b, 'true')
            perf[label_true] = np.array(perf[data_b] >= 1, dtype=np.int8)

            label_pred = '{}-{}_{}'.format(pair_prefix, data_b, 'pred')
            perf[label_pred] = np.array(perf[data_a] >= 1, dtype=np.int8)
            perf.index = perf[spec_b + '_name']
            perf = perf[[label_true, label_pred]].copy()
            if (spec_a, spec_b) not in collector:
                collector[(spec_a, spec_b)] = {'data_orth_pair': None,
                                               'data_crp_{}_wg'.format(model): None}
            
            
            orth_perf = collector[(spec_a, spec_b)]['data_orth_pair']
            if orth_perf is None:
                collector[(spec_a, spec_b)]['data_orth_pair'] = perf
            else:
                assert not any([c in orth_perf.columns for c in perf.columns]), 'Duplicate columns'
                orth_perf = pd.concat([orth_perf, perf], ignore_index=False, axis=1)
                collector[(spec_a, spec_b)]['data_orth_pair'] = orth_perf
    return collector               
                

execd = exec_pred_cluster()
if execd:
    print('Prediction clustering created')

Loading cached data
Prediction clustering created
