In [None]:
import pandas as pd 
import numpy as np
import scanpy as sc
import pathlib as pl

In [None]:
import signaturescoring as ssc

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from statannotations.Annotator import Annotator

In [None]:
def pretty_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

In [None]:
cell_cycle_genes = [x.strip() for x in open('/add/path/here/regev_lab_cell_cycle_genes.txt')]
s_genes = cell_cycle_genes[:43]
g2m_genes = cell_cycle_genes[43:]

In [None]:
adata = sc.read_h5ad("/add/path/here/Carroll_singlecell/Carroll_EAC_raw.h5ad")

In [None]:
adata.layers["counts"] = adata.X.copy()

In [None]:
sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)

In [None]:
del adata.raw

In [None]:
sc.tl.pca(adata)

In [None]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=["celltype"])

In [None]:
adata.obs["condition"] = adata.obs["sample"].str.split("_").str[1]

In [None]:
subadata = adata[adata.obs.celltype.isin(["EAC"])].copy()

In [None]:
sc.pp.filter_genes(subadata, min_cells=20)

In [None]:
sc.tl.score_genes_cell_cycle(subadata, s_genes=s_genes, g2m_genes=g2m_genes)

In [None]:
sc.tl.pca(subadata)

In [None]:
sc.external.pp.harmony_integrate(subadata, key="patient", max_iter_harmony=20)

In [None]:
sc.pp.neighbors(subadata, use_rep="X_pca_harmony")

sc.tl.umap(subadata)

In [None]:
sc.pl.umap(subadata, color=["patient"])

In [None]:
signature_dir = pl.Path("/add/path/here")

full_sigs = {}
for s in (signature_dir).iterdir():
    sig = s.stem
    full_sigs[sig] = pd.read_csv(s,index_col=0)
    full_sigs[sig] = full_sigs[sig][~full_sigs[sig].index.str.startswith(("MT-","RPS","RPL"))]
    full_sigs[sig] = full_sigs[sig].head(75).index.ravel()

In [None]:
mTFs =['KLF5', 'ELF3', 'SMAD3', 'TCF7L2', 'HNF4G', "BNC2"]

In [None]:
for sig, genes in full_sigs.items():
    ssc.score_signature(adata=subadata,
                        gene_list=list(np.setdiff1d(genes,mTFs)), 
                        method="adjusted_neighborhood_scoring", 
                        ctrl_size=150,
                        score_name=f"{sig}_score")

In [None]:
ssc.score_signature(adata=subadata,
                        gene_list=["KLF5","ELF3","SMAD3","TCF7L2"], 
                        method="adjusted_neighborhood_scoring", 
                        ctrl_size=150,
                        score_name="mTF_score")

In [None]:
sc.pl.umap(subadata, 
           color=['cNMF_1_score','cNMF_3_score','cNMF_4_score',"celltype"],
           ncols=2, frameon=False)

In [None]:
sc.pl.umap(subadata, 
           color=['S_score','G2M_score'],
           ncols=2, frameon=False)

In [None]:
subadata.obs["condition"] = subadata.obs["sample"].str.split("_").str[1]

In [None]:
sc.pl.umap(subadata, 
           color=['condition',"celltype","patient"],
           ncols=2, frameon=False)

In [None]:
subadata.obs[["cNMF_1_score","cNMF_3_score","cNMF_4_score"]].corr()

In [None]:
df = subadata.obs[subadata.obs.tissue.isin(["EAC","EAC.Op"])].copy()

In [None]:
df.patient = df.patient.astype(str)

In [None]:
ax = sns.boxplot(data=df, x="patient", y="cNMF_4_score", hue="condition", hue_order=["PreTx","ICI-4W","PostTx"])
pretty_ax(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)

In [None]:
subadata.obs.condition.value_counts()

In [None]:
ax = sns.boxplot(data=df, x="patient", y="cNMF_3_score", hue="condition", hue_order=["PreTx","ICI-4W","PostTx"])
pretty_ax(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)

In [None]:
ax = sns.boxplot(data=df, x="patient", y="cNMF_5_score", hue="condition", hue_order=["PreTx","ICI-4W","PostTx"])
pretty_ax(ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)

In [None]:
g = sns.PairGrid(subadata.obs[["cNMF_1_score","cNMF_3_score","cNMF_4_score"]], diag_sharey=False, corner=True)
g.map_lower(sns.kdeplot)
g.map_diag(sns.histplot)
g.fig.savefig("figures/external/Carroll_cNMF_relplot.png", dpi=200, bbox_inches="tight")

In [None]:
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors

X_pca = pd.DataFrame(subadata.obsm["X_pca_harmony"][:,:2],index=subadata.obs_names,columns=["PC1","PC2"])

