In [None]:
import os
import gc
import time
import collections
import json
import itertools
import functools
import datetime
import warnings
import re

In [None]:
from tqdm.notebook import tqdm

In [None]:
import numpy as np
import scipy as sp
import scipy.stats
import pandas as pd

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def get_hex(colour_series:pd.Series):
    return colour_series.map(lambda rgb: "#" + "".join(map(lambda c: hex(int(c*256))[2:], rgb)))

In [None]:
import sklearn.model_selection
import sklearn.pipeline
import sklearn.preprocessing
import sklearn.feature_selection
import sklearn.metrics
import sklearn.decomposition
import sklearn.manifold
import sklearn.linear_model
import sklearn.svm
import sklearn.ensemble
import sklearn.base
import sklearn.cluster
import umap
import lifelines

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import skorch
import skorch.scoring

In [None]:
from pathwayae.models import MLP, Autoencoder, VAE, PAAE, PAVAE, NopLayer
from pathwayae.skorch_utils import ScoredNeuralNetAutoencoder

from pathwayae.losses import AE_MSELoss, VAELoss, build_beta_schedule

from pathwayae.utils import from_log2pk, to_log2pk, fpkm_to_tpm, sample_wise_preprocess_fn, sigmoid, logcurve, logcurve_start_end

from pathwayae.pathway_utils import read_pathway_from_json_file

In [None]:
data_folder = os.path.expanduser("~/data/") 
tcga_folder = os.path.join(data_folder, "pathwayae", "tcga")
meta_folder = os.path.join(data_folder, "pathwayae", "metabric")
os.makedirs(tcga_folder, exist_ok=True)

In [None]:
cancer_type = "BRCA"

In [None]:
metabric_ensembl_counts_gex_tsv_fname = os.path.join(meta_folder, f"data_mrna_agilent_microarray.txt.gz")
gex_meta = pd.read_csv(metabric_ensembl_counts_gex_tsv_fname, sep="\t", index_col="Hugo_Symbol").drop(columns="Entrez_Gene_Id")
gex_meta.index.rename("SampleID", inplace=True)
gex_meta = gex_meta.T.dropna(axis="columns")
gex_meta.columns = gex_meta.columns.str.upper()

gex_meta = to_log2pk(fpkm_to_tpm(from_log2pk(gex_meta)))

gex_meta.shape

In [None]:
gex_meta.index[:10]

In [None]:
metabric_phenotype_csv_fname = os.path.join(meta_folder, f"data_clinical_patient.txt.gz")
metabric_phenotype = pd.read_csv(metabric_phenotype_csv_fname, sep="\t", comment="#", index_col="PATIENT_ID")
metabric_phenotype.index.rename("SampleID", inplace=True)
metabric_phenotype.rename(columns={"CLAUDIN_SUBTYPE": "PAM50"}, inplace=True)


In [None]:
metabric_phenotype_csv_fname = os.path.join(meta_folder, f"data_clinical_patient.txt.gz")
metabric_phenotype = pd.read_csv(metabric_phenotype_csv_fname, sep="\t", comment="#", index_col="PATIENT_ID")
metabric_phenotype.index.rename("SampleID", inplace=True)
metabric_phenotype.rename(columns={"CLAUDIN_SUBTYPE": "PAM50"}, inplace=True)
metabric_phenotype = metabric_phenotype[["PAM50","OS_MONTHS","OS_STATUS"]]
metabric_phenotype["OS_STATUS"].replace({"0:LIVING":0,"1:DECEASED":1}, inplace=True)
metabric_phenotype.rename(columns={'OS_MONTHS':"OS_Time",'OS_STATUS':"OS_Event"}, inplace=True)
metabric_phenotype["OS_Time"] *= 365/12
metabric_phenotype.dropna(subset=["PAM50"], inplace=True)
metabric_phenotype.drop(index=metabric_phenotype.index[metabric_phenotype["PAM50"] == "claudin-low"], inplace=True)
metabric_phenotype.drop(index=metabric_phenotype.index[metabric_phenotype["PAM50"] == "NC"], inplace=True)

for c in metabric_phenotype.columns:
    print(c, metabric_phenotype[c].count(), metabric_phenotype[c].value_counts())

metabric_phenotype.index[:10]

In [None]:
assert sum(metabric_phenotype.PAM50.isna())==0

In [None]:
metabric_phenotype["PAM50"].hist()

In [None]:
ensembl_fpkm_gex_tsv_fname = os.path.join(tcga_folder, f"TCGA-{cancer_type}.htseq_fpkm.tsv.gz")
gex = pd.read_csv(ensembl_fpkm_gex_tsv_fname, sep="\t", index_col="Ensembl_ID").T.dropna(axis="columns")
gex = to_log2pk(fpkm_to_tpm(from_log2pk(gex)))

zero_variance_columns = set(gex.var()[gex.var()==0].index)

gex = gex.drop(columns=list(zero_variance_columns))

with open(os.path.join(tcga_folder, "ensembl_to_gene_id.json")) as f:
    ensembl_to_gex_dict = json.load(f)

if True:
    columns_to_drop = [k for k in ensembl_to_gex_dict if ensembl_to_gex_dict[k]=="" and k not in zero_variance_columns]

    gex = gex.drop(columns=columns_to_drop)

    
