In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import palettable

import pathlib as pl

from tqdm.notebook import tqdm

In [None]:
import scib

from scipy.stats import fisher_exact

In [None]:
def pretty_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=False,
        labelleft = False)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

# Download data

In [None]:
adata = sc.read_h5ad("/add/path/here/full_cohort.h5ad")

In [None]:
clinical = pd.read_csv("/add/path/here/EAC_clinical_info.csv",index_col=0)
treatment_mapping = {"Neoadjuvant CROSS": "Neoadj. chemo", "Neoadjuvent carboplatin": "Neoadj. chemo"}
metastatic = (clinical["Tumor?"]=="Yes ") & (clinical["Site"].str.contains("metastasis"))
metastatic.name = "Metastatic?"
clinical["Metastatic?"] = metastatic

clinical["Location"] = clinical["Site"].replace({"GEJ": "Esophagus/GEJ", 
                          "Esophagus": "Esophagus/GEJ"})
clinical["Location"][clinical["Location"].str.contains("Liver")] = "Liver"
clinical["Location"][clinical["Location"].str.contains("Adrenal")] = "Adrenal gland"
clinical["Location"][clinical["Location"].str.contains("Peritoneal")] = "Peritoneum"

clinical["Stage"] = clinical["Grade/stage"].replace({"Stage IV ": "IV", "Stage IV": "IV", 
                                                     "Moderately differentiated; ypT1aN0": "I", 
                                                     "Moderately differentiated; pT1aN0": "I", 
                                                     "Poorly differentiated; ypT2N0": "II", 
                                                     "Presented with stage III became stage IV during esophagectomy when pleural metastases were identified": "III/IV"})

clinical["Treatment"] = ["Neoadj. chemo",
                         "None",
                         "Neoadj. chemo + ICI + RT",
                         "None",
                         "None",
                         "Chemo + HER2 targeted + ICI", 
                         "Neoadj. chemo + HER2 targeted", 
                         "Neoadj. chemo + ICI", 
                         "None",
                         "Neoadj. chemo + VEGFR2i"]

clinical["HER2 status"] = clinical["HER2"].replace({"HER 2 1+": "1+/equivocal"})

clinical = clinical.sort_values(by=["Tumor?","Metastatic?","Location"])

clinical["PD-L1 CPS score"] = [0,2,7,2,3,24,0,8,3,15]

In [None]:
refined_annotations = pd.read_csv("/add/path/here/refined_annotations.csv",index_col=0)

refined_annotations.columns = ["refined_annotations"]

refined_wcancer = pd.read_csv("/add/path/here/refined_wCNMF_programs_and_sampleid.csv",index_col=0)

In [None]:
count_df = refined_wcancer[["sample_id","Corrected label"]].value_counts().unstack()

In [None]:
(count_df.T/count_df.sum(axis=1)).T

In [None]:
colorlist = palettable.colorbrewer.qualitative.Dark2_8.mpl_colors
colorlistbis = palettable.colorbrewer.qualitative.Paired_3.mpl_colors
colormapping_pat = {'Aguirre_EGSFR1982': colorlist[0], 
                    "Aguirre_EGSFR2218": colorlist[1], 
                    "CCG1153_4411": colorlist[2], 
                    "Aguirre_EGSFR1938": colorlist[3], 
                    "Aguirre_EGSFR0074": colorlist[4], 
                    "Aguirre_EGSFR0128": colorlist[5], 
                    "Aguirre_EGSFR1732": colorlist[6], 
                    "Aguirre_EGSFR0148": colorlist[7], 
                    "CCG1153_4496262": colorlistbis[0], 
                    "CCG1153_6640539": colorlistbis[1], "NA": "whitesmoke"}

# snRNA-seq

