In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import pathlib as pl
from sklearn.preprocessing import StandardScaler

import seaborn as sns
import matplotlib.pyplot as plt

import signaturescoring as ssc

from tqdm.notebook import tqdm

from statannotations.Annotator import Annotator

In [None]:
def pretty_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

In [None]:
def get_tpm(gencode_mapping: pd.DataFrame, bulk: pd.DataFrame) -> pd.DataFrame:
    gene_length = (gencode_mapping.set_index("gene_name")["end"] - gencode_mapping.set_index("gene_name")["start"])
    gene_length = gene_length[~(gene_length.index.duplicated())]

    bulk = bulk.loc[:,bulk.columns.intersection(gene_length.index)]
    gene_length = gene_length.loc[bulk.columns.intersection(gene_length.index)]

    full_rpk = bulk/gene_length

    pm_factor = full_rpk.sum(axis=1)/1000000

    tpm = (full_rpk.T / pm_factor).T
    
    return tpm

In [None]:
mTFs = ['KLF5', 'ELF3', 'SMAD3', 'TCF7L2', 'HMGA2',"BNC2"]

In [None]:
gex_df = pd.read_csv("/add/path/here/GSE207526/GSE207526_110.EAC.and.10.Normal.for.GSEA.txt",sep="\t").iloc[1:,:].T

In [None]:
gencode_mapping = pd.read_csv("/add/path/here/gencode_v41_positions.csv",index_col=0)

In [None]:
tpm = get_tpm(gencode_mapping, gex_df)

In [None]:
import GEOparse

gse = GEOparse.get_GEO(geo="GSE207526", destdir="/add/path/here")

clinical_df, gex_df = [],[]
for gsm_name, gsm in gse.gsms.items():
    name = gsm.metadata["title"][0].split(" ")[1]
    gsmid = gsm.metadata["geo_accession"][0]
    charac = gsm.metadata["characteristics_ch1"]
    disease = charac[0].split(": ")[1]
    treated = charac[1].split(": ")[1]

    df = pd.DataFrame([name,gsmid,disease,treated]).T
    df.columns=["ID","GSM_ID","Disease","Treatment"]
    clinical_df.append(df)
    

clinical_df = pd.concat(clinical_df).set_index("ID")

clinical_df["Status"] = clinical_df["Disease"].replace({"esophageal adenocarcinoma": "EAC", "healthy squamous cell tissue from patient with esophageal adenocarcinoma": "NAT"})

clinical_df["Study"] = clinical_df.index.str.split(".").str[0]

survival_clin = pd.read_csv("/add/path/here/data.SPSS.subselect.txt", sep="\t", index_col=0).set_index("FileName.GenomeScan")

clinical_df = pd.concat([clinical_df,survival_clin],axis=1)

clinical_df["OS.status"] = clinical_df["OS.status"].replace({"Alive or censored": 0, "Deceased": 1})

clinical_df["Stage"] = clinical_df["AJCC_baseline_detailed"].replace({"stage IIIA": "III",
                                                                      "stage IV": "IV",
                                                                      "stage IIIB": "III", 
                                                                      "stage IIB": "II",
                                               "stage IIIC": "III", "stage IIA": "II",
                                               "not possible to stage": np.nan})

In [None]:
signature_dir = pl.Path("/add_path_here/")

full_sigs = {}
for s in (signature_dir).iterdir():
    sig = s.stem
    full_sigs[sig] = pd.read_csv(s,index_col=0)
    full_sigs[sig] = full_sigs[sig][~full_sigs[sig].index.str.startswith(("MT-","RPS","RPL"))]
    full_sigs[sig] = full_sigs[sig].head(100).index.ravel()

In [None]:
ss = StandardScaler()
std_fpkm = pd.DataFrame(ss.fit_transform(tpm),index=tpm.index,columns=tpm.columns)

state_score = []
for sig, genes in full_sigs.items():
    selgenes = std_fpkm.columns.intersection(genes)
    selgenes = np.setdiff1d(selgenes,mTFs)
    scores = std_fpkm[selgenes].mean(axis=1)
    scores.name = sig
    state_score.append(scores)
state_score = pd.concat(state_score,axis=1)

In [None]:
state_score = pd.concat([state_score,clinical_df],axis=1)

In [None]:
from itertools import combinations

In [None]:
vc = state_score["Status"].value_counts().to_dict()

for score in ["cNMF_3","cNMF_4","cNMF_1","cNMF_5"]:

    fig, ax = plt.subplots(1,1,figsize=(6,5))
    pairs = list(combinations(state_score["Status"].unique(), 2))
    order = ["NAT",'EAC']
    ax= sns.boxplot(data=state_score, x="Status", y=score,order=order)
    annot = Annotator(
        ax,
        pairs=pairs,
        data=state_score, x="Status", y=score,order=order
    )
    annot.configure(
        test="Mann-Whitney",
        loc="inside",
        text_format="star",
        show_test_name=False,
        verbose=2,
        comparisons_correction=None,
        fontsize=10,
    )
    annot.apply_test()
    _, test_results = annot.annotate()
    pretty_ax(ax)
    ax.set_xticklabels([el.replace(" ","\n") + f"\nN={vc[el]}" for el in order])
    ax.set_xlabel("")

In [None]:
vc = state_score["Stage"].value_counts().to_dict()