gex.shape

In [None]:
if True:
    value_counts = {}
    for k,v in ensembl_to_gex_dict.items():
        if v == '':
            continue
        elif k in gex.columns:
            if v in value_counts:
                value_counts[v].append(k)
            else:
                value_counts[v] = [k]
    value_counts = {k:v for k,v in value_counts.items() if len(v)>1}
    print(len(value_counts), sum((len(v) for v in value_counts.values())))

In [None]:
duplicate_and_in_both = ["MATR3", "EMG1", "TMSB15B", "BMS1P4", "POLR2J4",]

for v in duplicate_and_in_both:
    display(gex[value_counts[v]].describe())
    display(gex[value_counts[v]].corr())
    display(gex_meta[v].describe())

[(v,value_counts[v]) for v in duplicate_and_in_both]

In [None]:
# Merge values that map to the same gene symbol
for g_name, g_ensembl_ids in value_counts.items():
    n_ambiguous = len(g_ensembl_ids)
    gex[g_name] = to_log2pk(from_log2pk(gex[g_ensembl_ids], 1).mean(axis=1), 1)
    gex.drop(columns=g_ensembl_ids, inplace=True)

In [None]:
# Rename the rest
gex = gex.rename(columns=ensembl_to_gex_dict)

gex.columns.rename("GeneName", inplace=True)
gex.index.rename("SampleID", inplace=True)

In [None]:
gex.shape, gex_meta.shape

In [None]:
genes_in_both = sorted(set(gex_meta.columns).intersection(set(gex.columns)))
len(genes_in_both)

In [None]:
gex = gex.loc[:,genes_in_both]
gex_meta = gex_meta.loc[:,genes_in_both]
gex.shape, gex_meta.shape

In [None]:
assert all(map(lambda x: x[0]==x[1], list(zip(gex.columns, gex_meta.columns))))

In [None]:
pathway_folder = os.path.join(data_folder, "pathways")
os.makedirs(pathway_folder, exist_ok=True)

In [None]:
gex_genes = set(gex.columns.values)
gex_genes_indexer = {v:i for i,v in enumerate(gex.columns.values)}
get_pathways_with_indices = lambda pathways: [[gex_genes_indexer[gene] for gene in pathway] for pathway in pathways]

In [None]:
kegg_pathways = read_pathway_from_json_file(os.path.join(pathway_folder,"c2.cp.kegg.v7.5.1.json"), gex_genes)
kegg_pathways_with_indices = get_pathways_with_indices(kegg_pathways)
number_of_pathways = len(kegg_pathways_with_indices)
pathways_input_dimension = sum((len(pathway) for pathway in kegg_pathways_with_indices))
number_of_input_genes = len(functools.reduce(lambda acc_p, p: acc_p.union(set(p)), kegg_pathways_with_indices, set()))
number_of_pathways, pathways_input_dimension, number_of_input_genes

In [None]:
hallmark_pathways = read_pathway_from_json_file(os.path.join(pathway_folder,"h.all.v7.5.1.json"), gex_genes)
hallmark_pathways_with_indices = get_pathways_with_indices(hallmark_pathways)
number_of_pathways = len(hallmark_pathways_with_indices)
pathways_input_dimension = sum((len(pathway) for pathway in hallmark_pathways_with_indices))
number_of_input_genes = len(functools.reduce(lambda acc_p, p: acc_p.union(set(p)), hallmark_pathways_with_indices, set()))
number_of_pathways, pathways_input_dimension, number_of_input_genes

In [None]:
ensembl_fpkm_phenotype_tsv_fname = os.path.join(tcga_folder, f"TCGA.{cancer_type}.sampleMap_{cancer_type}_clinicalMatrix")
phenotype = pd.read_csv(ensembl_fpkm_phenotype_tsv_fname, sep="\t", index_col="sampleID")
phenotype.index = phenotype.index.rename("SampleID")
phenotype = phenotype[[c for c in phenotype.columns if "pam50" in c.lower()] + ['OS_Time_nature2012', 'OS_event_nature2012',]]
phenotype.rename(columns={'OS_Time_nature2012':"OS_Time",'OS_event_nature2012':"OS_Event"}, inplace=True)
for c in phenotype.columns:
    print(phenotype[c].count(), phenotype[c].value_counts())

In [None]:
phenotype_clf_tgt = "PAM50Call_RNAseq"
phenotype_clf_tgt_meta = "PAM50"
phenotype_clf_map = {
    "LumA":0,
    "LumB":1,
    "Basal":2,
    "Normal":3,
    "Her2":4,
}
phenotype_clf_nan = {f"{value}":np.nan for value in [np.nan, "not reported", ""]}
phenotype.columns

In [None]:
# Drop nan
PHENOTYPE_CLF_COLUMN = "subtype"
phenotype[PHENOTYPE_CLF_COLUMN] = phenotype[phenotype_clf_tgt].replace(phenotype_clf_nan)
phenotype = phenotype.dropna(subset=[PHENOTYPE_CLF_COLUMN])
phenotype.columns

In [None]:
_possible_mappings = {idx:[] for idx in phenotype.index}
for idx in phenotype.index:
    for v in gex[gex.index.str.startswith(idx)].index.values:
        _possible_mappings[idx].append(v)
