In [None]:
import numpy as np
import pandas as pd
import sys, os
import random
import copy
from time import time
import matplotlib.pyplot as plt
import seaborn as sns
from utils.method import read_bic_table
from utils.eval import find_best_matches, make_known_groups, find_best_matching_biclusters
import glob
sys.path.insert(0, './evaluation/subsampling')
import settings

from run_desmond import run_DESMOND


def make_ref_groups(subtypes, annotation,exprs):
    # prepared a dict of subtype classifications {"class1":{"subt1":[],"subt2":[]},"class2":{"subtA":[],"subtB":[]}}
    all_samples = set(exprs.columns.values)
    pam50 = make_known_groups(subtypes, exprs,target_col = "PAM50",verbose=False)
    lum = {}
    lum["Luminal"] = pam50["LumA"].union(pam50["LumB"])
    scmod2 = make_known_groups(subtypes, exprs,target_col = 'SCMOD2',verbose=False)
    claudin = {} 
    claudin["Claudin-low"] = set(subtypes.loc[subtypes['claudin_low']==1,:].index.values).intersection(all_samples)
    
    ihc = {}
    for x in ["IHC_HER2","IHC_ER","IHC_PR"]:
        ihc[x] = set(annotation.loc[annotation[x]=="Positive",:].index.values)
    ihc["IHC_TNBC"] = set(annotation.loc[annotation["IHC_TNBC"]==1,:].index.values)
    
    known_groups = {"PAM50":pam50,"Luminal":lum,"Claudin-low":claudin,"SCMOD2":scmod2,"IHC":ihc}
    
    freqs = {}
    N =  exprs.shape[1]
    for classification in known_groups.keys():
        for group in known_groups[classification].keys():
            n = len(known_groups[classification][group])
            freqs[group] = n/N
            
    return known_groups, freqs

def calculate_perfromance(results, known_groups, freqs, all_samples,
                          classifications={"Intrinsic":["Luminal","Basal","Her2","Normal","Claudin-low"]}):
    # finds best matches for each subtype, calcuates J per subtype and overall performance
    N = len(all_samples)
    best_matches = []
    
    for classification in known_groups.keys():
        bm = find_best_matches(results,known_groups[classification],all_samples,FDR=0.05,verbose = False)
        best_matches.append(bm)
            
    best_matches = pd.concat(best_matches, axis=0)
    best_matches = best_matches["J"].to_dict()
    
    for cl_name in classifications.keys():
        overall_performance = 0
        norm_factor = 0
        for group in classifications[cl_name]:
            overall_performance += best_matches[group]*freqs[group]
            norm_factor +=freqs[group]
        overall_performance = overall_performance/norm_factor 
        best_matches["overall_performance_"+cl_name] = overall_performance
    return best_matches

def compare_gene_clusters(tcga_result,metabric_result, N):
    # N - total number of genes
    # finds best matched TCGA -> METABRIC and METABRIC -> TCGA
    # calculates % of matched clusterst, number of genes in matched cluster, 
    # and the average J index for best matches 
    bm = find_best_matching_biclusters(tcga_result,metabric_result, N)
    bm = bm.dropna()
    bm2 = find_best_matching_biclusters(metabric_result, tcga_result, N)
    bm2 = bm2.dropna()
    
    bm = bm.loc[bm["n_shared"]>1,:].sort_values(by="n_shared",ascending = False)
    bm2 = bm2.loc[bm2["n_shared"]>1,:].sort_values(by="n_shared",ascending = False)
    
    clust_similarity = {}
    # number of biclusters 
    clust_similarity["n_1"] = tcga_result.shape[0]
    clust_similarity["n_2"] = metabric_result.shape[0]
    #print("% matched biclusters:",bm.shape[0]/tcga_result.shape[0],bm2.shape[0]/metabric_result.shape[0])
    clust_similarity["percent_matched_1"] = bm.shape[0]/tcga_result.shape[0]
    clust_similarity["percent_matched_2"] = bm2.shape[0]/metabric_result.shape[0]
    #print("n matched genes:",bm.loc[:,"n_shared"].sum(),bm2.loc[:,"n_shared"].sum())
    clust_similarity["n_shared_genes_1"] = bm.loc[:,"n_shared"].sum()
    clust_similarity["n_shared_genes_2"] = bm2.loc[:,"n_shared"].sum()
    #print("avg. J:",bm.loc[:,"J"].mean(),bm2.loc[:,"J"].mean())
    clust_similarity["avg_bm_J_1"] = bm.loc[:,"J"].mean()
    clust_similarity["avg_bm_J_2"] = bm2.loc[:,"J"].mean()

    return clust_similarity, bm, bm2