In [None]:
highlevel_refined = {"Hepatocyte": "Epithelial", 
                     "Carcinoma": "Carcinoma", 
                     "Fibroblast": "Fibroblast", 
                     "Quiescent endothelial cells": "Endothelial", 
                     "Smooth muscle": "Muscle", 
                     "Skeletal muscle": "Muscle",
                     "TAM2": "Myeloid", "TAM3": "Myeloid",
                     "TCD4": "Lymphoid", 
                     "Inflammatory CAF": "Fibroblast", 
                     "Adipose CAF": "Fibroblast",
                     "HGF-CAF": "Fibroblast",
                     "TAM1": "Myeloid", 
                     "Myeloid-HighMT": "Unknown/technical", 
                     "Angiogenic EC": "Endothelial", 
                     "Quiescent EC": "Endothelial", 
                     "Venous EC": "Endothelial",
                     "TCD8": "Lymphoid", 
                     "B": "Lymphoid", 
                     "DC": "Myeloid", 
                     "Hepatic EC": "Endothelial", 
                     "Kupffer cells": "Myeloid", 
                     "NK": "Lymphoid", 
                     "Treg": "Lymphoid", 
                     "StrMus-HighMT": "Unknown/technical", 
                     "T-HighMT": "Unknown/technical", 
                     "Mast": "Myeloid", 
                     "Adipocytes": "Stromal/Muscle", 
                     "Endo-HighMT": "Unknown/technical"}

adata.obs = pd.concat([adata.obs,refined_annotations],axis=1)
adata.obs = pd.concat([adata.obs,refined_wcancer["refined_wcancer"]],axis=1)

adata.obs["highlevel_refined"] = adata.obs.refined_annotations.replace(highlevel_refined)

In [None]:
scib.preprocessing.score_cell_cycle(adata, organism="human")

## Patient-level distributions

In [None]:
patlevel_counts = adata.obs[["sample_id","highlevel_refined"]].groupby(by="sample_id").value_counts(normalize=True)
patlevel_counts = patlevel_counts.round(2)*100

df = patlevel_counts.unstack(level=-1)

df = df.loc[clinical.index.intersection(df.index)].fillna(0).astype(int)

colorlist = sns.color_palette("colorblind", 10)
ctlist = adata.obs.highlevel_refined.unique()
colormapping = {ct: colorlist[i] for i,ct in enumerate(ctlist)}
colormapping["NA"] = "whitesmoke"

In [None]:
def add_clinical_info(ax):

    x1, x2 = 0, 1   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 100, 10, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "NAT", ha='center', va='bottom', color=col)
    
    x1, x2 = 2, 4   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 100, 10, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Primary", ha='center', va='bottom', color=col)
    
    x1, x2 = 5, 9   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 100, 10, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Metastatic", ha='center', va='bottom', color=col)
    
    return ax

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6,3))
df.plot(kind = 'bar', stacked = True, color=colormapping, ax=ax,)
ax.legend(bbox_to_anchor=(1.05, 1), frameon=False)
ax.spines[['right', 'top']].set_visible(False)
ax.set_xlabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
for c in ax.containers:

    # Optional: if the segment is small or 0, customize the labels
    labels = [int(v.get_height()) if v.get_height() >= 5 else '' for v in c]
    
    # remove the labels parameter if it's not needed for customized labels
    ax.bar_label(c, labels=labels, label_type='center', fmt='%0.0f', color="white")
ax = add_clinical_info(ax)
fig.savefig("figures/barplot_pat_tme_highlevel_scaled.png", dpi=300, bbox_inches="tight")

In [None]:
patlevel_counts = adata.obs[["sample_id","highlevel_refined"]].groupby(by="sample_id").value_counts()
df = patlevel_counts.unstack(level=-1)

df = df.loc[clinical.index.intersection(df.index)]

colorlist = sns.color_palette("colorblind", 10)
ctlist = adata.obs.highlevel_refined.unique()
colormapping = {ct: colorlist[i] for i,ct in enumerate(ctlist)}
colormapping["NA"] = "whitesmoke"