_replacements = {k:sorted(v)[0] for k,v in _possible_mappings.items() if len(v)>0}
phenotype = phenotype.rename(index=_replacements, inplace=False)
phenotype.columns

In [None]:
both_index = sorted(set(phenotype.index).intersection(gex.index))
[(len(idx), idx[:5],) for idx in [gex.index, phenotype.index, both_index]]

In [None]:
both_index_meta = sorted(set(metabric_phenotype.index).intersection(gex_meta.index))
[(len(idx), idx[:5],) for idx in [gex_meta.index, metabric_phenotype.index, both_index_meta]]

In [None]:
gex = gex.loc[both_index]
gex_meta = gex_meta.loc[both_index_meta]

full_phenotype = phenotype
phenotype_meta = full_phenotype_meta = metabric_phenotype

phenotype = phenotype.loc[both_index,[PHENOTYPE_CLF_COLUMN,]]
phenotype_meta = full_phenotype_meta.loc[both_index_meta,["PAM50"]]
phenotype[PHENOTYPE_CLF_COLUMN] = phenotype[PHENOTYPE_CLF_COLUMN].replace(phenotype_clf_map)
phenotype_meta[PHENOTYPE_CLF_COLUMN] = phenotype_meta["PAM50"].replace(phenotype_clf_map)
phenotype[phenotype_clf_tgt] = full_phenotype.loc[phenotype.index,phenotype_clf_tgt]
phenotype_meta[phenotype_clf_tgt_meta] = phenotype_meta.loc[phenotype_meta.index,phenotype_clf_tgt_meta]


(
    phenotype[PHENOTYPE_CLF_COLUMN].value_counts(), phenotype[PHENOTYPE_CLF_COLUMN].dtype, phenotype[PHENOTYPE_CLF_COLUMN].unique(), phenotype[PHENOTYPE_CLF_COLUMN].describe(),
    "\n",
    phenotype_meta[PHENOTYPE_CLF_COLUMN].value_counts(), phenotype_meta[PHENOTYPE_CLF_COLUMN].dtype, phenotype_meta[PHENOTYPE_CLF_COLUMN].unique(), phenotype_meta[PHENOTYPE_CLF_COLUMN].describe(),
)

In [None]:
assert(all((gi==pi for gi,pi in zip(gex.index.to_list(), phenotype.index.to_list()))))

In [None]:
assert(all((gi==pi for gi,pi in zip(gex_meta.index.to_list(), phenotype_meta.index.to_list()))))

In [None]:
def plot_2d_space(df, SpaceTransformer=sklearn.decomposition.PCA, **kwargs):
    values = SpaceTransformer().fit_transform(df)
    return sns.scatterplot(x=values[:,0], y=values[:,1], **kwargs)

In [None]:
genes_dim = gex.values.shape[1]
gex.values.shape[1]

In [None]:
gex_X = sklearn.preprocessing.quantile_transform(gex.values, n_quantiles=max(*gex.values.shape), output_distribution="normal")
gex_meta_X = sklearn.preprocessing.quantile_transform(gex_meta.values, n_quantiles=max(*gex_meta.values.shape), output_distribution="normal")

In [None]:
kegg_p = [torch.tensor(pathway) for pathway in kegg_pathways_with_indices]
hmrk_p = [torch.tensor(pathway) for pathway in hallmark_pathways_with_indices]

pways_keys_lst = ["KEGG", "Hallmark Genes",]
pways_defs_lst = [kegg_p, hmrk_p,]
pways_defs_dict = dict(zip(pways_keys_lst,pways_defs_lst))

In [None]:
hallmark_pathway_description_path = os.path.join(pathway_folder,"h.all.v7.5.1.json")

with open(hallmark_pathway_description_path, "r") as f:
    hallmark_pathway_descriptions = json.load(f)
hallmark_pathway_genes = [(k,hallmark_pathway_descriptions[k]["geneSymbols"]) for k in hallmark_pathway_descriptions]
hallmark_all_pathway_genes = functools.reduce(lambda acc, v: acc.union(set(v[1])), hallmark_pathway_genes, set())
hallmark_common_genes = hallmark_all_pathway_genes.intersection(gex_genes)
hallmark_pathway_genes_with_allowed_genes = [(k,[gene for gene in pathway if gene in hallmark_common_genes]) for k, pathway in hallmark_pathway_genes]

hallmark_pway_names = [k for (k,p), pi in zip(hallmark_pathway_genes_with_allowed_genes,hallmark_pathways_with_indices)]
assert(all([(len(p)==len(pi)) for (k,p), pi in zip(hallmark_pathway_genes_with_allowed_genes,hallmark_pathways_with_indices)]))
hallmark_pway_names = [pname.replace("HALLMARK_","") for pname in hallmark_pway_names]
len(hallmark_pway_names), hallmark_pway_names[:5]

In [None]:
kegg_pathway_description_path = os.path.join(pathway_folder,"c2.cp.kegg.v7.5.1.json")

with open(kegg_pathway_description_path, "r") as f:
    kegg_pathway_descriptions = json.load(f)
