In [3]:
%matplotlib inline

import os as os
import collections as col
import functools as fnt

import numpy as np
import pandas as pd

fhgfs_projects = '/TL/deep/fhgfs/projects/pebert/thesis/projects'
workdir = os.path.join(fhgfs_projects, 'cross_species/processing/norm/task_summarize')

outdir = '/home/pebert/temp/creepiest/cons_genes'

summ_file = os.path.join(workdir, 'train_test_perf_agg.h5')

run_status = True
run_pct_rank = False

def check_rank_consistency(slack, row):
    if (row[1] - slack) <= row[0] <= (row[1] + slack):
        return 1
    else:
        return 0

def collect_gene_status(fpath, query):

    all_genes = set()
    load_paths = col.defaultdict(list)
    with pd.HDFStore(fpath, 'r') as hdf:
        data_paths = [k for k in hdf.keys() if k.startswith('/testing/status/asig') and k.endswith('/data')]
        for k in data_paths:
            components = k.split('/')
            if components[5] == query:
                data, model = components[7], components[8]
                tissue_data = data.split('_')[4]
                tissue_model = model.split('_')[4]
                load_paths[(tissue_data, tissue_model)].append(k)

        for (tisdat, tismod), paths in load_paths.items():
            count_on = col.Counter()
            count_off = col.Counter()
            count_est_per_dataset = col.Counter()
            count_switch_ortho = {'on': col.defaultdict(col.Counter), 'off': col.defaultdict(col.Counter)}
            count_switch_nonortho = {'on': col.defaultdict(col.Counter), 'off': col.defaultdict(col.Counter)}
            num_datasets = 0
            for p in paths:
                path = p.split('/')
                data = hdf[p]
                count_est_per_dataset[path[7]] += 1

                on_idx = np.logical_and(data['true'] == 1, data['pred'] == 1)
                genes_on = (data.loc[on_idx, :]).index.tolist()
                count_on.update(genes_on)

                on_ortho_idx = np.logical_and(on_idx, data['ortho_pair_odb9'] == 1)
                genes_on_ortho = (data.loc[on_ortho_idx, :]).index.tolist()
                count_switch_ortho['on'][path[7]].update(genes_on_ortho)

                on_nonortho_idx = np.logical_and(on_idx, data['ortho_pair_odb9'] == 0)
                genes_on_nonortho = (data.loc[on_nonortho_idx, :]).index.tolist()
                count_switch_nonortho['on'][path[7]].update(genes_on_nonortho)

                off_idx = np.logical_and(data['true'] == 0, data['pred'] == 0)
                genes_off = (data.loc[off_idx, :]).index.tolist()
                count_off.update(genes_off)

                off_ortho_idx = np.logical_and(off_idx, data['ortho_pair_odb9'] == 1)
                genes_off_ortho = (data.loc[off_ortho_idx, :]).index.tolist()
                count_switch_ortho['off'][path[7]].update(genes_off_ortho)

                off_nonortho_idx = np.logical_and(off_idx, data['ortho_pair_odb9'] == 0)
                genes_off_nonortho = (data.loc[off_nonortho_idx, :]).index.tolist()
                count_switch_nonortho['off'][path[7]].update(genes_off_nonortho)

                if len(all_genes) == 0:
                    all_genes = all_genes.union(data.index.tolist())
                num_datasets += 1
            outname = '{}_gene_status_on_{}-{}.tsv'.format(query, tisdat, tismod)
            outpath = os.path.join(outdir, outname)
            dump_gene_list(count_on, all_genes, outpath)
    raise   
    retval = {'all_genes': all_genes, 'genes_on': count_on, 'genes_off': count_off,
              'total_num_est': num_datasets, 'num_est_dataset': count_est_per_dataset,
              'switch_ortho': count_switch_ortho, 'switch_nonortho': count_switch_nonortho}
    return retval