In [None]:
def add_clinical_info(ax):

    x1, x2 = 0, 1   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 16000, 1000, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "NAT", ha='center', va='bottom', color=col)
    
    x1, x2 = 2, 4   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 16000, 1000, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Primary", ha='center', va='bottom', color=col)
    
    x1, x2 = 5, 9   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 16000, 1000, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Metastatic", ha='center', va='bottom', color=col)
    
    return ax

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6,3))
df.plot(kind = 'bar', stacked = True, color=colormapping, ax=ax,)
ax.legend(bbox_to_anchor=(1.05, 1), frameon=False)
ax.spines[['right', 'top']].set_visible(False)
ax.set_xlabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax = add_clinical_info(ax)
fig.savefig("figures/barplot_pat_tme_highlevel_nonscaled.png", dpi=300, bbox_inches="tight")

## UMAP viz

In [None]:
sc.tl.pca(adata)

sc.external.pp.harmony_integrate(adata, key="sample_id", max_iter_harmony=20)

sc.pp.neighbors(adata, use_rep="X_pca_harmony")

sc.tl.umap(adata)

In [None]:
fig = sc.pl.umap(adata, color=["highlevel_refined"], palette=colormapping, frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/highlevel_refined_scRNA_umap.png", dpi=300, bbox_inches="tight")

In [None]:
adata.obs.highlevel_refined.value_counts()

In [None]:
fig = sc.pl.umap(adata, color=["sample_id"], frameon=False, ncols=1, return_fig=True, palette=colormapping_pat)
fig.savefig("figures/highlevel_refined_snRNA_umap_sampleid.png", dpi=300, bbox_inches="tight")

In [None]:
fig = sc.pl.umap(adata, color=["refined_annotations"], frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/lowlevel_refined_snRNA_umap.png", dpi=300, bbox_inches="tight")

In [None]:
import signaturescoring as ssc
marker_genes = {}
for cl in ["1","2","3","4","5"]:
    marker_genes[cl] = pd.read_csv(f"/add/path/here/cNMF_{cl}.csv",index_col=0)
    
for prog in marker_genes:
    ssc.score_signature(adata=adata,
                        gene_list=list(marker_genes[prog].head(100).index.ravel()), 
                        method="adjusted_neighborhood_scoring", 
                        ctrl_size=150,
                        score_name=f"cNMF_{prog}_score")

In [None]:
fig, axs = plt.subplots(3,2, figsize=(10,10))
flatax = axs.flatten()
for i,ax in enumerate(flatax[:-1]):
    sns.boxplot(data=adata.obs, x="highlevel_refined", y=f"cNMF_{i+1}_score", palette=colormapping, ax=flatax[i])
    flatax[i].spines[['right', 'top']].set_visible(False)
    flatax[i].set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    flatax[i].hlines(y=0, xmin=flatax[i].get_xlim()[0], xmax=flatax[i].get_xlim()[1], linestyles="dashed", color="grey")
    flatax[i].set_xlabel("")
flatax[-1].axis("off")
fig.tight_layout()
fig.savefig("figures/malignant/boxplot_cNMF_score_per_celltype.png", dpi=250, bbox_inches="tight")

## Cell cycle

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,1))
sns.heatmap(adata.obs[["S_score","G2M_score","cNMF_2_score"]].corr().loc[["cNMF_2_score"],["S_score","G2M_score"]],
            annot=True, cmap="vlag", center=0, vmin=-1, vmax=1,
            ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

In [None]:
df = adata.obs[adata.obs.highlevel_refined=="Carcinoma"].copy()
confusion = df[["refined_wcancer","phase"]].value_counts().unstack()

confusion["Cycling"] = confusion["G2M"] + confusion["S"]

cycling_OR = {}
for state in ["cNMF_1","cNMF_3","cNMF_4"]:
    conting = pd.concat([confusion.drop([state]).sum().to_frame().T,confusion.loc[[state],:]])
    cycling_OR[state] = [fisher_exact(conting.loc[:,['G1','Cycling']])[0]]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,1))