kegg_pathway_genes = [(k,kegg_pathway_descriptions[k]["geneSymbols"]) for k in kegg_pathway_descriptions]
kegg_all_pathway_genes = functools.reduce(lambda acc, v: acc.union(set(v[1])), kegg_pathway_genes, set())
kegg_common_genes = kegg_all_pathway_genes.intersection(gex_genes)
kegg_pathway_genes_with_allowed_genes = [(k,[gene for gene in pathway if gene in kegg_common_genes]) for k, pathway in kegg_pathway_genes]

kegg_pway_names = [k for (k,p), pi in zip(kegg_pathway_genes_with_allowed_genes,kegg_pathways_with_indices)]
assert(all([(len(p)==len(pi)) for (k,p), pi in zip(kegg_pathway_genes_with_allowed_genes,kegg_pathways_with_indices)]))
kegg_pway_names = [pname.replace("KEGG_","") for pname in kegg_pway_names]
len(kegg_pway_names), kegg_pway_names[:5]

In [None]:
pathway_set_to_use = "KEGG"
base_model_name_to_use = "PAAE-[32]-[64]"

pathway_set_to_use_dict = {
    "Hallmark Genes": hallmark_pathway_genes_with_allowed_genes,
    "KEGG": kegg_pathway_genes_with_allowed_genes,
}

clustermap_k_map = {
    "KEGG": 32,
    "Hallmark Genes": 50
}

featuremap_k_map = {
    "KEGG": 32,
    "Hallmark Genes": 50
}

violinplot_k_map = {
    "KEGG": 32,
    "Hallmark Genes": 50
}

pways_to_look_at_map = {
    "KEGG": ["KEGG_GLYCOSAMINOGLYCAN_BIOSYNTHESIS_KERATAN_SULFATE", "KEGG_SPHINGOLIPID_METABOLISM", "KEGG_VALINE_LEUCINE_AND_ISOLEUCINE_DEGRADATION", "KEGG_P53_SIGNALING_PATHWAY", "KEGG_PANCREATIC_CANCER", "KEGG_GLIOMA", "KEGG_MISMATCH_REPAIR"],
    "Hallmark Genes": ["HALLMARK_P53_PATHWAY", "HALLMARK_UV_RESPONSE_DN", "HALLMARK_CHOLESTEROL_HOMEOSTASIS", "HALLMARK_APICAL_SURFACE", "HALLMARK_MTORC1_SIGNALING"]
}

pway_names_and_genes = pathway_set_to_use_dict[pathway_set_to_use]
pways_to_look_at = pways_to_look_at_map[pathway_set_to_use]
base_model_name = f"{base_model_name_to_use} ({pathway_set_to_use})"
base_model_name

In [None]:
model_name = base_model_name
model_fname = model_name.replace('(','_').replace(')','_').replace('-','_').replace(' ','')
pathway_hidden_dims = list(map(int, re.findall("\[(.*?)\]", model_name)[0].split(",")))
hidden_and_enc_dims = list(map(int, re.findall("\[(.*?)\]", model_name)[1].split(",")))
model_fname, hidden_and_enc_dims, pathway_hidden_dims

In [None]:
results_folder = "results/metabric"
os.makedirs(results_folder, exist_ok=True)
images_folder = "images/metabric"
os.makedirs(images_folder, exist_ok=True)
SAVING_FORMATS = ["png", "pdf", "svg"]
for fmt in SAVING_FORMATS: os.makedirs(os.path.join(images_folder,fmt), exist_ok=True)

In [None]:
if "PAVAE" in model_name:
    ModelClass = PAVAE
    model_kwargs = {
        "genes_dim": genes_dim,
        "pathway_definitions": pways_defs_dict[pathway_set_to_use],
        "pathway_hidden_dims": pathway_hidden_dims,
        "hidden_dims": hidden_and_enc_dims[:-1],
        "encoding_dim": hidden_and_enc_dims[-1],
    }
elif "PAAE" in model_name:
    ModelClass = PAAE
    genes_dim = gex_X.shape[1]
    model_kwargs = {
        "genes_dim": genes_dim,
        "pathway_definitions": pways_defs_dict[pathway_set_to_use],
        "pathway_hidden_dims": pathway_hidden_dims,
        "hidden_dims": hidden_and_enc_dims[:-1],
        "encoding_dim": hidden_and_enc_dims[-1],
    }
else:
    raise ValueError(f"{model_name} does not exist")

model = ModelClass(**model_kwargs)
model.load_state_dict(torch.load(os.path.join(results_folder,f"{model_fname}.pt")))
model.eval()

gex_a = model.get_pathway_activities(torch.tensor(gex_X, dtype=torch.float32)).detach().numpy()
gex_meta_a = model.get_pathway_activities(torch.tensor(gex_meta_X, dtype=torch.float32)).detach().numpy()

if "PAVAE" in model_name:
    gex_z = model.encode(torch.tensor(gex_X, dtype=torch.float32))[0].detach().numpy()
    gex_meta_z = model.encode(torch.tensor(gex_meta_X, dtype=torch.float32))[0].detach().numpy()
elif "PAAE" in model_name:
    gex_z = model(torch.tensor(gex_X, dtype=torch.float32))[0].detach().numpy()
    gex_meta_z = model(torch.tensor(gex_meta_X, dtype=torch.float32))[0].detach().numpy()
