In [None]:
import pandas as pd
from pathlib import Path

import kaplanmeier as km
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
DPI=300

%config InlineBackend.figure_format = 'retina'
%matplotlib inline

In [None]:
DIR2LOAD = Path("/data/BCI-CRC/nasrine/data/CRC/spatial/CRC_LM_VISIUM/CRC_LM_VISIUM_04_08_09_11/cell2loc_spatialde2/concat_withWu2022/microenvs_geneexp/")

DIR2SAVE = DIR2LOAD.joinpath("survival_analysis/")
DIR2SAVE.mkdir(parents=True, exist_ok=True)

FIG2SAVE = DIR2SAVE.joinpath("Figures/")
FIG2SAVE.mkdir(parents=True, exist_ok=True)

In [None]:
### Get DEGs
file = DIR2LOAD.joinpath('concat_withWu2022_DE_microenvs.xlsx')
df = pd.read_excel(file)
signatures_dict = df.to_dict(orient='list')

In [None]:
# Read bulk expression data and survival data

def kmplot_bulk(expression_path, survival_path, survival_col_name, survival_dict):
    
    expression = pd.read_parquet(expression_path)
    survival = pd.read_csv(survival_path, sep="\t")
    survival[survival_col_name] = survival[survival_col_name].replace(survival_dict)
    
    # Remove NA rows from survival data
    print (f"\nBefore removing NA from survival data: {survival.shape}")
    survival = survival.dropna(axis=0)
    print (f"After removing NA from survival data: {survival.shape}")
    print(f"Patients in GEX but not survial data: {len([i for i in list(expression.columns[:-1]) if i not in list(survival['Patient ID'])])}")
    patients_to_keep = [i for i in list(expression.columns[:-1]) if i in list(survival["Patient ID"])]
    
    expression = expression[patients_to_keep+["Gene"]] # the last column ("Gene") contains gene symbols
    
    print(f"Shape survival data: {survival.shape}")
    print(f"Shape expression data: {expression.shape}\n")
    
    return expression, survival


In [None]:
def kmplot(expression, signatures_specific, survival, savepath, saveprefix, n_genes_survival=50):
    pValues = dict()
    expression_signatures_dict = dict()
    for key in signatures_specific:
        # Subset bulk expression data for genes in key gene signature (i.e. top 100 genes)
        expression_filtered = expression[expression["Gene"].isin(signatures_specific[key][0:n_genes_survival])]
        ###print(f"How many genes in {key} signature in bulk data: {expression_filtered.shape[0]}")
        expression_signatures = expression_filtered.iloc[:, :-1].mean(axis=0).sort_values().to_frame()
        expression_signatures.columns = ["Signature_score"]

        # Divide bulk expression data into tertiles
        #n_tertile = int(len(expression_signatures) / 3)
        #expression_signatures["Tertiles"] = "Med"
        #expression_signatures.iloc[0:n_tertile, 1] = "Low"
        #expression_signatures.iloc[-n_tertile:, 1] = "High"
        
        n_tertile = int(len(expression_signatures) / 2)
        expression_signatures["Tertiles"] = "Low"
        # assign high to those scores of the signature that are above median
        expression_signatures.iloc[-n_tertile:, 1] = "High"

        # Merge survival data with bulk expression
        survival_km = survival.merge(expression_signatures, how="inner", left_on="Patient ID", right_index=True)
        survival_km = survival_km[survival_km["Tertiles"].isin(["Low", "High"])]

        # Compute Survival
        results = km.fit(survival_km.iloc[:,3],
                         survival_km.iloc[:,2],
                         survival_km['Tertiles'])
        pValues[key] = results['logrank_P']
        expression_signatures_dict[key] = expression_signatures

        # Save figure if significant
        #if pValues[key] < 0.05:
        print (key, pValues[key])
        with plt.style.context('default'):
            km.plot(results, full_ylim=True,
                    savepath = Path(savepath, f"{saveprefix}_{key}.pdf"),
                    cii_lines=None
                        )
    
    with pd.ExcelWriter(Path(savepath,f"Expression_signatures_{saveprefix}.xls")) as writer:  
        for key in expression_signatures_dict:
            expression_signatures_dict[key].to_excel(writer, sheet_name=f"{key}")
    
    print ("\nSignificant results:")
    for key in pValues:
        if pValues[key] < 0.05:
            print (f"\t{key}")
    #return expression_signatures_dict

In [None]:
# Path to expression data, only for unique patients, final column contained gene symbol saved as parquet
# Path to survival data 
# Name of column in survival data containing censorship/event details, needs to be converted to 0 or 1 (integers) using survival dict
# Survival dict for PFS {"0:CENSORED": 0, "1:PROGRESSION": 1}; for OS {"0:LIVING":0, "1:DECEASED":1}
expression, survival = kmplot_bulk(expression_path = "/data/BCI-CRC/SO/data/public/TCGA/COADREAD_TPonly_uniquePatients_zscore.parquet",
                                   survival_path = "/data/BCI-CRC/SO/data/public/TCGA/cBioPortal/KM_Plot__Progression_Free_(months).txt",
                                   survival_col_name = "PFS_STATUS",
                                   survival_dict = {"0:CENSORED":0, "1:PROGRESSION":1}
                                  )

# Plot KM plots using kaplanmeier package
kmplot(expression,
       signatures_dict,
       survival,
       savepath=FIG2SAVE,
       saveprefix="niches_PFS_30",
       n_genes_survival=30
      )



In [None]:
# Path to expression data, only for unique patients, final column contained gene symbol saved as parquet
# Path to survival data 
# Name of column in survival data containing censorship/event details, needs to be converted to 0 or 1 (integers) using survival dict
# Survival dict for PFS {"0:CENSORED": 0, "1:PROGRESSION": 1}; for OS {"0:LIVING":0, "1:DECEASED":1}
expression, survival = kmplot_bulk(expression_path = "/data/BCI-CRC/SO/data/public/TCGA/COADREAD_TPonly_uniquePatients_zscore.parquet",
                                   survival_path = "/data/BCI-CRC/SO/data/public/TCGA/cBioPortal/KM_Plot__Overall_(months).txt",
                                   survival_col_name = "OS_STATUS",
                                   survival_dict = {"0:LIVING":0, "1:DECEASED":1}
                                  )

# Plot KM plots using kaplanmeier package
kmplot(expression,
       signatures_dict,
       survival,
       savepath=FIG2SAVE,
       saveprefix="niches_OS_100",
       n_genes_survival=100
      )