sns.heatmap(data=pd.DataFrame(cycling_OR, index=["Cycling OR"]),annot=True,fmt=".1f",
            cmap="vlag",center=1,vmin=0,vmax=2,ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

## Subset of malignant cells

In [None]:
colorlist = palettable.colorbrewer.qualitative.Set1_5.mpl_colors
colormapping_mal = {"cNMF_1": colorlist[0], "cNMF_2": colorlist[1], "cNMF_3": colorlist[2], 
                    "cNMF_4": colorlist[3], "cNMF_5": colorlist[4]}
colormapping_mal["cNMF_Outlier"] = "grey"
colormapping_mal["Carcinoma_undefined"] = "grey"

In [None]:
subadata = adata[adata.obs.highlevel_refined=="Carcinoma"].copy()

In [None]:
sc.tl.pca(subadata)
sc.external.pp.harmony_integrate(subadata, key="sample_id", basis="X_pca", max_iter_harmony=20)
sc.pp.neighbors(subadata, use_rep="X_pca_harmony")
sc.tl.umap(subadata)

In [None]:
import signaturescoring as ssc
marker_genes = {}
for cl in ["1","2","3","4","5"]:
    marker_genes[cl] = pd.read_csv(f"/add/path/here/cNMF_{cl}.csv",index_col=0)
    
for prog in marker_genes:
    ssc.score_signature(adata=subadata,
                        gene_list=list(marker_genes[prog].head(100).index.ravel()), 
                        method="adjusted_neighborhood_scoring", 
                        ctrl_size=150,
                        score_name=f"cNMF_{prog}_score")

In [None]:
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors
X_pca = pd.DataFrame(subadata.obsm["X_pca_harmony"][:,:2],index=subadata.obs_names,columns=["PC1","PC2"])

X_pca = pd.concat([X_pca, subadata.obs[["cNMF_1_score","cNMF_3_score","cNMF_4_score","cNMF_2_score","cNMF_5_score"]]],axis=1)
X_pca.columns = ["PC1","PC2","cNMF_1","cNMF_3","cNMF_4","cNMF_2","cNMF_5"]

def plot_pcs_color(ax, state):
    
    vmin, vmax = X_pca[state].min(), X_pca[state].max()
    #vcenter = 0
    vcenter = (X_pca[state].quantile(0.75) + X_pca[state].quantile(0.25))/2
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = matplotlib.colormaps['RdBu_r']
    sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        c=X_pca[state],
        s=1,
        norm=normalize,
        cmap=colormap,
        ax=ax
    )
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(X_pca[state])
    ax.set_title(state)
    fig.colorbar(scalarmappaple, ax=ax)
    pretty_ax(ax)

fig, ax = plt.subplots(1,3, figsize=(15,4))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_3")
plot_pcs_color(flatax[1], "cNMF_1")
plot_pcs_color(flatax[2], "cNMF_4")
fig.tight_layout()
fig.savefig("figures/malignant/PC_wCNMF_score.png", dpi=300, bbox_inches="tight")

fig, ax = plt.subplots(1,2, figsize=(10,4))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_2")
plot_pcs_color(flatax[1], "cNMF_5")
fig.tight_layout()