else:
    raise ValueError(f"{model_name} does not exist")

len(model.pathway_encoders), gex_a.shape, gex_z.shape, gex_meta_a.shape, gex_meta_z.shape

In [None]:
def get_simple_feature_importance(E:MLP) -> np.ndarray:
    return functools.reduce(
        (lambda acc, l: l.weight.detach().numpy() @ acc),
        E.layers[1:],
        E.layers[0].weight.detach().numpy()
    ).squeeze()

def get_simple_abs_feature_importance(E:MLP) -> np.ndarray:
    return functools.reduce(
        (lambda acc, l: np.abs(l.weight.detach().numpy()) @ acc),
        E.layers[1:],
        np.abs(E.layers[0].weight.detach().numpy())
    ).squeeze()

In [None]:
pathway_importances = list(zip(pway_names_and_genes,map(get_simple_feature_importance, model.pathway_encoders)))
pathways = [p for (p,g),i in pathway_importances]
genes = sorted(functools.reduce(lambda acc, genes: acc.union(genes), [g for (p,g),i in pathway_importances], set()))
gene_to_idx_map = dict(zip(genes, range(len(genes))))

feature_importance_matrix = np.full((len(pathway_importances), len(genes)), np.nan)

for pwy_idx, ((pwy,gen),imp) in enumerate(pathway_importances):
    gen_idxs = [gene_to_idx_map[g] for g in gen]
    feature_importance_matrix[pwy_idx, gen_idxs] = imp

pathway_feature_importance_df = pd.DataFrame(data=feature_importance_matrix, columns=genes, index=pathways)
pathway_feature_importance_df

In [None]:
pathway_importances = list(zip(pway_names_and_genes,map(get_simple_abs_feature_importance, model.pathway_encoders)))
pathways = [p for (p,g),i in pathway_importances]
genes = sorted(functools.reduce(lambda acc, genes: acc.union(genes), [g for (p,g),i in pathway_importances], set()))
gene_to_idx_map = dict(zip(genes, range(len(genes))))

feature_importance_matrix = np.full((len(pathway_importances), len(genes)), np.nan)

for pwy_idx, ((pwy,gen),imp) in enumerate(pathway_importances):
    gen_idxs = [gene_to_idx_map[g] for g in gen]
    feature_importance_matrix[pwy_idx, gen_idxs] = imp

pathway_feature_importance_abs_df = pd.DataFrame(data=feature_importance_matrix, columns=genes, index=pathways)
pathway_feature_importance_abs_df

In [None]:
imp_k_max = 20
imp_k_range = list(range(2, imp_k_max+1, 2))

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    print("\\multirow{2}{*}{"+pway_name.replace("HALLMARK_","").replace("KEGG_","").replace("_","\\_")+"}")
    print("", *most_important_anpw.index, sep=" & ", end=" \\\\\n")
    print("", *map(lambda x: f"{x:+.2f}", most_important_anpw.values), sep=" & ", end=" \\\\\n")
    print("\\midrule")

In [None]:
pam_50_genes = ["FOXC1", "MIA", "KNTC2", "CEP55", "ANLN", "MELK", "GPR160", "TMEM45B", "ESR1", "FOXA1", "ERBB2", "GRB7", "FGFR4", "BLVRA", "BAG1", "CDC20", "CCNE1", "ACTR3B", "MYC", "SFRP1", "KRT14", "KRT17", "KRT5", "MLPH", "CCNB1", "CDC6", "TYMS", "UBE2T", "RRM2", "MMP11", "CXXC5", "ORC6L", "MDM2", "KIF2C", "PGR", "MKI67", "BCL2", "EGFR", "PHGDH", "CDH3", "NAT1", "SLC39A6", "MAPT", "UBE2C", "PTTG1", "EXO1", "CENPF", "CDCA1", "MYBL2", "BIRC5"]
len(list(filter(lambda x: x in pathway_feature_importance_df.columns, pam_50_genes)))

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    print(pway_name)
    print(*filter(lambda x: x in pam_50_genes, most_important_anpw.index))

