# Motivation

The role of this notebook is to explore the biological significance of the selected tf

In [1]:
%load_ext autoreload
import pandas as pd
import numpy as np
import os
import sys

import plotly.express as px
import plotly.io as pio


import multiprocess as mp


# own libraries
SCRIPT_DIR = os.path.dirname(os.path.abspath("pcgna_processing.py"))
sys.path.append(os.path.dirname(SCRIPT_DIR))
sys.path.append('/Users/vlad/Documents/Code/York/iNet_v2/src/')

from NetworkAnalysis.ExperimentSet import ExperimentSet
from NetworkAnalysis import GraphHelper as gh
from NetworkAnalysis.utilities import clustering as cs
from NetworkAnalysis.utilities import sankey_consensus_plot as sky
from NetworkAnalysis.utilities.helpers import save_fig, survival_plot, survival_comp
from NetworkAnalysis.GraphToolExp import GraphToolExperiment as GtExp
sys.path.append(os.path.dirname("../../src"))
# Gsea libraries

pio.templates.default = "ggplot2"


pool = mp.Pool(mp.cpu_count())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
results_path = "../../results/exp/"
data_base = "../../data/"
base_path = "../../results/"
exp_folder_tumour = "network_I/tum/" # "/integration_v2.1/ - path from iNET
exp_folder_h42 = "network_I/healthy42/"
exp_folder_h42_ctrl = "network_I/healthyControls/"

figures_path = "selective_edge_pruning/"

vu_output = pd.read_csv(f"{data_base}/metadata/VU_clustering_v3.tsv", sep="\t", index_col="Sample")

tcga_mutations_df = pd.read_csv(f"{data_base}/tumour/mutations_tcga.csv")
tcga_mutations_df = tcga_mutations_df[tcga_mutations_df["count"] != 0].set_index("gene")

tum_tpms = pd.read_csv(f"{data_base}/tumour/TPMs_selected_genes_v3_13k_gc42.tsv", sep="\t", index_col="genes")
tum_tpms_v4 = pd.read_csv(f"{data_base}/tumour/tum_TPMs_selected_genes_gc42_all_v4.tsv", sep="\t", index_col="genes")

# tf list
tf_path = f"{data_base}/metadata/TF_names_v_1.01.txt"
if os.path.exists(tf_path):
    tf_list = np.genfromtxt(fname=tf_path, delimiter="\t", skip_header=1, dtype="str")

# Analysis

## Gene expression visualisation

In [49]:
sel_tfs = pd.read_csv(f'{data_base}/tf_ctrl.csv', index_col='gene')
sel_tfs['tum_mean_expression'] = tum_tpms_v4.loc[sel_tfs.index].mean(axis=1)
sel_tfs['tum_median_expression'] = tum_tpms_v4.loc[sel_tfs.index].median(axis=1)
sel_tfs['tum_std_expression'] = tum_tpms_v4.loc[sel_tfs.index].std(axis=1)

In [50]:
sel_tfs

Unnamed: 0_level_0,mut_count,tum_median_expression,healthy_median_expression,tum_mean_expression,healthy_mean_expression,tum_std_expression,healthy_std_expression
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
TMF1,9.0,10.542958,33.355197,11.538859,32.092722,6.897833,12.021861
TGIF1,4.0,122.278603,133.417372,133.043308,158.564100,68.719895,87.264486
TRERF1,15.0,5.376512,10.024822,6.198969,11.768747,4.264199,6.629848
ZBTB10,1.0,2.886254,3.889252,3.804161,7.196493,3.285030,8.219635
NFATC4,7.0,21.945165,29.020569,31.025401,30.762123,29.833537,23.082292
...,...,...,...,...,...,...,...
REPIN1,2.0,43.175231,52.293611,48.269158,50.571216,29.774059,23.306956
JUNB,3.0,248.349863,271.009715,286.124501,437.743241,181.303215,398.875691
TP63,14.0,36.632924,33.863527,46.864273,50.054863,50.317509,37.662011
ZNF224,10.0,6.923992,11.850942,8.754605,13.715299,6.660910,7.292112


In [65]:
# sel_tfs = sel_tfs[~sel_tfs.index.isin(["ELF3", "JUNB"])]
dmy_df = sel_tfs.copy(deep=True)
dmy_df['tum_mean_expression'] = sel_tfs['tum_mean_expression'] +1
dmy_df['healthy_mean_expression'] = sel_tfs['healthy_mean_expression'] +1

fig = px.scatter(
    dmy_df.reset_index(),
    x="tum_mean_expression",
    y="healthy_mean_expression",
    # error_x="tum_std_expression",
    # error_y="healthy_std_expression",
    # text="gene",
    color="mut_count",
    size="mut_count",
    hover_data="gene",
    color_continuous_scale=px.colors.sequential.Sunset_r,
    height=700,
    log_x=True,
    log_y=True,

)
# Get the top 10 genes with highest healthy median expression
top_genes_h = dmy_df.nlargest(10, "healthy_median_expression").index.tolist()
top_genes_t = dmy_df.nlargest(10, "tum_median_expression").index.tolist()
top_genes_m = dmy_df.nlargest(10, "mut_count").index.tolist()