In [None]:
fig = sc.pl.umap(subadata, color=["refined_wcancer"], palette=colormapping_mal, frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/malonly_cNMF_harmony_snRNA_umap.png", dpi=300, bbox_inches="tight")

In [None]:
fig = sc.pl.umap(subadata, color=["sample_id"], frameon=False, ncols=1, palette=colormapping_pat, return_fig=True)
fig.savefig("figures/malonly_sampleid_harmony_snRNA_umap.png", dpi=300, bbox_inches="tight")

In [None]:
fig = sc.pl.umap(subadata, color=["log1p_total_counts","pct_counts_mt"], frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/malonly_technical_harmony_snRNA_umap.png", dpi=300, bbox_inches="tight")

In [None]:
mTFs = ["KLF5","NFE2L1","MXD1","PPARD","SMAD3","KLF6","TCF7L2","ATF3","EHF","GRHL2"]

In [None]:
fig = sc.pl.umap(subadata, color=mTFs,
                 frameon=False, ncols=2, return_fig=True)
fig.savefig("figures/malonly_cNMF_harmony_mTF_expression.png", dpi=300, bbox_inches="tight")

In [None]:
TF_expr = pd.DataFrame(subadata[:,mTFs].X.copy().toarray(),index=subadata.obs_names,columns=mTFs)

TF_expr = pd.concat([TF_expr,subadata.obs["refined_wcancer"]],axis=1)

fig, ax = plt.subplots(2,5, figsize=(15,5))
flatax=ax.flatten()
for i,axi in enumerate(flatax):
    sns.boxplot(data=TF_expr,x="refined_wcancer",y=mTFs[i],ax=axi, palette=colormapping_mal)
    axi.set_xticklabels(axi.get_xticklabels(), rotation=45)
    axi.set_xlabel("")
    axi.spines[["top","right"]].set_visible(False)
fig.tight_layout()
fig.savefig("figures/malignant/malonly_technical_harmony_snRNA_umap.png",dpi=300,bbox_inches="tight")

# snATAC-seq

In [None]:
atac = sc.read_h5ad("/add/path/here/combined_atac.h5ad")

In [None]:
atac.obs.highlevel_annotation = atac.obs.refined_wcancer.replace(highlevel_refined)
atac.obs.highlevel_annotation = atac.obs.highlevel_annotation.replace({f"cNMF_{i}": "Carcinoma" for i in range(1,6)})
atac.obs.highlevel_annotation = atac.obs.highlevel_annotation.replace({"Carcinoma_undefined": "Carcinoma"})

In [None]:
atac.obsm["X_lsi_red"] = atac.obsm["X_lsi"][:,:40]

In [None]:
sc.external.pp.harmony_integrate(atac, key="sample_id", basis="X_lsi_red", max_iter_harmony=20)

In [None]:
sc.pp.neighbors(atac, use_rep="X_pca_harmony")

In [None]:
sc.tl.umap(atac)

In [None]:
fig = sc.pl.umap(atac, color=["highlevel_annotation"], palette=colormapping, frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/highlevel_refined_snATAC_umap.png", dpi=300, bbox_inches="tight")

In [None]:
atac.obs.highlevel_annotation.value_counts()

In [None]:
fig = sc.pl.umap(atac, color=["dataset"], frameon=False, ncols=1, palette=colormapping_pat, return_fig=True)
fig.savefig("figures/highlevel_refined_snATAC_umap_sampleid.png", dpi=300, bbox_inches="tight")

## Patient-level distributions

In [None]:
patlevel_counts = atac.obs[["sample_id","highlevel_annotation"]].groupby(by="sample_id").value_counts(normalize=True)
patlevel_counts = patlevel_counts.round(2)*100

df = patlevel_counts.unstack(level=-1)

df = df.loc[clinical.index.intersection(df.index)].fillna(0).astype(int)

colorlist = sns.color_palette("colorblind", 10)
ctlist = adata.obs.highlevel_refined.unique()
colormapping = {ct: colorlist[i] for i,ct in enumerate(ctlist)}
colormapping["NA"] = "whitesmoke"

In [None]:
def add_clinical_info(ax):

    x1, x2 = 0, 1   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 100, 10, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "NAT", ha='center', va='bottom', color=col)
    
    x1, x2 = 2, 4   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 100, 10, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Primary", ha='center', va='bottom', color=col)
    
    x1, x2 = 5, 9   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 100, 10, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Metastatic", ha='center', va='bottom', color=col)
    
    return ax

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6,3))
df.plot(kind = 'bar', stacked = True, color=colormapping, ax=ax,)
ax.legend(bbox_to_anchor=(1.05, 1), frameon=False)
ax.spines[['right', 'top']].set_visible(False)
ax.set_xlabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
for c in ax.containers:

    # Optional: if the segment is small or 0, customize the labels
    labels = [int(v.get_height()) if v.get_height() >= 5 else '' for v in c]
    
    # remove the labels parameter if it's not needed for customized labels
    ax.bar_label(c, labels=labels, label_type='center', fmt='%0.0f', color="white")