In [None]:
np.log10(0.0501187)

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    p_values_tcga = [
        lifelines.CoxPHFitter().fit(gex.loc[:,[feat]].join(full_phenotype[["OS_Time", "OS_Event"]]).dropna(), duration_col="OS_Time", event_col="OS_Event").summary.loc[feat,"p"]
        for feat in most_important_anpw.index
    ]
    p_values_meta = [
        lifelines.CoxPHFitter().fit(gex_meta.loc[:,[feat]].join(full_phenotype_meta[["OS_Time", "OS_Event"]]).dropna(), duration_col="OS_Time", event_col="OS_Event").summary.loc[feat,"p"]
        for feat in most_important_anpw.index
    ]
    all = list(zip(most_important_anpw.index, p_values_tcga, p_values_meta))
    significant_both = list(filter((lambda x: x[1]<=0.05 and x[2]<=0.05), all))
    significant_any = list(filter((lambda x: x[1]<=0.05 or x[2]<=0.05), all))
    significant_tcga = list(filter((lambda x: x[1]<=0.05), all))
    significant_meta = list(filter((lambda x: x[2]<=0.05), all))
    print(pway_name, *map(lambda x: len(x)/len(most_important_anpw.index), [significant_both, significant_any, significant_tcga, significant_meta]))
    print(*map(lambda x: f"{x[0]} {np.log10(x[1]):.1f} {np.log10(x[2]):.1f}", all))
    print(*map(lambda x: x[0], significant_both))
    print(*map(lambda x: x[0], significant_tcga))
    print(*map(lambda x: x[0], significant_meta))
    print()

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(ascending=False)
    sns.barplot(x=most_important_anpw.values, y=most_important_anpw.index, hue=[("+" if b else "-") for b in (most_important_anpw.values>0)], dodge=False)
    plt.xlabel("Neural Path Weight")
    plt.legend([],[])
    plt.title(pway_name)
    for fmt in SAVING_FORMATS: plt.savefig(os.path.join(images_folder,fmt,f"importantfeatures_mostimportant_{model_fname}_{pway_name}.{fmt}"), bbox_inches="tight")
    plt.close()

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    pathway_feature_importance_df.loc[pway_name].dropna().hist()
    plt.xlabel("Neural Path Weight")
    for fmt in SAVING_FORMATS: plt.savefig(os.path.join(images_folder,fmt,f"importantfeatures_weightdist_{model_fname}_{pway_name}.{fmt}"), bbox_inches="tight")
    plt.close()

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    print("\\multirow{2}{*}{"+pway_name.replace("HALLMARK_","").replace("KEGG_","").replace("_","\\_")+"}")
    print("", *most_important_anpw.index, sep=" & ", end=" \\\\\n")
    print("", *map(lambda x: f"{x:+.6f}", most_important_anpw.values), sep=" & ", end=" \\\\\n")
    print("\\midrule")

In [None]:
# Create a categorical palette to identify the networks
unique_phenotypes = phenotype[phenotype_clf_tgt].unique()

phenotype_pal = sns.color_palette(n_colors=len(unique_phenotypes))
phenotype_lut = dict(zip(unique_phenotypes, phenotype_pal))

# Create a categorical palette to identify the networks
unique_pways = [p for p in pways_to_look_at]

pway_pal = sns.palettes.color_palette(sns.color_palette(n_colors=len(unique_pways)+len(unique_phenotypes))[len(unique_phenotypes):])
pway_lut = dict(zip(unique_pways, pway_pal))

display(phenotype_pal)
display(pway_pal)


In [None]:
genes_in_all_pathways_src_pathway = []
genes_in_all_pathways = []

for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    genes_in_all_pathways_src_pathway.extend([pway_name]*len(most_important_anpw.index))
    genes_in_all_pathways.extend(most_important_anpw.index)

    for gex_df, pheno_series, dset_label in [[gex, phenotype[phenotype_clf_tgt], "TCGA"], [gex_meta, phenotype_meta[phenotype_clf_tgt_meta], "Metabric"]]:
        phenotype_colors = pd.Series(pheno_series, name="PAM50").map(phenotype_lut)

        plot_df = gex_df.loc[:,most_important_anpw.index]

        g = sns.clustermap(plot_df.T, 
                        col_colors=phenotype_colors,
                        col_cluster=True,
                        row_cluster=False,
                        cmap="vlag",
                        cbar_pos=(0.1, .2, .03, .4),
        )
        old_yticks = g.ax_heatmap.get_yticks()
        new_yticks = np.arange(min(old_yticks),min(old_yticks)+len(plot_df.columns),1)
        g.ax_col_colors.set_title(pway_name)
        g.ax_heatmap.set_xticklabels([])
        g.ax_heatmap.set_xlabel(dset_label)
        handles = [Patch(facecolor=phenotype_lut[name]) for name in phenotype_lut]
        g.ax_row_dendrogram.remove()
        g.ax_col_dendrogram.remove()
        plt.legend(
            handles,
            phenotype_lut,
            title='PAM50',
            bbox_to_anchor=(0.2, 0.7),
            bbox_transform=plt.gcf().transFigure,
            loc='right')
        for fmt in SAVING_FORMATS: plt.savefig(os.path.join(images_folder,fmt,f"importantfeatures_heatmap_{model_fname}_{pway_name}_{dset_label}.{fmt}"), bbox_inches="tight")
        plt.close()

In [None]:
# Limit to 50 biggest
plot_genes_in_all_pathways_src_pathway = genes_in_all_pathways_src_pathway[:50]
plot_genes_in_all_pathways = genes_in_all_pathways[:50]

