In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import pathlib as pl
from sklearn.preprocessing import StandardScaler

import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from statannotations.Annotator import Annotator

In [None]:
def pretty_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

In [None]:
def get_tpm(gencode_mapping: pd.DataFrame, bulk: pd.DataFrame) -> pd.DataFrame:
    gene_length = (gencode_mapping.set_index("gene_name")["end"] - gencode_mapping.set_index("gene_name")["start"])
    gene_length = gene_length[~(gene_length.index.duplicated())]

    bulk = bulk.loc[:,bulk.columns.intersection(gene_length.index)]
    gene_length = gene_length.loc[bulk.columns.intersection(gene_length.index)]

    full_rpk = bulk/gene_length

    pm_factor = full_rpk.sum(axis=1)/1000000

    tpm = (full_rpk.T / pm_factor).T
    
    return tpm

In [None]:
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
from pydeseq2.utils import load_example_data

In [None]:
gex_df = pd.read_csv("/add/path/here/GSE207526_110.EAC.and.10.Normal.for.GSEA.txt",sep="\t").iloc[1:,:].T

gencode_mapping = pd.read_csv("/add/path/here/gencode_v41_positions.csv",index_col=0)

tpm = get_tpm(gencode_mapping, gex_df)

In [None]:
survival_clin = pd.read_csv("/add/path/here/data.SPSS.subselect.txt", sep="\t", index_col=0).set_index("FileName.GenomeScan")

In [None]:
import GEOparse

gse = GEOparse.get_GEO(geo="GSE207526", destdir="/add/path/here/")

In [None]:
clinical_df = []
for gsm_name, gsm in gse.gsms.items():
    name = gsm.metadata["title"][0].split(" ")[1]
    gsmid = gsm.metadata["geo_accession"][0]
    charac = gsm.metadata["characteristics_ch1"]
    disease = charac[0].split(": ")[1]
    treated = charac[1].split(": ")[1]

    df = pd.DataFrame([name,gsmid,disease,treated]).T
    df.columns=["ID","GSM_ID","Disease","Treatment"]
    clinical_df.append(df)
    

clinical_df = pd.concat(clinical_df).set_index("ID")

clinical_df["Status"] = clinical_df["Disease"].replace({"esophageal adenocarcinoma": "EAC", "healthy squamous cell tissue from patient with esophageal adenocarcinoma": "NAT"})

clinical_df["Study"] = clinical_df.index.str.split(".").str[0]

clinical_df = pd.concat([clinical_df,survival_clin],axis=1)

clinical_df["OS.status"] = clinical_df["OS.status"].replace({"Alive or censored": 0, "Deceased": 1})

clinical_df["Stage"] = clinical_df["AJCC_baseline_detailed"].replace({"stage IIIA": "III",
                                                                      "stage IV": "IV",
                                                                      "stage IIIB": "III", 
                                                                      "stage IIB": "II",
                                               "stage IIIC": "III", "stage IIA": "II",
                                               "not possible to stage": np.nan})

clinical_df = clinical_df[clinical_df.Status=="EAC"]

In [None]:
tpm = tpm.loc[clinical_df.index]

In [None]:
signature_dir = pl.Path("/add/path/here/signatures_canceronly/")

full_sigs = {}
for s in (signature_dir).iterdir():
    sig = s.stem
    full_sigs[sig] = pd.read_csv(s,index_col=0).set_index("0")
    full_sigs[sig] = full_sigs[sig][~full_sigs[sig].index.str.startswith(("MT-","RPS","RPL"))]
    full_sigs[sig] = full_sigs[sig].index.ravel()

In [None]:
import gseapy as gp
ss = gp.ssgsea(data=tpm.T,
               gene_sets=full_sigs.copy(),
               outdir=None,
               sample_norm_method='rank', # choose 'custom' will only use the raw value of `data`
               no_plot=True)

state_score = ss.res2d.T

In [None]:
state_score = pd.concat([state_score,clinical_df],axis=1)

In [None]:
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors

from sklearn.decomposition import PCA
pca = PCA(n_components=2)