ax = add_clinical_info(ax)
fig.savefig("figures/barplot_pat_tme_ATAC_highlevel_scaled.png", dpi=300, bbox_inches="tight")

In [None]:
patlevel_counts = atac.obs[["sample_id","highlevel_annotation"]].groupby(by="sample_id").value_counts()
df = patlevel_counts.unstack(level=-1)

df = df.loc[clinical.index.intersection(df.index)].fillna(0).astype(int)

colorlist = sns.color_palette("colorblind", 10)
ctlist = adata.obs.highlevel_refined.unique()
colormapping = {ct: colorlist[i] for i,ct in enumerate(ctlist)}
colormapping["NA"] = "whitesmoke"

In [None]:
def add_clinical_info(ax):

    x1, x2 = 0, 1   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 9000, 600, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "NAT", ha='center', va='bottom', color=col)
    
    x1, x2 = 2, 4   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 9000, 600, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Primary", ha='center', va='bottom', color=col)
    
    x1, x2 = 5, 9   # columns 'Sat' and 'Sun' (first column: 0, see plt.xticks())
    y, h, col = 9000, 600, 'k'
    ax.plot([x1-0.4, x1-0.3, x2+0.3, x2+0.4], [y, y+h, y+h, y], lw=1.5, c=col)
    ax.text((x1+x2)*.5, y+1.3*h, "Metastatic", ha='center', va='bottom', color=col)
    
    return ax

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6,3))
df.plot(kind = 'bar', stacked = True, color=colormapping, ax=ax,)
ax.legend(bbox_to_anchor=(1.05, 1), frameon=False)
ax.spines[['right', 'top']].set_visible(False)
ax.set_xlabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax = add_clinical_info(ax)
fig.savefig("figures/barplot_pat_tme_ATAC_highlevel_nonscaled.png", dpi=300, bbox_inches="tight")

## Subset malignant

In [None]:
subatac = atac[atac.obs.highlevel_annotation=="Carcinoma"].copy()
subatac.obsm["X_lsi_red"] = subatac.obsm["X_lsi"][:,:40]


In [None]:
sc.external.pp.harmony_integrate(subatac, key="sample_id", basis="X_lsi_red", max_iter_harmony=20)
sc.pp.neighbors(subatac, use_rep="X_pca_harmony")
sc.tl.umap(subatac)