for gex_df, pheno_series, dset_label in [[gex, phenotype[phenotype_clf_tgt], "TCGA"], [gex_meta, phenotype_meta[phenotype_clf_tgt_meta], "Metabric"]]:
    phenotype_colors = pd.Series(pheno_series, name="PAM50").map(phenotype_lut)
    pway_colors = pd.Series(plot_genes_in_all_pathways_src_pathway, index=plot_genes_in_all_pathways, name="Pway").map(pway_lut)

    plot_df:pd.DataFrame = gex_df.loc[:,plot_genes_in_all_pathways]

    g = sns.clustermap(plot_df.T, 
                    col_colors=phenotype_colors,
                    col_cluster=True,
                    row_colors=pway_colors,
                    row_cluster=True,
                    cmap="vlag",
                    cbar_pos=(0.1, .2, .03, .4),
    )
    old_yticks = g.ax_heatmap.get_yticks()
    old_yticklabels = g.ax_heatmap.get_yticklabels()
    new_yticks = np.arange(min(old_yticks),min(old_yticks)+len(plot_df.columns),1)
    g.ax_heatmap.set_yticks(new_yticks, labels = plot_df.columns[g.dendrogram_row.reordered_ind])
    g.ax_heatmap.set_xticklabels([])
    g.ax_heatmap.set_xlabel(dset_label)
    g.ax_row_dendrogram.remove()
    phenotype_handles = [Patch(facecolor=phenotype_lut[name]) for name in phenotype_lut]
    plt.legend(
        phenotype_handles,
        phenotype_lut,
        title='PAM50',
        bbox_to_anchor=(0.2, 0.7),
        bbox_transform=plt.gcf().transFigure,
        loc='right')
    for fmt in SAVING_FORMATS: plt.savefig(os.path.join(images_folder,fmt,f"importantfeatures_heatmap_{model_fname}_5pways_{dset_label}.{fmt}"), bbox_inches="tight")
    plt.close()

In [None]:
for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    print("\\multirow{2}{*}{"+pway_name.replace("HALLMARK_","").replace("KEGG_","").replace("_","\\_")+"}")
    print("", *most_important_anpw.index, sep=" & ", end=" \\\\\n")
    print("", *map(lambda x: f"{x:+.2f}", most_important_anpw.values), sep=" & ", end=" \\\\\n")
    print("\\midrule")

In [None]:
n_clusters = 3
only_min_max = True
use_quantiles = True
q_values = [i/n_clusters for i in range(n_clusters+1)]
time_limit = 365*5


p_values_dict = {
    "feat": [],
    "feat_idx": [],
    "pway": [],
    "tcga-coxph": [],
    "tcga-logrank": [],
    "survf-t": [],
    "tcga-survf-lo": [],
    "tcga-survf-hi": [],
    "meta-coxph": [],
    "meta-logrank": [],
    "meta-survf-lo": [],
    "meta-survf-hi": [],
}