top_genes = list(set(top_genes_h + top_genes_t + top_genes_m))
top_genes.extend(["FOSL1", "FOXQ1", "ERF", "MYCL", "STAT2", "IRF7", "CIC", "MBD6", "ZNF750", "ZNF513", "BNC1"])
chg_genes = ["KMT2A", "SPEN", "MEDCOM", "ERF", "AHR", "STAT2", "MBD6", "ZNF513", "TRERF1"]
# Add text annotation for the top 10 genes
for idx, gene in enumerate(top_genes):
    x = dmy_df.loc[dmy_df.index == gene, "tum_mean_expression"].values[0]
    y = dmy_df.loc[dmy_df.index == gene, "healthy_mean_expression"].values[0]
    ay = -30
    if idx % 2:
        ax = 5
    else:
        ax = -5

    if gene in chg_genes:
        ax = -ax
        ay = -ay

    if True:
        x, y, ax, ay = np.log10(x), np.log10(y), np.log10(ax), np.log10(ay)

    fig.add_annotation(x=x, y=y, text=gene, showarrow=True, arrowhead=1, ax=ax, ay=ay)

fig.update_layout(height=900)


invalid value encountered in log10



## Comparing with known markers

In [5]:
luminal_markers = ["KRT20", "PPARG", "FOXA1", "GATA3", "SNX31", "UPK1A", "UPK2", "FGFR3"]
basal_markers = ["CD44", "KRT6A", "KRT5", "KRT14", "COL17A1"]
squamos_markers = ["DSC3", "GSDMC", "TCGM1", "PI3", "TP63"]
immune_markers = ["CD274", "PDCD1LG2", "IDO1", "CXCL11", "L1CAM", "SAA1"]
neural_diff = ["MSI1", "PLEKHG4B", "GNG4", "PEG10", "RND2", "APLP1", "SOX2", "TUBB2B"]

# TCGA markers - main paper
emt_claudin = ["ZEB1", "ZEB2", "SNAI1", "TWIST1", "CDH2", "CLDN3", "CLDN4", "CLDN7"]
ecm_muscle = ["PGM5", "DES", "C7", "SFRP4", "COMP", "SGCD"]

tcga_markers = luminal_markers + basal_markers + squamos_markers + immune_markers + neural_diff + emt_claudin + ecm_muscle

In [6]:
set(sel_tfs.index) & set(tcga_markers)

{'TP63'}

### Urothelium type markers

In [7]:
tf_diff = ["P63", "FOXA1","PPARG", "RARG", "IRF1", "ELF3", "GRHL3", "KLF5", "GATA4", "GATA6", "GATA3"]
krt = ["KRT13", "KRT14", "KRT15", "KRT20"]
upk = ["UPK1B", "UPK1A", "UPK3A", "UPK2"]
cld = ["CLDN3", "CLDN4", "CLDN5" ]

egfr_fam = ["EGFR", "ERBB2", "ERBB3", "ERBB4", "EGF", "AREG", "HBEGF","TGFA","BTC", "EREG"]
fgfr_fam = ["FGFR1", "FGFR2", "FGFR3", "FGF1", "FGF2"]
map_kpathway = ["RAS", "RAF", "MEK1", "MEK2", "MEK3", "MEK4","ERK"]
pi3_kpathway = ["PIK3C3", "PIK3R2", "PIK3C2B", "AKT1", "AKT2"]
others = ["MKI67", "MCM2", "UPK3A", "ZO1", "TJP1", "ZO2", "TJP2", "ZO3", "TJP3"]
hox_ur = ["HOXB2", "HOXB3", "HOXB5", "HOXB6", "HOXB8"]
hox_bla = ["HOXA9", "HOXA10", "HOXA11", "HOXA13"]

diff_markers = tf_diff + cld + krt + upk

uro_markers = diff_markers + egfr_fam + fgfr_fam + map_kpathway + pi3_kpathway + others + hox_ur + hox_bla

In [8]:
set(sel_tfs.index) & set(uro_markers)

{'ELF3', 'GRHL3', 'HOXB6', 'KLF5'}

### Lund type markers

In [9]:
lund_qtc1 = ["FLI1", "FOXP3", "ILKZF1", "IRF4", "IRF8", "RUNX3", "SCML4", "SPI1", "STAT4", "TBX21", "TFEC"]
lund_qtc2 = ["AEBP1", "BNC2", "GLI2", "GLIS1", "HIC1", "MSC", "PPRX1", "PPRX2", "TGFB1I1", "TWIST1"]
lund_qtc3 = ["EBF1", "HEYL", "LEF1", "MEF2C", "TCF4", "ZEB1", "ZEB2"]
lund_qtc8 = ["GATA5", "HAND1", "HAND2", "KLF16"]
lund_qtc17 = ["ARID5A", "BATF3", "VENTX"]
lund_ba_mes = lund_qtc1 + lund_qtc2 + lund_qtc3 + lund_qtc8 + lund_qtc17

lund_ba_sq = ["BRIP1", "E2F7", "FOXM1", "ZNF367", "IRF1", "SP110", "STAT1"]
lund_mes = ["TP53", "RB1", "FGFR3", "ANKHD1", "VIM", "ZEB2"]
ba_sq_inf = ["CDH3", "EGFR"]

