In [None]:
### Set number of cores for parallel computation
# For just reading output files and creating summary, use 1 kernel
KERNEL = 20


import os
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import pandas as pd
import sys,os
import random
import copyx

import matplotlib.pyplot as plt
import seaborn as sns

from utils.eval import find_best_matches, generate_exprs

from methods import NMF, PCA, sparse_PCA, moCluster, MOFA2

from methods.utils import interpret_results, resultsHandler

from pathlib import Path
import multiprocessing as mp

from utils.eval import find_best_matches, make_known_groups, find_best_matching_biclusters
from contextlib import redirect_stdout




def make_ref_groups(subtypes, annotation,exprs):
    # prepared a dict of subtype classifications {"class1":{"subt1":[],"subt2":[]},"class2":{"subtA":[],"subtB":[]}}
    all_samples = set(exprs.columns.values)
    pam50 = make_known_groups(subtypes, exprs,target_col = "PAM50",verbose=False)
    lum = {}
    lum["Luminal"] = pam50["LumA"].union(pam50["LumB"])
    scmod2 = make_known_groups(subtypes, exprs,target_col = 'SCMOD2',verbose=False)
    claudin = {} 
    claudin["Claudin-low"] = set(subtypes.loc[subtypes['claudin_low']==1,:].index.values).intersection(all_samples)
    
    ihc = {}
    for x in ["IHC_HER2","IHC_ER","IHC_PR"]:
        ihc[x] = set(annotation.loc[annotation[x]=="Positive",:].index.values)
    ihc["IHC_TNBC"] = set(annotation.loc[annotation["IHC_TNBC"]==1,:].index.values)
    
    known_groups = {"PAM50":pam50,"Luminal":lum,"Claudin-low":claudin,"SCMOD2":scmod2,"IHC":ihc}
    
    freqs = {}
    N =  exprs.shape[1]
    for classification in known_groups.keys():
        for group in known_groups[classification].keys():
            n = len(known_groups[classification][group])
            freqs[group] = n/N
            
    return known_groups, freqs

def calculate_perfromance(results, known_groups, freqs, all_samples,
                          classifications={"Intrinsic":["Luminal","Basal","Her2","Normal","Claudin-low"]}):
    # finds best matches for each subtype, calcuates J per subtype and overall performance
    N = len(all_samples)
    best_matches = []
    
    for classification in known_groups.keys():
        bm = find_best_matches(results,known_groups[classification],all_samples,FDR=0.05,verbose = False)
        best_matches.append(bm)
            
    best_matches = pd.concat(best_matches, axis=0)
    best_matches = best_matches["J"].to_dict()
    
    for cl_name in classifications.keys():
        overall_performance = 0
        norm_factor = 0
        for group in classifications[cl_name]:
            overall_performance += best_matches[group]*freqs[group]
            norm_factor +=freqs[group]
        overall_performance = overall_performance/norm_factor 
        best_matches["overall_performance_"+cl_name] = overall_performance
    return best_matches

def compare_gene_clusters(tcga_result,metabric_result, N):
    # N - total number of genes
    # finds best matched TCGA -> METABRIC and METABRIC -> TCGA
    # calculates % of matched clusterst, number of genes in matched cluster, 
    # and the average J index for best matches 
    bm = find_best_matching_biclusters(tcga_result,metabric_result, N)
    bm = bm.dropna()
    bm2 = find_best_matching_biclusters(metabric_result, tcga_result, N)
    bm2 = bm2.dropna()
    
    bm = bm.loc[bm["n_shared"]>1,:].sort_values(by="n_shared",ascending = False)
    bm2 = bm2.loc[bm2["n_shared"]>1,:].sort_values(by="n_shared",ascending = False)
    
    
    clust_similarity = {}
    # number of biclusters 
    clust_similarity["n_1"] = tcga_result.shape[0]
    clust_similarity["n_2"] = metabric_result.shape[0]
    #print("% matched biclusters:",bm.shape[0]/tcga_result.shape[0],bm2.shape[0]/metabric_result.shape[0])
    clust_similarity["percent_matched_1"] = bm.shape[0]/tcga_result.shape[0]
    clust_similarity["percent_matched_2"] = bm2.shape[0]/metabric_result.shape[0]
    #print("n matched genes:",bm.loc[:,"n_shared"].sum(),bm2.loc[:,"n_shared"].sum())
    clust_similarity["n_shared_genes_1"] = bm.loc[:,"n_shared"].sum()
    clust_similarity["n_shared_genes_2"] = bm2.loc[:,"n_shared"].sum()
    #print("avg. J:",bm.loc[:,"J"].mean(),bm2.loc[:,"J"].mean())
    clust_similarity["avg_bm_J_1"] = bm.loc[:,"J"].mean()
    clust_similarity["avg_bm_J_2"] = bm2.loc[:,"J"].mean()
    
    
    return clust_similarity, bm, bm2


classifications={"Intrinsic":["Luminal","Basal","Her2","Normal","Claudin-low"],
                "SCMOD2":["ER-/HER2-","ER+/HER2- Low Prolif","ER+/HER2- High Prolif","HER2+"],
                "IHC":["IHC_TNBC","IHC_ER","IHC_HER2","IHC_PR"]}