def collect_gene_pct_rank(fpath, query):
    count_consistent = col.Counter()
    num_datasets = 0
    all_genes = set()
    rank_cons = fnt.partial(check_rank_consistency, *(0.025,))
    with pd.HDFStore(fpath, 'r') as hdf:
        load_keys = [k for k in hdf.keys() if k.startswith('/testing/level/all') and k.endswith('/data')]
        for k in load_keys:
            path = k.split('/')
            if path[5] == query:
                data = hdf[k]
                if len(all_genes) == 0:
                    all_genes = all_genes.union(data.index.tolist())
                on_idx = np.logical_and(data['true'] >= 1, data['pred'] >= 1)
                tpms = data.loc[on_idx, ['true', 'pred']]
                tpms = tpms.rank(axis=0, method='dense', ascending=True, pct=True)
                tpms['consistent'] = tpms.apply(rank_cons, axis=1, raw=True, reduce=True)
                cons_set = tpms.loc[tpms['consistent'] == 1, :]
                count_consistent.update(cons_set.index.tolist())
    return all_genes, count_consistent, num_datasets
            


def dump_gene_list(genes, all_genes, outpath):
    sort_list = []
    for g in all_genes:
        c = genes[g]
        sort_list.append((g, c))
    sort_list = sorted(sort_list, key=lambda x: (x[1], x[0]), reverse=True)
    sort_list = '\n'.join(['\t'.join([t[0], str(t[1])]) for t in sort_list])
    with open(outpath, 'w') as out:
        _ = out.write(sort_list + '\n')
    return


def identify_switching_genes(est_per_data, on_genes, off_genes):
    
    always_on_genes = set()
    always_off_genes = set()
    for dset, genes in on_genes.items():
        num_est = est_per_data[dset]
        for name, count in genes.most_common():
            if count < num_est:
                break
            else:
                always_on_genes.add(name)
        for name, count in off_genes[dset].most_common():
            if count < num_est:
                break
            else:
                always_off_genes.add(name)
    switch_genes = always_on_genes.intersection(always_off_genes)
    return switch_genes               
                

def collect_gene_status_lists():
    hsa = collect_gene_status(summ_file, 'human')

    dump_gene_list(hsa['genes_on'], hsa['all_genes'], os.path.join(outdir, 'human_genes_status_on.tsv'))
    dump_gene_list(hsa['genes_off'], hsa['all_genes'], os.path.join(outdir, 'human_genes_status_off.tsv'))

    hsa_switching_ortho = identify_switching_genes(hsa['num_est_dataset'], hsa['switch_ortho']['on'],
                                                   hsa['switch_ortho']['off'])
    with open(os.path.join(outdir, 'human_genes_switch_ortho.tsv'), 'w') as dump:
        _ = dump.write('\n'.join(sorted(hsa_switching_ortho)))

    hsa_switching_nonortho = identify_switching_genes(hsa['num_est_dataset'], hsa['switch_nonortho']['on'],
                                                   hsa['switch_nonortho']['off'])
    with open(os.path.join(outdir, 'human_genes_switch_non-ortho.tsv'), 'w') as dump:
        _ = dump.write('\n'.join(sorted(hsa_switching_nonortho)))
    
    mmu = collect_gene_status(summ_file, 'mouse')

    dump_gene_list(mmu['genes_on'], mmu['all_genes'], os.path.join(outdir, 'mouse_genes_status_on.tsv'))
    dump_gene_list(mmu['genes_off'], mmu['all_genes'], os.path.join(outdir, 'mouse_genes_status_off.tsv'))
    return


def collect_gene_pct_rank_lists():

    hsa = collect_gene_pct_rank(summ_file, 'human')
    dump_gene_list(hsa['genes_on'], hsa['all_genes'], os.path.join(outdir, 'human_genes_pct-rank.tsv'))
    
    mmu = collect_gene_pct_rank(summ_file, 'mouse')
    dump_gene_list(mmu['genes_on'], mmu['all_genes'], os.path.join(outdir, 'mouse_genes_pct-rank.tsv'))
    return

if run_status:
    collect_gene_status_lists()

if run_pct_rank:
    collect_gene_pct_rank_lists()

RuntimeError: No active exception to reraise