In [None]:
fig = sc.pl.umap(subatac, color=["refined_wcancer"], palette=colormapping_mal, frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/malonly_cNMF_harmony_snATAC_umap.png", dpi=300, bbox_inches="tight")

In [None]:
fig = sc.pl.umap(subatac, color=["sample_id"], palette=colormapping_pat, frameon=False, ncols=1, return_fig=True)
fig.savefig("figures/malonly_sampleid_harmony_snATAC_umap.png", dpi=300, bbox_inches="tight")

# Heatmaps of marker genes/regions

## Regions

In [None]:
DAR_res_dir = pl.Path("/add/path/here")

region_markers = {}
for state in ["cNMF_1","cNMF_3","cNMF_4"]:
    region_markers[state] = pd.read_csv(DAR_res_dir / (state + ".csv"),index_col=0)

var_names = {state: region_markers[state].head(100).index.to_numpy() for state in ["cNMF_1","cNMF_3","cNMF_4"]}

In [None]:
all_region_markers = [var_names[state] for state in ["cNMF_1","cNMF_3","cNMF_4"]]
all_region_markers = np.hstack(all_region_markers)
#all_region_markers = np.unique(np.hstack(all_region_markers))

In [None]:
heatmapadata = subatac[:,all_region_markers].copy()

In [None]:
sc.pp.scale(heatmapadata)

cell_idx = heatmapadata.obs.refined_wcancer.sort_values().index.to_numpy()

heatmap_df = pd.DataFrame(heatmapadata[cell_idx,:].X.copy(), index=heatmapadata.obs_names, columns=heatmapadata.var_names)

In [None]:
state_score = {}
for state in ["cNMF_1","cNMF_3","cNMF_4"]:
    state_score[state] = heatmap_df[var_names[state]].mean(axis=1)

state_score = pd.concat(state_score,axis=1)

state_score = pd.concat([state_score,heatmapadata.obs],axis=1)

In [None]:
sns.boxplot(data=state_score,x="refined_wcancer",y="cNMF_1")

In [None]:
sns.boxplot(data=state_score,x="refined_wcancer",y="cNMF_3")

In [None]:
sns.boxplot(data=state_score,x="refined_wcancer",y="cNMF_4")

In [None]:
heatmap_df.loc[heatmapadata.obs[heatmapadata.obs.refined_wcancer=="cNMF_3"].index]

In [None]:
row_colors = []
df = heatmapadata.obs.loc[cell_idx].refined_wcancer
for cell in df.index:
    row_colors.append(colormapping_mal[df.loc[cell]])

In [None]:
clmap = sns.clustermap(heatmap_df, cmap="vlag", center=0, vmax=4, vmin=-1, 
                       row_cluster=False, col_cluster=False, 
                       row_colors=row_colors)
clmap.ax_heatmap.set_xticklabels([])
clmap.ax_heatmap.set_xticks([])
clmap.ax_heatmap.set_yticklabels([])
clmap.ax_heatmap.set_yticks([])
clmap.fig.savefig("figures/malignant/heatmap_DAR_snATAC.png", dpi=300, bbox_inches="tight")

## Genes

In [None]:
cnmf_res_dir = pl.Path("/add/path/here")

gene_markers = {}
for state in ["cNMF_1","cNMF_3","cNMF_4"]:
    gene_markers[state] = pd.read_csv(cnmf_res_dir / (state + ".csv"),index_col=0)

var_names = {state: gene_markers[state].head(100).index.to_numpy() for state in ["cNMF_1","cNMF_3","cNMF_4"]}

In [None]:
all_region_markers = [var_names[state] for state in ["cNMF_1","cNMF_3","cNMF_4"]]
all_region_markers = np.hstack(all_region_markers)
#all_region_markers = np.unique(np.hstack(all_region_markers))

In [None]:
heatmapadata = subadata[:,all_region_markers].copy()

In [None]:
sc.pp.scale(heatmapadata)

cell_idx = heatmapadata.obs.refined_wcancer.sort_values().index.to_numpy()

heatmap_df = pd.DataFrame(heatmapadata[cell_idx,:].X.copy(), index=heatmapadata.obs_names, columns=heatmapadata.var_names)

In [None]:
row_colors = []
df = heatmapadata.obs.loc[cell_idx].refined_wcancer
for cell in df.index:
    row_colors.append(colormapping_mal[df.loc[cell]])

In [None]:
clmap = sns.clustermap(heatmap_df, cmap="vlag", center=0, vmax=3, vmin=-1, 
                       row_cluster=False, col_cluster=False, 
                       row_colors=row_colors)
clmap.ax_heatmap.set_xticklabels([])
clmap.ax_heatmap.set_xticks([])
clmap.ax_heatmap.set_yticklabels([])
clmap.ax_heatmap.set_yticks([])
clmap.fig.savefig("figures/malignant/heatmap_DGEX_snRNA.png", dpi=300, bbox_inches="tight")