lund_sc_ne = ["CHGA", "SYP", "ENO2", "EPCAM"] #Highly expressed

lund_markers = lund_ba_mes + lund_ba_sq + lund_mes + ba_sq_inf + lund_sc_ne

In [10]:
set(sel_tfs.index) & set(lund_markers)

{'KLF16', 'SP110', 'STAT1'}

### Immune markers

In [11]:
b_cells = ["BCL2", "BCL6", "CD19", "CD1D", "CD22", "CD24", "CD27", "CD274","CD34", "CD38", "CD40","CD44","CD5","CD53","CD69","CD72", "CD79A", "CD79B", "CD80", "CD86", "CD93", "CR2", "CXCR4", 'CXCR5',"FAS","FCER2", "FCRL4" "HAVCR1","IL10", 'IL2RA','IL7R','IRF4','ITGAX', 'LILRB1','MME','MS4A1','NT5E','PDCD1LG2','PRDM1','PTPRC','SDC1','SPN','TFRC','TLR9','TNFRSF13B','TNFRSF13C','TNFRSF17','XBP1']
t_cells = ['CD4', 'CD8', 'CCR4', 'CCR5', 'CCR6', 'CCR7', 'CCR10', 'CD127', 'CD27', 'CD28', 'CD38', 'CD58', 'CD69', 'CTLA4', 'CXCR3', 'FAS', 'IL2RA',
        'IL2RB', 'ITGAE', 'ITGAL', 'KLRB1', 'NCAM1', 'PECAM1', 'PTGDR2', 'SELL', 'IFNG', 'IL10', 'IL13', 'IL17A', 'IL2', 'IL21','IL22', 'IL25', 'IL26', 'IL4', 'IL5', 'IL9', 'TGFB1', 'TNF', 'AHR', 'EOMES','FOXO4', 'FOXP1', 'FOXP3', 'GATA3','IRF4', 'LEF1', 'PRDM1', 'RORC','STAT4', 'TBX21','TCF7', 'GZMA']

nk_cells = ['B3GAT1','CCR7','CD16','CD2','CD226','CD244','CD27','CD300A','CD34','CD58','CD59','CD69','CSF2','CX3CR1','CXCR1','CXCR3','CXCR4','EOMES','GZMB','ICAM1','IFNG','IL1R1','IL22','IL2RB','IL7R','ITGA1','Itga2','ITGAL','ITGAM','ITGB2','KIR2DL1','KIR2DL2','KIT','Klrb1c','KLRC1','KLRC2','KLRD1','KLRF1','KLRG1','KLRK1','LILRB1','Klra4','Klra8','NCAM1','NCR1','NCR2','NCR3','PRF1','SELL','SIGLEC7','SLAMF6','SPN','TBX21','TNF']

macrophages_cells = [ 'ADGRE1','CCR2','CD14','CD68','CSF1R','Ly6c1','MARCO','MRC1','NOS2','PPARG','SIGLEC1','TLR2','ARG1','CD163','CD200R1','CD80','CD86','CLEC10A','CLEC7A','CSF2','CX3CR1','FCGR1A','ITGAM','MERTK','PDCD1LG2','Retnla','TNF','CCL22','CD36','CD40','IL10','IL1B','IL6','LGALS3','TLR4','CCL2','CCR5','CD209','CD63','CD86','CSF1','CXCL2','FCGR3A','IFNG','IL4','IRF4','ITGAX','MSR1','PDGFB','PTPRC','STAT6','TIMD4','Chil3','CLEC6A','IL1R1','ITGB2','PDCD1LG2','TLR7']

monocyte_cells = ['CD14','CD16','CSF1R','CX3CR1','ITGAM','ITGAX','LY6C1','CCR2','CXCR4','FCGR1A','SELL','SPN','ADGRE1','CCR7','TNF','CD86','IL10','IL1B','MERTK','TREML4','CD209','NR4A1','Ly6a','PTPRC','IL3RA','CD27','CCR5','CD32','CD1A','MRC1','ITGB3','CD9','CXCR6','CCR1','FLT3','KLF2','CLEC12A','CCR6','CCR8','CD68','CLEC7A','KIT','MAF','MAFB','SPI1','CD1C','PPARG','CEBPB','ITGAE','TEK']


immune_markers = b_cells + t_cells + nk_cells + macrophages_cells + monocyte_cells

In [12]:
set(sel_tfs.index) & set(immune_markers)

{'AHR', 'BCL6'}

# Morpheus

## Prepare for Morpheus
Outliers for standard log2 TPMS and norm of log, after applying agglomerative clustering with 1-pearson correlation.

```Python
outliers_log2 = ['TCGA-C4-A0EZ', 'TCGA-DK-AA6W', 'TCGA-G2-A2EL', 'TCGA-BL-A3JM', 'TCGA-XF-A9T2', 'TCGA-XF-AAMH', 'TCGA-XF-A9ST', 'TCGA-GC-A4ZW', 'TCGA-HQ-A2OF', 'TCGA-DK-AA6T', 'TCGA-BT-A2LA', 'TCGA-XF-AAN7', 'TCGA-FJ-A871', 'TCGA-CF-A3MF']
```