for pway_name in pways_to_look_at:
    imp_k = max(imp_k_range)
    most_strong_noabs = abs(pathway_feature_importance_df.loc[pway_name].dropna()).sort_values()[-imp_k//2:]
    most_important_anpw = pathway_feature_importance_df.loc[pway_name,most_strong_noabs.index].sort_values(key=lambda x:np.abs(x), ascending=False)
    p_values_tcga = [
        lifelines.CoxPHFitter().fit(gex.loc[:,[feat]].join(full_phenotype[["OS_Time", "OS_Event"]]).dropna(), duration_col="OS_Time", event_col="OS_Event").summary.loc[feat,"p"]
        for feat in most_important_anpw.index
    ]
    p_values_meta = [
        lifelines.CoxPHFitter().fit(gex_meta.loc[:,[feat]].join(full_phenotype_meta[["OS_Time", "OS_Event"]]).dropna(), duration_col="OS_Time", event_col="OS_Event").summary.loc[feat,"p"]
        for feat in most_important_anpw.index
    ]
    all = list(zip(most_important_anpw.index, p_values_tcga, p_values_meta))

    for feat_idx, (feat, p_tcga, p_meta) in enumerate(all):
        fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)
        kmfs_dict = {}
        for ax_idx, (gex_df, pheno_df, label) in enumerate([[gex, full_phenotype, "TCGA"], [gex_meta, full_phenotype_meta, "Metabric"]]):
            if use_quantiles:
                quantiles = [gex_df.loc[:,feat].quantile(q=q) for q in q_values]
                cluster_values = np.argmax(
                    np.stack(
                        [
                            np.logical_and(
                                q_low<=gex_df.loc[:,feat],
                                gex_df.loc[:,feat]<q_high
                            )
                            for q_low, q_high in zip(quantiles[:-1],quantiles[1:])
                        ],
                        axis=1
                    ),
                    axis=1,
                )
            else:
                kmeans = sklearn.cluster.KMeans(n_clusters).fit(gex_df.loc[:,[feat]])
                cluster_values = kmeans.predict(gex_df.loc[:,[feat]])
                cluster_values = np.argsort(kmeans.cluster_centers_.flatten())[cluster_values]
            # Logrank test
            logrank_pheno_df = pheno_df.loc[gex_df.index,["OS_Time","OS_Event"]]
            logrank_cluster_indexing = (
                np.logical_or(cluster_values==0,cluster_values==n_clusters-1)
                    if only_min_max else
                np.ones_like(cluster_values, dtype=bool)
            )
            logrank_p_result = lifelines.statistics.multivariate_logrank_test(
                event_durations = logrank_pheno_df.loc[np.logical_or(cluster_values==0,cluster_values==n_clusters-1),"OS_Time"],
                event_observed = logrank_pheno_df.loc[np.logical_or(cluster_values==0,cluster_values==n_clusters-1),"OS_Event"],
                groups = cluster_values[np.logical_or(cluster_values==0,cluster_values==n_clusters-1)],
                t_0 = time_limit,
            )
            # Cluster labels
            cluster_labels = list(range(n_clusters))
            if n_clusters == 2:
                cluster_labels = ["low", "high"]
            elif n_clusters == 3:
                cluster_labels = ["low", "medium", "high"]
            elif n_clusters == 4:
                cluster_labels = ["low", "medium-low", "medium-high", "high"]
            for i in ([0,n_clusters-1] if only_min_max else range(n_clusters)):
                cluster_idx = gex_df.index[cluster_values==i]
                plot_pheno = pheno_df.loc[cluster_idx,["OS_Time","OS_Event"]].dropna()
                kmf = lifelines.KaplanMeierFitter().fit(
                    plot_pheno["OS_Time"],
                    plot_pheno["OS_Event"],
                    label = cluster_labels[i],
                )
                kmfs_dict[(label,i)] = kmf
                kmf.plot(loc=slice(0,time_limit if time_limit>=0 else plot_pheno["OS_Time"].max()), ax=axes[ax_idx])
            axes[ax_idx].set_title(f"{label}" + " ($\\log_{10} p = " + f"{np.log10(logrank_p_result.p_value):.1f}$)")
            if label == "TCGA":    
                p_values_dict["tcga-logrank"].append(logrank_p_result.p_value)
            else:
                p_values_dict["meta-logrank"].append(logrank_p_result.p_value)
        p_values_dict["tcga-coxph"].append(p_tcga)
        p_values_dict["meta-coxph"].append(p_meta)
        smallest_last_surv_time = min(
            time_limit,
            *map(
                max,
                [
                    kmfs_dict[("TCGA",0)].survival_function_.index,
                    kmfs_dict[("Metabric",0)].survival_function_.index,
                    *( [] if only_min_max else [
                        kmfs_dict[("TCGA",i)].survival_function_.index for i in range(1,n_clusters-2)
                    ]),
                    *( [] if only_min_max else [
                        kmfs_dict[("Metabric",i)].survival_function_.index for i in range(1,n_clusters-2)
                    ]),
                    kmfs_dict[("TCGA",n_clusters-1)].survival_function_.index,
                    kmfs_dict[("Metabric",n_clusters-1)].survival_function_.index,
                ]
            )
        )
        p_values_dict["survf-t"].append(smallest_last_surv_time)
        p_values_dict["tcga-survf-lo"].append(kmfs_dict[("TCGA",0)].predict(smallest_last_surv_time))
        p_values_dict["tcga-survf-hi"].append(kmfs_dict[("TCGA",n_clusters-1)].predict(smallest_last_surv_time))
        p_values_dict["meta-survf-lo"].append(kmfs_dict[("Metabric",0)].predict(smallest_last_surv_time))
        p_values_dict["meta-survf-hi"].append(kmfs_dict[("Metabric",n_clusters-1)].predict(smallest_last_surv_time))
        p_values_dict["feat"].append(feat)
        p_values_dict["feat_idx"].append(feat_idx)
        p_values_dict["pway"].append(pway_name)
        fig.tight_layout()
        fig.suptitle(
            f"{pway_name} - {feat}"
        )
        fig.tight_layout()
        for fmt in SAVING_FORMATS: fig.savefig(os.path.join(images_folder,fmt,f"kaplanmeier_{model_fname}_{pway_name}_{feat}.{fmt}"), bbox_inches="tight")
        plt.close(fig)
    


In [None]:
p_values_df = pd.DataFrame(p_values_dict)

p_values_df = p_values_df.loc[p_values_df["pway"].apply(lambda p: p in pways_to_look_at[:5])]

p_values_df["tcga-logrank-significant"] = p_values_df["tcga-logrank"]<0.05
p_values_df = p_values_df.loc[p_values_df["tcga-logrank-significant"]]
p_values_df["meta-logrank-significant"] = p_values_df["meta-logrank"]<0.05

p_values_df["both-logrank-significant"] = np.logical_and(p_values_df["tcga-logrank-significant"], p_values_df["meta-logrank-significant"])

p_values_df["tcga-survf-a"] = p_values_df["tcga-survf-hi"]-p_values_df["tcga-survf-lo"]
p_values_df["meta-survf-a"] = p_values_df["meta-survf-hi"]-p_values_df["meta-survf-lo"]

p_values_df["survf-match"] = np.sign(p_values_df["tcga-survf-a"])==np.sign(p_values_df["meta-survf-a"])

p_values_df["both-logrank-significant-and-survf-match"] = np.logical_and(p_values_df["both-logrank-significant"],p_values_df["survf-match"])
p_values_df.sort_values(by=["both-logrank-significant", "feat"], axis="index", ascending=[False,True], inplace=True)
print(model_fname)
print(p_values_df["both-logrank-significant-and-survf-match"].mean())
print(p_values_df["both-logrank-significant-and-survf-match"].describe())
display(p_values_df[["feat","pway","both-logrank-significant-and-survf-match","both-logrank-significant","survf-match","tcga-logrank","meta-logrank","survf-t","tcga-survf-lo","tcga-survf-hi","meta-survf-lo","meta-survf-hi",]])

In [None]:
p_values_df[["feat","pway","tcga-logrank","meta-logrank"]].set_index(["feat","pway"]).apply(np.log10).round(1)