In [None]:
def run_on_subsampled(dataset_name, subsample_factor, min_samples, seeds=None):
    assert seeds is None or isinstance(seeds, list)
    if seeds is None:
        seed_folders = glob.glob(os.path.join(settings.OUTPUT_FOLDER, dataset_name, 'seed=*', f'factor={subsample_factor}', f'min_samples={min_samples}'))
        
        print(f'Found {len(seed_folders)} seeds for {dataset_name} with factor {subsample_factor} and min_samples {min_samples}')
    print(f'Found {len(seed_folders)}. {seed_folders}')
    for seed_folder in seed_folders:
        print('Running unpast for seed_folder', seed_folder)
        for x in seed_folder.split('/'):
            if x.startswith('seed='):
                seed = x.replace('seed=', '')
                break
        subtype_path = os.path.join(seed_folder, 'subtypes.tsv')
        annoation_path = os.path.join(seed_folder, 'annotation.tsv')
        expression_path = os.path.join(seed_folder, 'expression.tsv')
        
        subtype_df = pd.read_csv(subtype_path,sep = "\t",index_col=0)
        annoation_df = pd.read_csv(annoation_path,sep = "\t",index_col=0)
        exprs_df = pd.read_csv(expression_path,sep = "\t",index_col=0)
        
        clustering_similarities_df = _run_unpast(dataset_name, expression_path, subtype_df, annoation_df, exprs_df)
        
        filename = os.path.join(seed_folder, f'unpast_result_seed={seed}.tsv')
        print('Saving outfile', filename)
        clustering_similarities_df.to_csv(filename, sep='\t')
    return
        

def _run_unpast(dataset_name, expression_path, subtype_df, annoation_df, exprs_df):
    classifications={"Intrinsic":["Luminal","Basal","Her2","Normal","Claudin-low"],
                "SCMOD2":["ER-/HER2-","ER+/HER2- Low Prolif","ER+/HER2- High Prolif","HER2+"],
                "IHC":["IHC_TNBC","IHC_ER","IHC_HER2","IHC_PR"]}

    known_groups, freqs = make_ref_groups(subtype_df, annoation_df, exprs_df)
    
    n_runs = 5
    seeds = []
    random.seed(101)
    for i in range(n_runs):
        seeds.append(random.randint(0, 1000000))
    print("generate ",n_runs," seeds",seeds)

    best_params = settings.UNPAST_BEST_PARAMS['OVERALL']

    ### Louvain 
    out_dir= '/'.join(expression_path.split('/')[:-1])
    out_dir = os.path.join(out_dir, ';'.join(['='.join([str(a), str(b)]) for a, b in settings.UNPAST_BEST_PARAMS['OVERALL'].items()]))
    modularities = [0,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
    print(out_dir)
    if not os.path.exists(out_dir):
            os.makedirs(out_dir)
            
    subt = []
    clustering_similarities = []
    for run in range(n_runs):
        seed = seeds[run]
        print("Running unpast: Iteration ",run, 'with', best_params, expression_path)

        # save parameters as a ;-separated string
        params = f"bin={best_params['bin_method']};pval={best_params['pval']}"
        params += f";clust={best_params['clust_method']};ds={best_params['ds']}"
        params_dict = {"parameters":params, "seed":seed,"run":run}

        ### running TCGA or reading results
        try:
            result = run_DESMOND(
                expression_path, 
                dataset_name, 
                **best_params,
                out_dir=out_dir,
                save=True, 
                load = True,
                ceiling =3,
                min_n_samples = 5,
                cluster_binary=False,
                seed = seed,
                verbose = False, 
                plot_all = False,
                merge = 1)
            # find the best matches between TCGA biclusters and subtypes
            # and calculate overall performance == weighted sum of Jaccard indexes
            performance = calculate_perfromance(result, known_groups,
                                                  freqs, set(exprs_df.columns.values),
                                                  classifications=classifications)
            performance.update(params_dict)
            performance["time"] = time
            subt.append(performance)
            failed = False
        except Exception as e:
            print(e)
            print("biclustering failed with ",seed, params,file = sys.stderr)
            failed = True
            subt.append(params_dict)

    return pd.DataFrame.from_records(subt)
    
    

In [None]:
# run_on_subsampled('METABRIC', .5, 10)
run_on_subsampled('TCGA', .5, 10)

In [None]:
run_on_subsampled('METABRIC', .3, 10)
run_on_subsampled('TCGA', .3, 10)

In [None]:
run_on_subsampled('METABRIC', .1, 10)
run_on_subsampled('TCGA', .1, 10)

In [None]:
run_on_subsampled('METABRIC', .05, 0)
run_on_subsampled('TCGA', .05, 0)

In [None]:
run_on_subsampled('METABRIC', .4, 10)
run_on_subsampled('TCGA', .4, 10)

In [None]:
run_on_subsampled('METABRIC', .2, 10)
run_on_subsampled('TCGA', .2, 10)