X_pca = pd.concat([X_pca, subadata.obs[["cNMF_1_score","cNMF_3_score",
                                        "cNMF_4_score","cNMF_2_score",
                                        "cNMF_5_score","S_score","G2M_score"]]],axis=1)
X_pca.columns = ["PC1","PC2","cNMF_1","cNMF_3","cNMF_4","cNMF_2","cNMF_5","S_score","G2M_score"]

def plot_pcs_color(ax, state):
    vcenter = 0
    vmin, vmax = X_pca[state].min(), X_pca[state].max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = matplotlib.colormaps['RdBu_r']
    sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        c=X_pca[state],
        s=10,
        norm=normalize,
        cmap=colormap,
        ax=ax
    )
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(X_pca[state])
    fig.colorbar(scalarmappaple, ax=ax)
    ax.set_title(state)
    pretty_ax(ax)

fig, ax = plt.subplots(1,3, figsize=(15,4))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_3")
plot_pcs_color(flatax[1], "cNMF_1")
plot_pcs_color(flatax[2], "cNMF_4")
fig.tight_layout()
fig.savefig("figures/external/Carroll_PC_wCNMF_score.png", dpi=200, bbox_inches="tight")

fig, ax = plt.subplots(1,2, figsize=(10,4))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_2")
plot_pcs_color(flatax[1], "cNMF_5")

fig.tight_layout()

fig, ax = plt.subplots(1,2, figsize=(10,4))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "S_score")
plot_pcs_color(flatax[1], "G2M_score")

fig.tight_layout()

In [None]:
TF_expr = pd.DataFrame(subadata[:,mTFs].X.copy().toarray(),index=subadata.obs_names,columns=mTFs)
TF_expr = pd.concat([TF_expr,subadata.obs[["cNMF_1_score","cNMF_3_score","cNMF_4_score"]]],axis=1)

heatmap_df = TF_expr.corr().loc[["cNMF_3_score","cNMF_1_score","cNMF_4_score"],mTFs]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(5,2))
sns.heatmap(data=heatmap_df, annot=heatmap_df, cmap="vlag", center=0, ax=ax, fmt=".2f")
ax.set_yticklabels(["cNMF_3","cNMF_1","cNMF_4"])
fig.savefig("figures/external/Carroll_heatmap_cNMF_TF_corr.png", dpi=300, bbox_inches="tight")

In [None]:
df = subadata.obs[["cNMF_3_score","cNMF_1_score","cNMF_4_score","mTF_score"]]
heatmap_df = df.corr().loc[["mTF_score"],["cNMF_3_score","cNMF_1_score","cNMF_4_score"]]
fig, ax = plt.subplots(1,1,figsize=(2,0.5))
sns.heatmap(data=heatmap_df, annot=heatmap_df, cmap="vlag", center=0, ax=ax)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(["cNMF_3","cNMF_1","cNMF_4"], rotation=45, ha="right")
fig.savefig("figures/external/Carroll_heatmap_cNMF_mTFscore_corr.png", dpi=300, bbox_inches="tight")

# Read bulk

In [None]:
clinical = pd.read_csv("/add/path/here/carroll_clinical.csv", index_col=0)

In [None]:
gencode = pd.read_csv("/add/path/here/gencode_v41_positions.csv",index_col=0)

gencode["gen_red"] = gencode["gene_id"].str.split(".").str[0]

In [None]:
mapping = gencode.set_index("gen_red")["gene_name"].to_dict()

In [None]:
bulk = pd.read_csv("/add/path/here/Carroll_singlecell/LUD2015-005_RNAseq_featureCounts.tsv",sep="\t",index_col=0)

gene_info = bulk.iloc[:,:5].copy()

bulk = bulk.iloc[:,5:].copy()

In [None]:
logcounts = bulk.sum().apply(np.log1p)

tpm = (bulk.T / gene_info["Length"]).T

tpm = tpm / tpm.sum()*1000000

tpm = tpm.rename(index=mapping).T

tpm = tpm.loc[:,(tpm==0).sum()<=0.5*tpm.shape[0]]

std_tpm = (tpm - tpm.mean())/tpm.std()

std_tpm["condition"] = std_tpm.index.str.split("_").str[1]
std_tpm["patient"] = std_tpm.index.str.split("_").str[0]
std_tpm["tissue"] = std_tpm.index.str.split("_").str[2]

std_tpm = std_tpm[std_tpm["tissue"]=="Tumor"]

In [None]:
state_score = []
for sig, genes in full_sigs.items():
    selgenes = std_tpm.columns.intersection(genes)
    selgenes = np.setdiff1d(selgenes, mTFs)
    scores = std_tpm[selgenes].mean(axis=1)
    scores.name = sig
    state_score.append(scores)