for score in ["cNMF_3","cNMF_4","cNMF_1","cNMF_5"]:

    fig, ax = plt.subplots(1,1,figsize=(6,5))
    pairs = list(combinations(state_score["Stage"].dropna().unique(), 2))
    order = ["II","III","IV"]
    ax= sns.boxplot(data=state_score, x="Stage", y=score,order=order)
    annot = Annotator(
        ax,
        pairs=pairs,
        data=state_score, x="Stage", y=score,order=order
    )
    annot.configure(
        test="Mann-Whitney",
        loc="inside",
        text_format="star",
        show_test_name=False,
        verbose=2,
        comparisons_correction=None,
        fontsize=10,
    )
    annot.apply_test()
    _, test_results = annot.annotate()
    pretty_ax(ax)
    ax.set_xticklabels([el.replace(" ","\n") + f"\nN={vc[el]}" for el in order])
    ax.set_xlabel("")

In [None]:
ax = sns.scatterplot(data=state_score, x="cNMF_3", y="cNMF_4", hue="Status")
ax.spines[['right', 'top']].set_visible(False)
ax.hlines(y=0, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], linestyles="dashed", color="grey")
ax.vlines(x=0, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], linestyles="dashed", color="grey")

In [None]:
cNMF_3_patients = state_score[(state_score["cNMF_3"]>=0) & (state_score["cNMF_4"]<0)].index

cNMF_4_patients = state_score[(state_score["cNMF_4"]>=0) & (state_score["cNMF_3"]<0)].index

cNMF_mixed_patients = state_score[(state_score["cNMF_4"]>0) & (state_score["cNMF_3"]>0)].index

In [None]:
state_score["PatClass"] = "None"
state_score.loc[cNMF_3_patients, "PatClass"] = "cNMF_3"
state_score.loc[cNMF_4_patients, "PatClass"] = "cNMF_4"
state_score.loc[cNMF_mixed_patients, "PatClass"] = "Mixed"

In [None]:
ax = sns.scatterplot(data=state_score, x="cNMF_3", y="cNMF_4", hue="Study")
ax.spines[['right', 'top']].set_visible(False)
ax.hlines(y=0, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], linestyles="dashed", color="grey")
ax.vlines(x=0, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], linestyles="dashed", color="grey")

In [None]:
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors

from sklearn.decomposition import PCA
pca = PCA(n_components=50)
X_pca = pca.fit_transform(std_fpkm)
X_pca = pd.DataFrame(X_pca,index=std_fpkm.index,columns=[f"PC{i}" for i in range(1,51)])

X_pca = pd.concat([X_pca, state_score],axis=1)

def plot_pcs_color(ax, state):
    vcenter = 0
    vmin, vmax = X_pca[state].min(), X_pca[state].max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = matplotlib.colormaps['RdBu_r']
    sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        c=X_pca[state],
        s=10,
        norm=normalize,
        cmap=colormap,
        ax=ax
    )
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(X_pca[state])
    fig.colorbar(scalarmappaple, ax=ax)
    ax.set_title(state)
    pretty_ax(ax)

fig, ax = plt.subplots(1,3, figsize=(11,3))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_3")
plot_pcs_color(flatax[1], "cNMF_1")
plot_pcs_color(flatax[2], "cNMF_4")
fig.tight_layout()
fig.savefig("figures/external/GSE207526_PC_wCNMF_score.png", dpi=200, bbox_inches="tight")

fig, ax = plt.subplots(1,2, figsize=(6,2))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_2")
plot_pcs_color(flatax[1], "cNMF_5")

fig.tight_layout()

In [None]:
sns.scatterplot(x=np.arange(50), y=pca.explained_variance_ratio_)

In [None]:
X_pca[["cNMF_3","cNMF_4","PC1","PC2"]].corr()

In [None]:
fig, ax = plt.subplots(1,1)
sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        hue=X_pca["Status"],
        ax=ax
    )

In [None]:
fig, ax = plt.subplots(1,1)
sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        hue=X_pca["Study"],
        ax=ax
    )

# Survival data

In [None]:
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test


In [None]:
clin = state_score[[f"cNMF_{i}" for i in range(1,6)]+["OS.days","OS.status","PatClass"]].dropna()

for i in range(1,6):
    metasig = f"cNMF_{i}"
    scores = state_score[metasig]

    duration, event = {},{}
    for high in [0,1]:
        if high==1:
            stratification = scores>=scores.quantile(0.7)
            df = clin[stratification]
            duration[high] = df["OS.days"].ravel()
            event[high] = df["OS.status"].ravel()
        else:
            stratification = scores<=scores.quantile(0.3)
            df = clin[stratification]
            duration[high] = df["OS.days"].ravel()
            event[high] = df["OS.status"].ravel()

    fig, ax = plt.subplots(1,1)
    kmf = KaplanMeierFitter() 
    ## Fit the data into the model
    kmf.fit(duration[0], event[0], label='Low score')
    kmf.plot(show_censors=True,c="r",ax=ax,ci_alpha=0.1)
    kmf.fit(duration[1], event[1], label='High score')
    kmf.plot(show_censors=True,c="b",ax=ax)
    pretty_ax(ax)
    ax.set_ylabel("OS")
    ax.set_xlabel("Time to event")
    results=logrank_test(duration[0],duration[1],event_observed_A=event[0], event_observed_B=event[1])
    results.print_summary()
    ax.text(0.75*ax.get_xlim()[1],0.8,f"p={results.p_value:.1e}",fontsize=13)

# Correlation with TFs

In [None]:
corr_df = pd.concat([state_score[["cNMF_1","cNMF_3","cNMF_4"]],
           tpm.loc[:,mTFs]],axis=1)

In [None]:
heatmap_df = corr_df.corr().loc[["cNMF_3","cNMF_1","cNMF_4"],mTFs]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4,1.5))
sns.heatmap(data=heatmap_df, annot=heatmap_df, cmap="vlag", center=0, ax=ax)
fig.savefig("figures/celllines/heatmap_cNMF_TF_corr.png", dpi=300, bbox_inches="tight")