file_metabric_annotation = '/local/DESMOND2_data/v6/preprocessed_v6/METABRIC_1904.annotation_v6.tsv'
file_metabric_expression = '/local/DESMOND2_data/v6/preprocessed_v6/METABRIC_1904_17Kgenes.log2_exprs_z_v6.tsv'
file_metabric_subtypes = '/local/DESMOND2_data/v6/preprocessed_v6/METABRIC_1904_17Kgenes.subtypes_and_signatures_v6.tsv'
file_tcga_annotation = '/local/DESMOND2_data/v6/preprocessed_v6/TCGA-BRCA_1079.Xena_TCGA_PanCan.annotation_v6.tsv'
file_tcga_expression = '/local/DESMOND2_data/v6/preprocessed_v6/TCGA-BRCA_1079_17Kgenes.Xena_TCGA_PanCan.log2_exprs_z_v6.tsv'
file_tcga_subtypes = '/local/DESMOND2_data/v6/preprocessed_v6/TCGA-BRCA_1079_17Kgenes.Xena_TCGA_PanCan.subtypes_and_signatures_v6.tsv'
file_gene_mapping = '/local/DESMOND2_data/v6/preprocessed_v6/gene_id_mapping.tsv'

# out_dir = '/home/hartung/data/preprocessed_v6/results's
out_dir = '/home/bba1401/data/unpast_real'

basename_t = "TCGA"
basename_m = "METABRIC" 


m_subtypes = pd.read_csv(file_metabric_subtypes,sep = "\t",index_col=0)
m_annotation = pd.read_csv(file_metabric_annotation,sep = "\t",index_col=0)

t_subtypes = pd.read_csv(file_tcga_subtypes,sep = "\t",index_col=0)
t_annotation = pd.read_csv(file_tcga_annotation,sep = "\t",index_col=0)


exprs_t= pd.read_csv(file_tcga_expression,sep = "\t",index_col=0)
exprs_t[exprs_t>3] = 3
exprs_t[exprs_t<-3] = -3

exprs_m= pd.read_csv(file_metabric_expression,sep = "\t",index_col=0)
exprs_m[exprs_m>3] = 3
exprs_m[exprs_m<-3] = -3

known_groups_t, freqs_t = make_ref_groups(t_subtypes, t_annotation,exprs_t)
known_groups_m, freqs_m = make_ref_groups(m_subtypes, m_annotation,exprs_m)

with open('mocluster_log.txt', 'w') as f:
    with redirect_stdout(f):
        
        METHODS = [NMF, sparse_PCA] # [NMF, sparse_PCA, moCluster, MOFA2]
        for METHOD in METHODS:
            method_name = METHOD.__name__.split('.')[-1]

            #### Preparation
            # METABRIC
            file_path_m = file_metabric_expression
            output_path_m = os.path.join(out_dir, basename_m, method_name)
            ground_truth_file_m = file_metabric_annotation
            combinations_m = METHOD.generate_arg_list(file_path_m, output_path_m, ground_truth_file_m)
            # TCGA
            file_path_t = file_tcga_expression
            output_path_t = os.path.join(out_dir, basename_t, method_name)
            ground_truth_file_t = file_tcga_annotation
            combinations_t = METHOD.generate_arg_list(file_path_t, output_path_t, ground_truth_file_t)


            #### Compute in parallel
            # Option to compute the results in parallel, methods will store results
            # Follow up with executing the 'Run' below to read existing results and evaluate
            if KERNEL > 1:
                with mp.Pool(KERNEL) as pool:
                    pool.map(METHOD.run_real, combinations_m + combinations_t)


            #### Run
            # Methods will compute results or read existing results
            # sanity check
            assert len(combinations_m) == len(combinations_t)
            subt_t = []
            subt_m = []
            clustering_similarities = []
            for comb_m, comb_t in zip(combinations_m, combinations_t):
                result_m, runtime_m = METHOD.run_real(comb_m)
                result_t, runtime_t = METHOD.run_real(comb_t)

                try:
                    performance_m = calculate_perfromance(result_m, known_groups_m,
                                                          freqs_m, set(exprs_m.columns.values),
                                                          classifications=classifications)
                    performance_m.update({'parameters': comb_m['output_path'], 'run': comb_m['random_state']})
                    performance_m['time'] = runtime_m
                except ZeroDivisionError:
                    performance_m = {}
                subt_m.append(performance_m)

                try:
                    performance_t = calculate_perfromance(result_t, known_groups_t,
                                                          freqs_t, set(exprs_t.columns.values),
                                                          classifications=classifications)
                    performance_t.update({'parameters': comb_t['output_path'], 'run': comb_t['random_state']})
                    performance_t['time'] = runtime_t
                except ZeroDivisionError:
                    performance_t = {}
                subt_t.append(performance_t)


            # save results
            pd.DataFrame.from_records(subt_m).to_csv(os.path.join(out_dir, basename_m, method_name, f'{method_name}_METABRIC.tsv'), sep="\t")    
            pd.DataFrame.from_records(subt_t).to_csv(os.path.join(out_dir, basename_t, method_name, f'{method_name}_TCGA.tsv'), sep="\t")