state_score = pd.concat(state_score,axis=1)

state_score["condition"] = state_score.index.str.split("_").str[1]
state_score["patient"] = state_score.index.str.split("_").str[0]
state_score["tissue"] = state_score.index.str.split("_").str[2]
state_score["logcounts"] = logcounts

In [None]:
common_patients = clinical.index.intersection(state_score.patient.unique())

state_score = state_score[state_score.patient.isin(common_patients)]

state_score["condition"] = pd.Categorical(state_score["condition"], ["PreTx","ICI-4W","PostTx"])

In [None]:
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors

from sklearn.decomposition import PCA
pca = PCA(n_components=50)
X_pca = pca.fit_transform(std_tpm.drop(["condition","patient","tissue"],axis=1).dropna(axis=1))
X_pca = pd.DataFrame(X_pca,index=std_tpm.index,columns=[f"PC{i}" for i in range(1,51)])

X_pca = pd.concat([X_pca, state_score],axis=1)

def plot_pcs_color(ax, state):
    vcenter = 0
    vmin, vmax = X_pca[state].min(), X_pca[state].max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = matplotlib.colormaps['RdBu_r']
    sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        c=X_pca[state],
        s=10,
        norm=normalize,
        cmap=colormap,
        ax=ax
    )
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(X_pca[state])
    fig.colorbar(scalarmappaple, ax=ax)
    ax.set_title(state)
    pretty_ax(ax)

fig, ax = plt.subplots(1,3, figsize=(11,3))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_3")
plot_pcs_color(flatax[1], "cNMF_1")
plot_pcs_color(flatax[2], "cNMF_4")
fig.tight_layout()
fig.savefig("figures/external/GSE207526_PC_wCNMF_score.svg", dpi=200, bbox_inches="tight")

fig, ax = plt.subplots(1,2, figsize=(6,2))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_2")
plot_pcs_color(flatax[1], "cNMF_5")

fig.tight_layout()

In [None]:
ax = sns.scatterplot(data=X_pca, x="PC1", y="PC2", hue="condition")
ax.spines[['right', 'top']].set_visible(False)

In [None]:
ax = sns.scatterplot(data=X_pca, x="PC1", y="PC2", hue="logcounts")
ax.spines[['right', 'top']].set_visible(False)

In [None]:
custom_palette = sns.color_palette("mako", 33)
patcolors = {pat: custom_palette[i] for i,pat in enumerate(state_score.patient.unique())}

In [None]:
state_score = state_score.sort_values(by=["patient","condition"])

In [None]:
ax = sns.lineplot(data=state_score, x="cNMF_3", y="cNMF_4", hue="patient",
                  palette=patcolors,
                  markers=list(state_score.condition.replace({"PreTx": "o", "ICI-4W": "v", "PostTx": ""}).ravel()))
marker_map = {"PreTx": "o", "ICI-4W": "v", "PostTx": "s"}
for condition in state_score.condition.unique():
    ax = sns.scatterplot(data=state_score[state_score.condition==condition],
                         x="cNMF_3", y="cNMF_4", 
                         marker=marker_map[condition], hue="patient", palette=patcolors, legend=None)
ax.spines[['right', 'top']].set_visible(False)
ax.hlines(y=0, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], linestyles="dashed", color="grey")
ax.vlines(x=0, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], linestyles="dashed", color="grey")
plt.legend(ncols=2, bbox_to_anchor=(1,1,0,0), frameon=False)

In [None]:
sns.boxplot(data=state_score, y="cNMF_4", x="condition", order=["PreTx","ICI-4W","PostTx"])

In [None]:
sns.boxplot(data=state_score, y="cNMF_5", x="condition", order=["PreTx","ICI-4W","PostTx"])

In [None]:
red_tpm = std_tpm[std_tpm["condition"]=="PreTx"]

In [None]:
state_score = []
for sig, genes in full_sigs.items():
    selgenes = red_tpm.columns.intersection(genes)
    selgenes = np.setdiff1d(selgenes, mTFs)
    scores = red_tpm[selgenes].mean(axis=1)
    scores.name = sig
    state_score.append(scores)
state_score = pd.concat(state_score,axis=1)

state_score["condition"] = state_score.index.str.split("_").str[1]
state_score["patient"] = state_score.index.str.split("_").str[0]
state_score["tissue"] = state_score.index.str.split("_").str[2]

common_patients = clinical.index.intersection(state_score.patient.unique())

state_score = state_score[state_score.patient.isin(common_patients)]

state_score.index = state_score.index.str.split("_").str[0]