```Python
norm_outliers = ['TCGA-2F-A9KW', 'TCGA-XF-A9ST', 'TCGA-BL-A3JM', 'TCGA-XF-A9T2', 'TCGA-XF-AAMH', 'TCGA-DK-AA6T', 'TCGA-BT-A2LA', 'TCGA-XF-AAN7', 'TCGA-FJ-A871', 'TCGA-C4-A0EZ', 'TCGA-DK-AA6W', 'TCGA-G2-A2EL']
```

<!-- ![alt text](selective_edge_pruning/sel_tf_log2.png)
![alt text](selective_edge_pruning/sel_tf_norm_log2.png) -->



In [13]:
outliers_log2 = ['TCGA-C4-A0EZ', 'TCGA-DK-AA6W', 'TCGA-G2-A2EL', 'TCGA-BL-A3JM', 'TCGA-XF-A9T2', 'TCGA-XF-AAMH', 'TCGA-XF-A9ST', 'TCGA-GC-A4ZW', 'TCGA-HQ-A2OF', 'TCGA-DK-AA6T', 'TCGA-BT-A2LA', 'TCGA-XF-AAN7', 'TCGA-FJ-A871', 'TCGA-CF-A3MF']

norm_outliers = ['TCGA-2F-A9KW', 'TCGA-XF-A9ST', 'TCGA-BL-A3JM', 'TCGA-XF-A9T2', 'TCGA-XF-AAMH', 'TCGA-DK-AA6T', 'TCGA-BT-A2LA', 'TCGA-XF-AAN7', 'TCGA-FJ-A871', 'TCGA-C4-A0EZ', 'TCGA-DK-AA6W', 'TCGA-G2-A2EL']

cmn_outliars = set(outliers_log2) & set(norm_outliers)

print(f"### Num outliers for standard log2 {len(outliers_log2)}.\n --> {outliers_log2}")
print(f"### Num outliers for norm log2 {len(outliers_log2)}.\n --> {outliers_log2}")
print(f"### Common outliers *{len(cmn_outliars)}*.\n --> {cmn_outliars}")

### Num outliers for standard log2 14.
 --> ['TCGA-C4-A0EZ', 'TCGA-DK-AA6W', 'TCGA-G2-A2EL', 'TCGA-BL-A3JM', 'TCGA-XF-A9T2', 'TCGA-XF-AAMH', 'TCGA-XF-A9ST', 'TCGA-GC-A4ZW', 'TCGA-HQ-A2OF', 'TCGA-DK-AA6T', 'TCGA-BT-A2LA', 'TCGA-XF-AAN7', 'TCGA-FJ-A871', 'TCGA-CF-A3MF']
### Num outliers for norm log2 14.
 --> ['TCGA-C4-A0EZ', 'TCGA-DK-AA6W', 'TCGA-G2-A2EL', 'TCGA-BL-A3JM', 'TCGA-XF-A9T2', 'TCGA-XF-AAMH', 'TCGA-XF-A9ST', 'TCGA-GC-A4ZW', 'TCGA-HQ-A2OF', 'TCGA-DK-AA6T', 'TCGA-BT-A2LA', 'TCGA-XF-AAN7', 'TCGA-FJ-A871', 'TCGA-CF-A3MF']
### Common outliers *11*.
 --> {'TCGA-BT-A2LA', 'TCGA-G2-A2EL', 'TCGA-BL-A3JM', 'TCGA-FJ-A871', 'TCGA-DK-AA6T', 'TCGA-XF-A9T2', 'TCGA-XF-AAN7', 'TCGA-XF-A9ST', 'TCGA-C4-A0EZ', 'TCGA-XF-AAMH', 'TCGA-DK-AA6W'}


In [57]:
dmy_df = tum_tpms_v4.loc[sel_tfs.index]
dmy_df = np.log2(dmy_df + 1)

sel_metadata = ['KMeans_labels_6', 'consensus', 'TCGA408_classifier', 'Lund2017.subtype', 'ESTIMATE_score', "Immune_score", "Stromal_score", 'tumor_stage']
dmy_df = pd.concat([vu_output[sel_metadata].T, dmy_df], axis=0).dropna(axis=1)

# Remove outliers
dmy_df = dmy_df[list(set(dmy_df.columns) - set(cmn_outliars))]

# Adding notes on the genes
dmy_df['Notes'] = ''
dmy_df.loc[dmy_df.index.isin(lund_markers), "Notes"] = 'Lund marker'
dmy_df.loc[dmy_df.index.isin(immune_markers), "Notes"] = 'Immune marker'
dmy_df.loc[dmy_df.index.isin(uro_markers), "Notes"] = 'Uro markers'
dmy_df.loc[dmy_df.index.isin(tcga_markers), "Notes"] = 'TCGA markers'

dmy_df = dmy_df[['Notes'] + list(dmy_df.columns[:-1])]
dmy_df.to_csv(f"{figures_path}/log2_sel_tfs_no_outliers.tsv", sep='\t')


DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`



## Importing Morpheus

In [58]:
morpheus_path = f"{figures_path}/morpheus/"

morp_df = pd.read_csv(f"{morpheus_path}/15_CS_norm_log2_sel_tfs_no_outliers.gct", sep="\t", skiprows=2)
columns = morp_df["id"]
morp_df = morp_df.drop(columns=["Notes"]).transpose()
morp_df.columns = columns
morp_df = morp_df.iloc[2:, :]

morp_df['dendrogram_cut'] = morp_df['dendrogram_cut'].astype(float).astype(str)

keep_clusters = []
# Size of 1%
size_th = round(morp_df.shape[0] * 0.01)
for cluster, size in morp_df['dendrogram_cut'].value_counts().items():
    if size > 5:
        keep_clusters.append(cluster)

# Drop clusters that are smaller than 1% of the cohort size
morp_df = morp_df.loc[morp_df['dendrogram_cut'].isin(keep_clusters)]
morp_df.shape

(378, 107)

In [16]:
#Sankey
morp_df.rename_axis("sample", axis="columns", inplace=True)
reorder_cols = [
    "TCGA408_classifier",
    "dendrogram_cut",
    # "KMeans_labels_6",
    'Lund2017.subtype',
    # "consensus",
]
meta, sky_fig = sky.main(df=morp_df, reorder_cols=reorder_cols, title='MIBC stratification based on the TF from selective edge pruning', retMeta=True)
sky_fig.update_layout(height=700)

## Dumbell plots

In [17]:
import plotly.graph_objects as go
import matplotlib.pyplot as plt


# Main function to see the differences
def dumbell_plots(morp_df: pd.DataFrame, tum_df: pd.DataFrame, sel_tfs: pd.DataFrame, cls_1="", cls_2="", markers=[], log=False):

    cluster_1, cluster_2 = int(cls_1.split("_")[-1]), int(cls_2.split("_")[-1])

    samples_1 = morp_df[morp_df['dendrogram_cut'] == cluster_1].index
    samples_2 = morp_df[morp_df['dendrogram_cut'] == cluster_2].index

    # Reconstruct the TPM
    dmy_df = tum_df.loc[tum_df.index.isin(sel_tfs.index)]
    df_1 = dmy_df[samples_1]
    df_1[cls_1] = df_1.mean(axis=1)

    dmy_df = tum_df.loc[tum_df.index.isin(sel_tfs.index)]
    df_2 = dmy_df[samples_2]
    df_2[cls_2] = df_2.mean(axis=1)

    comb_df = pd.concat([df_1[cls_1], df_2[cls_2]], axis=1)

    # Prepare for plotting
    plot_data = {"line_x": [], "line_y": [], cls_1: [], cls_2: []}
    genes = list(comb_df.index)

    y_axis_title = 'TPM mean'
    if log:
        comb_df[cls_1] = np.log2(comb_df[cls_1] + 1)
        comb_df[cls_2] = np.log2(comb_df[cls_2] + 1)
        y_axis_title = 'Log2(TPM + 1) mean '

    # generate the data for plots
    for gene, row in comb_df.iterrows():

        val_1, val_2 = row[cls_1], row[cls_2]
        # The data for two scatter points cls_1 and cls_2
        plot_data[cls_1].extend([val_1])
        plot_data[cls_2].extend([val_2])
        # The line between the two clases
        plot_data["line_y"].extend(
[
                val_1,
                val_2,
                None,
            ]
        )
        plot_data["line_x"].extend([gene, gene, None])

    # plotting
    fig = go.Figure(
    data=[
        go.Scatter(
            x=plot_data["line_x"],
            y=plot_data["line_y"],
            mode="lines",
            showlegend=False,
            marker=dict(
                color="grey"
            )
        ),
        go.Scatter(
            y=plot_data[cls_1],
            x=genes,
            mode="markers",
            name=cls_1,
            marker=dict(
                # color="green",
                size=10
            )
            
        ),
        go.Scatter(
            y=plot_data[cls_2],
            x=genes,
            mode="markers",
            name=cls_2,
            marker=dict(
                # color="blue",
                size=10
            )   
        ),
        ]
    )

    # Add title and change the figure size
    fig.update_layout(
        title=f"Changes between {cls_1} and {cls_2}",
        yaxis_title=y_axis_title,
        xaxis_title="Gene",
        legend_itemclick=False
    )

    fig.update_xaxes()
    # Add markers

    annotations = []
    for i, marker in enumerate(markers):
        sel_df = comb_df.loc[marker]
        y = sel_df.max()
        x = marker


        # avoid overlapping
        xanchor = 'right' if i % 2 == 0 else 'left'
        ax = -10 if i % 2 == 0 else 10

        fig.add_annotation(
                x=x,
                y=y,
                text=marker,
                showarrow=True,
                ax = ax,  # No horizontal offset
                xanchor=xanchor
            )
        annotations.append({
                "x": x,
                "y": y,
                "text": marker,
                "showarrow": True,
                "ax": ax,  
                "xanchor": xanchor, 
            })
        
    return fig, comb_df, annotations

In [59]:
morp_df['dendrogram_cut'] = morp_df['dendrogram_cut'].astype(float).astype(int)

cluster_1, label_1 = 3, 'mes-like'
cluster_2, label_2 = 13, 'lumInf'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = [  'GRHL3', 'MYCL', "BNC1", 'ELF3', 'ZBTB7C', 'STAT1', 'HOXB6',  'ZNF750','MECOM',  'FOXQ1']

fig1, df, ann1 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title1 = f'{cls_1} vs {cls_2}'
fig1.show()

In [60]:
# Basal large vs Luminal
log = True
cluster_1, label_1 = 13, 'luminal'
cluster_2, label_2 = 4, 'basal'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = ['EGR1', 'BNC1', 'ELF3', 'FOSL1', 'MYCL', 'FOXQ1', "GRHL3", 'JRK']

fig1, df, ann1 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title1 = f'{cls_1} vs {cls_2}'

# Lum vs LumInf
cluster_1, label_1 = 13, 'luminal'
cluster_2, label_2 = 12, 'lumInf'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = ['EGR1', 'TP63', 'ELF3']

fig2, df, ann2 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title2= f'{cls_1} vs {cls_2}'

# Small vs LumInf
cluster_1, label_1 = 5, 'smallBasal'
cluster_2, label_2 = 12, 'lumInf'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = ['BNC1', 'ELF3', 'MYCL', "GRHL3",'HES2', 'JRK', 'TP63', 'MSX2', 'IRF6', 'HOXB6']

fig3, df, ann3 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title3 = f'{cls_1} vs {cls_2}'

# Mes-like vs Basal
cluster_1, label_1 = 3, 'mesLike'
cluster_2, label_2 = 4, 'basal'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = ["GRHL3",'BNC1', 'ELF3', 'MYCL','HES2', 'JRK', 'TP63', 'IRF6', 'STAT1', 'ZBTB7C', 'ZNF750', "EGR1", 'JUN', 'JUNB']

fig4, df, ann4 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title4 = f'{cls_1} vs {cls_2}'

#Mes-like vs small basal
cluster_1, label_1 = 3, 'mesLike'
cluster_2, label_2 = 5, 'smallBasal'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = ['REL',  'HES2', 'MYCL', 'GRHL3', 'KLF5', "ZBTB7C", 'TP63', 'IRF6', 'MECOM',  'ZNF750', 'FOXQ1', 'ELF3', 'MSX2','BNC1', 'STAT1',]

fig5, df, ann5 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title5= f'{cls_1} vs {cls_2}'


# Small basal vs Basla
cluster_1, label_1 = 4, 'basal'
cluster_2, label_2 = 5, 'smallBasal'
cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
markers = [ 'MSX2', 'MAFG', 'ZXDB', 'TP63', "ZBTB10", 'ZBTB7C']

fig6, df, ann6 = dumbell_plots(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)
title6 = f'{cls_1} vs {cls_2}'

In [63]:
num_cols=2
subplots_config = {
    "num_cols": num_cols,
    "shared_x": False,
    "shared_y": False,
    "h_spacing": 0.05,
    "v_spacing": 0.1,
    "main_title": "Gene differences",
    "height": 1200,
    "width": None,
    "y_title": None,
    "x_title": None,
    "specs": None,
}

figs, titles = [fig1, fig2, fig3, fig4, fig5, fig6], [title1, title2, title3, title4, title5, title6]
annotations = [ann1, ann2, ann3, ann4, ann5, ann6]

# Defining the trace colors
traces_names = ["mesLike_3", 'basal_4', "luminal_13", 'lumInf_12', "smallBasal_5"]
trace_colors = {}
for idx, name in enumerate(traces_names):
     trace_colors[name] = px.colors.qualitative.Plotly[idx]

# Making sure that we only display the traces once
displayed_legends = set()
for fig in figs:
    for trace in fig.data:
        if trace.name:
            trace.update(marker=dict(color=trace_colors[trace.name]))

            # Manage legend entries
            if trace.name not in displayed_legends:
                displayed_legends.add(trace.name)
                trace.showlegend = True  # Show legend for this trace
            else:
                trace.showlegend = False  # Hide legend for this trace

fig = gh.helper_multiplots(figs, titles, subplots_config)

# Adding the annotations to the subtplots
idx_row, idx_col = 1, 1
for i, ann in enumerate(annotations):
    for elem in ann:
        fig.add_annotation(elem, row=idx_row, col=idx_col)
    if idx_col % num_cols == 0:
            idx_col = 0
            idx_row += 1
    idx_col += 1


fig = fig.update_layout(showlegend=True)
# fig.show()
# save_fig(name="p0_elbowMethod_4K", fig=fig, base_path=figures_path, width=None, height=400)

### Single plot version

In [21]:
# Single scatter plot in plotly
def plot_cluster_means(morp_df: pd.DataFrame, tum_df: pd.DataFrame, sel_tfs: pd.DataFrame, cls_1="", cls_2="", markers=[], log=False):

    cluster_1, cluster_2 = int(cls_1.split("_")[-1]), int(cls_2.split("_")[-1])

    samples_1 = morp_df[morp_df['dendrogram_cut'] == cluster_1].index
    samples_2 = morp_df[morp_df['dendrogram_cut'] == cluster_2].index

    # Reconstruct the TPM
    dmy_df = tum_df.loc[tum_df.index.isin(sel_tfs.index)]
    df_1 = dmy_df[samples_1]
    df_1[cls_1] = df_1.mean(axis=1)

    dmy_df = tum_df.loc[tum_df.index.isin(sel_tfs.index)]
    df_2 = dmy_df[samples_2]
    df_2[cls_2] = df_2.mean(axis=1)

    comb_df = pd.concat([df_1[cls_1], df_2[cls_2]], axis=1)
    fig = px.scatter(comb_df.reset_index(), x=cls_1, y=cls_2, hover_data='genes', title=f'{cls_1} vs {cls_2}', log_x=log, log_y=log, trendline='ols', trendline_color_override='red', trendline_options=dict(log_x=log, log_y=log))

    for marker in markers:
        dmy = comb_df.loc[marker]
        x, y = dmy.values[0], dmy.values[1]
        if log:
            x, y = np.log10(x), np.log10(y)
            
        fig.add_annotation(
            x=x,
            y=y,
            text=dmy.name,
            showarrow=False,
            xanchor="right",
        )
    return fig, comb_df

# Matplotlib functions - useful for multiplots
def plot_cluster_means_2(ax, morp_df, tum_df, sel_tfs, cls_1="", cls_2="", markers=[], log=False):
    cluster_1 = int(cls_1.split("_")[-1])
    cluster_2 = int(cls_2.split("_")[-1])

    samples_1 = morp_df[morp_df['dendrogram_cut'] == cluster_1].index
    samples_2 = morp_df[morp_df['dendrogram_cut'] == cluster_2].index

    # Reconstruct the TPM
    dmy_df = tum_df.loc[tum_df.index.isin(sel_tfs.index)]
    df_1 = dmy_df.loc[:, samples_1].mean(axis=1).rename(cls_1)
    df_2 = dmy_df.loc[:, samples_2].mean(axis=1).rename(cls_2)

    comb_df = pd.concat([df_1, df_2], axis=1)

    # Plotting
    x_values = comb_df[cls_1]
    y_values = comb_df[cls_2]
    
    if log:
        x_values = np.log10(x_values + 1)  # +1 to handle log(0) cases
        y_values = np.log10(y_values + 1)
        xlabel = f'Log10({cls_1})'
        ylabel = f'Log10({cls_2})'
        title = f'Log10-scaled: {cls_1} vs {cls_2}'
    else:
        xlabel = cls_1
        ylabel = cls_2
        title = f'{cls_1} vs {cls_2}'

    ax.scatter(x_values, y_values)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    # Adding trend line
    z = np.polyfit(x_values, y_values, 1)  # Fit a first degree polynomial (linear fit)
    p = np.poly1d(z)  # Create the polynomial object to evaluate
    ax.plot(x_values, p(x_values), "r--")  # Plot the trend line

    # Annotations
    for marker in markers:
        x = x_values.get(marker, None)
        y = y_values.get(marker, None)
        if x is not None and y is not None:
            ax.annotate(marker, (x, y), textcoords="offset points", xytext=(0,10), ha='center')

    # Adding gridlines
    ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

# save_fig(name="test", fig=fig, base_path=figures_path, width=1400, height=700)

In [22]:
if 0:
    cluster_1, label_1 = 9, 'basal'
    cluster_2, label_2 = 7, 'luminf'
    cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'

    markers = ['TP63', 'HES2', 'MSX2', "MYCL", 'ZSCAN16', 'MAFF', "IRF7", 'IRF6', 'KLF5', "ETS2"]

    dmy_df = sel_tfs.loc[~sel_tfs.index.isin(["BNC1"])]
    fig, df = plot_cluster_means(morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=dmy_df, cls_1=cls_2, cls_2=cls_1, markers=markers, log=True)
    # fig.show()

### Matplotlib version

In [105]:
if 0:
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))  # 1 row, 2 columns

    cluster_1, label_1 = 10, 'basal'
    cluster_2, label_2 = 13, 'basal'
    cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
    markers = ["IRF6", "TP63", "GRHL3", "HES2", 'BNC1', "REL", "ZBTB7C", "STAT1", "ELF3", "JUNB", "ZNF750", "AHR", "MYCL", 'REPIN1']
    fig1, df = plot_cluster_means_2(axs[0], morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)

    # Repeat for the second figure
    cluster_1, cluster_2, label_1, label_2 = 10, 9, 'basal', 'basal'
    cls_1, cls_2 = f'{label_1}_{cluster_1}', f'{label_2}_{cluster_2}'
    markers = ["MSX2", "ZBTB7C", "ELF3", "KLF5", "TP63", "MECOM", 'NR4A2']
    fig2, df = plot_cluster_means_2(axs[1], morp_df=morp_df, tum_df=tum_tpms_v4, sel_tfs=sel_tfs, cls_1=cls_1, cls_2=cls_2, markers=markers, log=True)


# Survival analysis

In [112]:
from lifelines.statistics import multivariate_logrank_test

def prep_survival(df, cs_model="RawKMeans", label="SBM"):
    
    colors_net, color_map = px.colors.qualitative.G10, {}
    for idx, val in enumerate(df[cs_model].unique()):
        color_map[val] = colors_net[idx]

    df[cs_model] = df[cs_model].astype(str)
    fig = survival_plot(df.drop(columns=["days_to_last_follow_up", "days_to_death"]), vu_output, classifier=cs_model, color_map=color_map)
    fig = fig.update_layout(title="{}. Survival analysis for {}".format(label, cs_model))

    return fig, color_map

def survival_sig(df, model):
    df = df.reset_index().rename(columns={"index": "Sample"}).copy(deep=True)
    classifier = model

    dmy = df[["days_to_last_follow_up", "days_to_death", classifier]].replace("--", 0).astype(int)
    dmy["last_contact"] = dmy[["days_to_last_follow_up", "days_to_death"]].max(axis=1).div(30)

    labels = list(df[model].unique())
    dmy = dmy[dmy[classifier].isin(labels)]
    print(labels)

    results = multivariate_logrank_test(dmy["last_contact"], dmy[classifier], dmy["days_to_death"])
    display(results.print_summary())
    print("{0:.6f}".format(results.p_value))

# add the survival metadata
tcga_metadata = pd.read_csv(f"{data_base}/tumour/TCGA_metadata.tsv", sep="\t", index_col="Sample")
morp_df['days_to_last_follow_up'] = tcga_metadata['days_to_last_follow_up']
morp_df['days_to_death'] = tcga_metadata['days_to_death']

In [108]:
cluster_model = 'dendrogram_cut'
fig, dendo_color_map = prep_survival(morp_df, cs_model=cluster_model, label="CS_15")
# save_fig(name="Survival_plot_reward", fig=fig, base_path=figures_path, width=1400, height=600)

fig.update_layout(
    legend=dict(
        orientation="h",
        title="Network subtype",
        yanchor="bottom",
        y=0.9,
        xanchor="center",
        x=0.5,
        bgcolor="rgba(0,0,0,0)",
        font=dict(size=16, color="#003366"),
    ),
    title="",
    template="ggplot2",  # "ggplot2", "plotly_white"
    # paper_bgcolor="rgba(0,0,0,0)",
    # plot_bgcolor="rgba(0,0,0,0)",
    xaxis=dict(tickfont=dict(size=16)),
    yaxis=dict(tickfont=dict(size=16)),
    font=dict(size=16),
    height=700
)
fig.show()


In [113]:
morp_df[cluster_model] = morp_df[cluster_model].astype(float).astype(int)
survival_sig(morp_df, model=cluster_model)

[3, 4, 5, 12, 13]


0,1
t_0,-1
null_distribution,chi squared
degrees_of_freedom,4
test_name,multivariate_logrank_test

Unnamed: 0,test_statistic,p,-log2(p)
0,22.78,<0.005,12.8


None

0.000140


### Comparing survival

In [99]:
cluster_model = "dendrogram_cut"
comp_model = 'TCGA408_classifier'

colors_ref = px.colors.qualitative.Pastel2
color_map = {
    "LumP": colors_ref[0],
    "Lum Inf/Ns": colors_ref[1],
    "High IFNG": colors_ref[2],
    "Low IFNG": colors_ref[3],
    "Med IFNG": colors_ref[4],
    "Ne": colors_ref[5],
}

color_map_grey = {label: "grey" for label in morp_df[comp_model].unique()}
color_map = dict(color_map, **color_map_grey)

# choose the subtypes for each to compare if needed
select_labels_1, select_labels_2 = None, None

colors_net = px.colors.qualitative.G10
morp_df[cluster_model] = morp_df[cluster_model].astype(str)
for idx, val in enumerate(morp_df[cluster_model].unique()):
    color_map[val] = colors_net[idx]

fig = survival_comp(
    morp_df.drop(columns=["days_to_last_follow_up", "days_to_death"]),
    vu_output,
    classifier_1=cluster_model,
    classifier_2=comp_model,
    selected_labels_1=select_labels_1,
    selected_labels_2=select_labels_2,
    color_map=color_map,
)
fig = fig.update_layout(title="Survival analysis {}".format("VU + in-situ"))
fig.update_layout(height=900)
# save_fig(name="Survival_plot_reward", fig=fig, base_path=figures_path, width=1400, height=600)

## Apply clustering analysis

In [None]:
plot_data = tum_tpms_v4.loc[sel_tfs.index]
plot_data = np.log2(plot_data + 1)

gh.find_pcs((plot_data))

Sum of 90% variance at PC: 15
Change < 1% at PC: 5


In [None]:
selected_clusters = ["Birch", "RawKMeans", "GaussianMixture", "Ward", "SpectralClustering", "Avg"]

# run experiments
outputs, _, all_metrics, _ = cs.compare_exp(
    plot_data, rob_comp=None, n_clusters=None, selected_clusters=selected_clusters, show_figures=False, show_consensus=True, pca_data=False, n_comp=15,
)
outputs.set_index("Sample", inplace=True)

show_figs=False
if show_figs:
    # Plot the metrics
    fig = cs.display_metrics(all_metrics, f"Cluster metrics for Selected TF", show_individual=False, verbose=True)
    gh.plot_individual_metric(all_metrics, pca=False, offset_db=4)

Variation per principal component [0.63955238 0.09883797] and the sum 73.84%