std_fpkm = (tpm - tpm.mean())/tpm.std()
std_fpkm = std_fpkm.dropna(axis=1)
X_pca = pca.fit_transform(std_fpkm)
X_pca = pd.DataFrame(X_pca,index=std_fpkm.index,columns=[f"PC{i}" for i in range(1,3)])

X_pca = pd.concat([X_pca, state_score[["cNMF_1","cNMF_2","cNMF_3","cNMF_4","cNMF_5"]]],axis=1)
X_pca.columns = ["PC1","PC2","cNMF$_{1}$","cNMF$_{2}$","cNMF$_{3}$","cNMF$_{4}$","cNMF$_{5}$"]

def plot_pcs_color(ax, state):
    vcenter = X_pca[state].median()
    vmin, vmax = X_pca[state].min(), X_pca[state].max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = matplotlib.colormaps['RdBu_r']
    sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        c=X_pca[state],
        s=10,
        norm=normalize,
        cmap=colormap,
        ax=ax
    )
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(X_pca[state])
    fig.colorbar(scalarmappaple, ax=ax)
    ax.set_title(state)
    pretty_ax(ax)

fig, ax = plt.subplots(1,5, figsize=(15,2))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF$_{1}$")
plot_pcs_color(flatax[1], "cNMF$_{2}$")
plot_pcs_color(flatax[2], "cNMF$_{3}$")
plot_pcs_color(flatax[3], "cNMF$_{4}$")
plot_pcs_color(flatax[4], "cNMF$_{5}$")

fig.tight_layout()


# Survival data

In [None]:
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test

In [None]:
clin = state_score[[f"cNMF_{i}" for i in range(1,4)]+["OS.days","OS.status"]].dropna()

cox_results = []
for state in ["cNMF_1","cNMF_2","cNMF_3","cNMF_4","cNMF_5"]: 
    expr = state_score[state]
    expr = (expr - expr.mean())/expr.std()
    cox_clin = pd.concat([clin[["OS.status","OS.days",]],expr],axis=1).dropna()
    
    cph = CoxPHFitter()
    cph.fit(cox_clin, duration_col="OS.days", event_col="OS.status")
    summ = cph.summary
    summ = summ.rename(index={"age": f"{state}_age", "gender": f"{state}_gender"})
    cox_results.append(summ)

cox_results = pd.concat(cox_results)

ps = cox_results.loc[:,"p"].ravel()
x = cox_results.loc[:,"exp(coef)"].ravel()
lower = cox_results.loc[:,"exp(coef) lower 95%"].ravel()
upper = cox_results.loc[:,"exp(coef) upper 95%"].ravel()
ci = [x - lower, upper - x]
names = cox_results.index
colorlist = ["red" if ((x[i] - ci[0][i])>1 and (ci[1][i] + x[i])>1) else ("blue" if ((x[i] - ci[0][i])<1 and (ci[1][i] + x[i])<1) else "black") for i in range(len(ci[0]))]

fig, ax = plt.subplots(1,1,figsize=(3,3))
ax.errorbar(x,np.arange(0,len(x))[::-1],
            xerr=ci, marker="s",
            linewidth=0,
            elinewidth=2,
            ecolor=colorlist,
            markerfacecolor="black",
            markeredgecolor="black")
ax.spines[["bottom","left"]].set_linewidth(2)
ax.spines[["top","right"]].set_visible(False)
ax.vlines(1,ymin=ax.get_ylim()[0],ymax=ax.get_ylim()[1],linestyle="--",color="grey")
ax.set_yticks(np.arange(0,len(x)))
ax.set_yticklabels(["cNMF$_{1}$","cNMF$_{2}$","cNMF$_{3}$","cNMF$_{4}$","cNMF$_{5}$"][::-1])
for i,p in enumerate(ps[::-1]):
    ax.text(ax.get_xlim()[1], i, f"p={p:.2e}")
    
fig.savefig("figures/survival_GSE207526_cox_plot.svg", dpi=200, bbox_inches="tight")