In [None]:
import anndata as ad
import spatialdata as sd
from spatialdata_io import xenium
import squidpy as sq
import cellcharter as cc
import pandas as pd
import scanpy as sc
import scvi
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lightning.pytorch import seed_everything

seed_everything(12345)
scvi.settings.seed = 12345

# cellcharter

## read data

In [None]:
# first time read read xenium export zarr
import os
from pathlib import Path

base_path = Path('/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/xenium_atlas_human')
data = {}
radio = {
    'patient_1': 'No Radiotherapy',
    'patient_2': 'Radiotherapy',
    'patient_3': 'Radiotherapy',
    'patient_4': 'No Radiotherapy',
    'patient_5': 'No Radiotherapy',
    'patient_6': 'Radiotherapy',   
}
for entry in os.listdir(base_path):
    if "patient" in entry:
        print(entry)
        print("="*20)
        patient_dir = base_path / entry
        if not patient_dir.is_dir():
            continue
        
        data[entry] = {}
        first_item = next(os.scandir(patient_dir)).name
        item_path = patient_dir / first_item
        data[entry]["path"] = item_path        
        data[entry]["sd"] = xenium(item_path)
        data[entry]["sd"].write(os.path.join(item_path, "sd.zarr"))
        data[entry]["sd"] = sd.read_zarr(os.path.join(item_path, "sd.zarr"))
        data[entry]["radio"] = radio.get(entry, "Unknown")



In [None]:
# load zarr directly

import os
from pathlib import Path

base_path = Path('/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/xenium_atlas_human')
data = {}
radio = {
    'patient_1': 'No Radiotherapy',
    'patient_2': 'Radiotherapy',
    'patient_3': 'Radiotherapy',
    'patient_4': 'No Radiotherapy',
    'patient_5': 'No Radiotherapy',
    'patient_6': 'Radiotherapy',   
}
for entry in os.listdir(base_path):
    if "patient" in entry:
        print(entry)
        print("="*20)
        patient_dir = base_path / entry
        if not patient_dir.is_dir():
            continue
        
        data[entry] = {}
        first_item = next(os.scandir(patient_dir)).name
        item_path = patient_dir / first_item
        data[entry]["path"] = item_path
        data[entry]["sd"] = sd.read_zarr(os.path.join(item_path, "sd.zarr"))
        data[entry]["radio"] = radio.get(entry, "Unknown")



In [None]:
adata_list = []
i = 1
for sample in data:
    adata = data[sample]["sd"]['table']
    adata.obs['sample'] = sample
    adata.obs['radiotherapy'] = data[sample]["radio"]
    adata.obsm['spatial'] = adata.obsm['spatial'] + (i*1000)
    adata_list.append(adata)
    i += 1

adata = ad.concat(adata_list)
print(adata)
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.filter_cells(adata, min_counts=3)
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.log1p(adata)
adata.uns['spatial_fov'] = {s: {} for s in adata.obs['sample'].unique()}
adata.obs['sample'] = pd.Categorical(adata.obs['sample'])
adata_list = []

## run cellcharter

##### seed labeling

In [None]:
import numpy as np
from scipy.sparse import issparse

def get_score_vectorised(adata, gene_set):
    pos = [g for g in gene_set["positive"] if g in adata.var_names]
    neg = [g for g in gene_set["negative"] if g in adata.var_names]
    if len(pos) == 0 and len(neg) == 0:
        return np.zeros(adata.n_obs, dtype=float)
    if len(pos) > 0:
        X_pos = adata[:, pos].X
        if issparse(X_pos):
            X_pos = X_pos.toarray()
        pos_score = X_pos.sum(axis=1).flatten()
    else:
        pos_score = 0

    if len(neg) > 0:
        X_neg = adata[:, neg].X
        if issparse(X_neg):
            X_neg = X_neg.toarray()
        neg_score = X_neg.sum(axis=1).flatten()
    else:
        neg_score = 0

    score = pos_score - neg_score
    score /= (len(pos) + len(neg))
    return score

def get_cell_mask(adata, gene_set, n_cells=50):
    score = get_score_vectorised(adata, gene_set)
    idx = np.argpartition(score, -n_cells)[-n_cells:]
    mask = np.zeros(adata.n_obs, dtype=bool)
    mask[idx] = True
    return mask



In [None]:
adata_sc = ad.read_zarr('/home/daniele/atlases/Human_Atlas_Harmonised.zarr')

In [None]:
overlap_genes = np.intersect1d(adata.var_names, adata_sc.var_names)
len(overlap_genes)

In [None]:
overlap_genes = np.intersect1d(adata.var_names, adata_sc.var_names)
adata = adata[:, overlap_genes].copy()
adata_sc_subset = adata_sc[:, overlap_genes].copy()

In [None]:
sc.tl.rank_genes_groups(adata_sc_subset, 'Level_3', layer='log_norm')

In [None]:
sc.tl.rank_genes_groups(adata_sc_subset, 'Level_4', layer='log_norm')
gene_dict = {}

for cell in adata_sc.obs['Level_4'].unique():
    gene_dict[cell] = {}
    gene_dict[cell]["positive"] = list(
        adata_sc_subset.uns['rank_genes_groups']['names'][cell][:5]
    )
    gene_dict[cell]["negative"] = list(
        adata_sc_subset.uns['rank_genes_groups']['names'][cell][-10:]
    )
gene_dict['Double Positive CD4+CD8+ T Cell'] = {}
gene_dict['Double Positive CD4+CD8+ T Cell']["positive"] = ['CD4', 'CD8A', "CD8B"] 


In [None]:
seed = np.array(["Unknown"] * adata.shape[0], dtype=object)
for cell in gene_dict:
    cell_mask = get_cell_mask(adata, gene_dict[cell], n_cells=50)
    seed[cell_mask] = cell
adata.obs["seed_level_4"] = seed



In [None]:
scvi.model.SCVI.setup_anndata(
    adata, 
    layer="counts", 
    batch_key='sample',
    labels_key="seed_level_4"
)

scvi_model = scvi.model.SCVI(adata, n_latent=30, n_layers=2)
scvi_model.train(early_stopping=True, enable_progress_bar=True)
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, unlabeled_category="Unknown")
scanvi_model.train(early_stopping=True, enable_progress_bar=True)


In [None]:
SCANVI_KEY = "X_scANVI"

adata.obsm[SCANVI_KEY] = scanvi_model.get_latent_representation(adata)
adata.obs["predicted"] = scanvi_model.predict(adata)

#### cellcharter clustering

In [None]:
sq.gr.spatial_neighbors(adata, library_key='sample', coord_type='generic', delaunay=True, spatial_key='spatial', percentile=99)
cc.gr.remove_long_links(adata)

In [None]:
cc.gr.aggregate_neighbors(adata, n_layers=3, use_rep=SCANVI_KEY, out_key='X_cellcharter', sample_key='sample')


In [None]:
autok = cc.tl.ClusterAutoK(
    n_clusters=(2,10), 
    max_runs=10,
    convergence_tol=0.001
)
autok.fit(adata, use_rep='X_cellcharter')

In [None]:
cc.pl.autok_stability(autok)

In [None]:
adata.obs['cluster_cellcharter'] = autok.predict(adata, use_rep='X_cellcharter', k=8)

In [None]:
#safety save
adata.write_zarr("adata.zarr")
adata = ad.read_zarr("adata.zarr")

In [None]:
adata = ad.read_zarr("adata.zarr")

#### Niche annotation

In [None]:
df = adata.obs[['predicted', 'cluster_cellcharter']].copy()
ct = pd.crosstab(df['cluster_cellcharter'], df['predicted'])
ct = ct.div(ct.sum(axis=1), axis=0)

plt.figure(figsize=(40,4))
sns.heatmap(
    ct,
    annot=False,
    cmap="Reds",
    linewidths=0.5,
    cbar_kws={'label': 'Count'},
    vmax=0.25
)
plt.xlabel("predicted")
plt.ylabel("cluster_cellcharter")
plt.title("Dot-like heatmap of crosstab")
plt.show()


In [None]:
anno_niche = {
    "0":"Desmoplastic Area",
    "1":"Malignant Cells High Area",
    "2":"Malignant Cells High Area",
    "3":"Healty Exocrine Structure",
    "4":"Immune Infiltration Area",
    "5":"Malignant Cells High Area",
    "6":"TLS",
    "7":"Immune Infiltration Area",
    }
adata.obs['niche_annotation'] = adata.obs['cluster_cellcharter'].map(anno_niche)

In [None]:
adata.obs['niche_annotation'] = adata.obs['niche_annotation'].astype('category')

#### Niche enrichments 

In [None]:
cc.gr.enrichment(adata, group_key='niche_annotation', label_key='predicted', pvalues=True )

In [None]:
cc.pl.enrichment(
    adata, 
    group_key='niche_annotation', 
    label_key='predicted', 
    figsize=(18,3), 
    fontsize=12, 
    show_pvalues=True, 
    palette='coolwarm', 
    #save = "/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/supplementary/fig2/figure_spatial_enrichment_dotplot.png"
)

In [None]:
cells = []
for cell in adata.obs.predicted.unique():
    if "T " in cell or "CD4" in cell and not "Macro" in cell:
        cells.append(cell)
for cell in adata.obs.predicted.unique():
    if "T " in cell or "CD8" in cell and not "Macro" in cell:
        cells.append(cell)
for cell in adata.obs.predicted.unique():
    if "B " in cell or "Plasma" in cell:
        cells.append(cell)
for cell in adata.obs.predicted.unique():
    if "Endot" in cell:
        cells.append(cell)  
    


In [None]:

from matplotlib.lines import Line2D

df = adata.uns['niche_annotation_predicted_enrichment']['enrichment'].loc[:, cells]
p  = adata.uns['niche_annotation_predicted_enrichment']['pvalue'].loc[:, cells]

df = df.apply(pd.to_numeric, errors='coerce')
p  = p.apply(pd.to_numeric, errors='coerce')

df = df.replace([np.inf, -np.inf], np.nan)
p  = p.replace([np.inf, -np.inf], np.nan)

df = df.dropna(how='all', axis=0).dropna(how='all', axis=1)
p  = p.loc[df.index, df.columns]

def starify(x):
    return "*" if x < 0.05 else ""

annot = p.applymap(starify)

fig, ax = plt.subplots(figsize=(30, 8))

sns.heatmap(
    df, cmap="coolwarm", center=0, vmin=-5, vmax=5,
    annot=annot, fmt="", cbar_kws={"label": "logFC"},
    square=True, ax=ax
)

ax.set_title("Enrichment Heatmap", fontsize=16)

legend_elements = [
    Line2D([0], [0], marker='', color='none', label='*  p < 0.05', linestyle='None')
]

ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.25, 1.0), frameon=False)

plt.tight_layout()
#fig.savefig("/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/figure2/enrichment_heatmap.png", dpi=300, bbox_inches='tight')
plt.show()

#### boundaries

In [None]:
cc.tl.boundaries(adata, cluster_key='niche_annotation',)

In [None]:
adata

#### readd annotation to single spatial data

In [None]:
for sample in adata.obs['sample'].unique():
    data[sample]['sd']['table'].obs['predicted'] = adata[adata.obs['sample'] == sample].obs.predicted.astype('category')
    data[sample]['sd']['table'].obs['niche_annotation'] = adata[adata.obs['sample'] == sample].obs.niche_annotation.astype('category')

#### spatial plots

In [None]:
def crop_sdata_region(sdata,
                      roi_info,
                      axes = ("x", "y")) -> sd.SpatialData:

    from spatialdata import bounding_box_query
    cropped_sdata = bounding_box_query(
        sdata,
        min_coordinate=[roi_info['roi_xmin'], roi_info['roi_ymin']],
        max_coordinate=[roi_info['roi_xmax'], roi_info['roi_ymax']],
        axes=axes,
        target_coordinate_system="global" 
    )
    return cropped_sdata

In [None]:
def plot_roi_highlight_celltype(
    sdata,
    obs_column,
    target=None,
    roi_info=None,
    palette=None,
    save=None,
    ax=None,
    figsize=None,
):
    plot_data = crop_sdata_region(sdata, roi_info) if roi_info is not None else sdata
    table = plot_data.tables["table"]

    if target is None:
        plot_data.pl.render_shapes(
            "cell_circles",
            color=obs_column,
            outline=False,
        ).pl.show(
            na_in_legend=False,
            frameon=False,
            ax=ax,
        )
        return

    if isinstance(target, str):
        targets = [target]
    else:
        targets = list(target)

    temp = "_temp_highlight"
    col = table.obs[obs_column]

    if "Other" not in col.cat.categories:
        table.obs[obs_column] = col.cat.add_categories(["Other"])

    table.obs[temp] = table.obs[obs_column].where(
        table.obs[obs_column].isin(targets),
        other="Other"
    ).astype("category")

    if palette is None:
        palette = {t: "#FF5733" for t in targets}
    palette = {t: palette.get(t, "#FF5733") for t in targets}
    palette["Other"] = "#D3D3D3D6"

    groups = targets + ["Other"]

    plot_data.pl.render_shapes(
        "cell_boundaries",
    ).pl.render_shapes(
        "cell_circles",
        color=temp,
        palette=[palette[g] for g in groups],
        groups=groups,
        outline_alpha=0.1,
        scale = 0.5
    ).pl.show(
        na_in_legend=False,
        frameon=False,
        ax=ax,
        figsize=figsize if figsize is not None else (8,8),
    )

    del table.obs[temp]


In [None]:
roi_info_patient2 = pd.Series({
    'roi_xmin': 10000,
    'roi_xmax': 18000,
    'roi_ymin': 12000,
    'roi_ymax': 18000,
})
roi_info_patient4 = pd.Series({
    'roi_xmin': 5000,
    'roi_xmax': 20000,
    'roi_ymin': 8000,
    'roi_ymax': 15000,
})


In [None]:
palette_t = {
    "Double Positive CD4+CD8+ T Cell": "#e31a1c",
    "Other": "#959595",
}
palette_n = {
    "TLS": "#223271",
    "Other": "#959595",
}
palette_niche = {
    "TLS": "#4C90C0",                       
    "Malignant Cells High Area": "#E16A86", 
    "Immune Infiltration Area": "#8CC084",  
    "Desmoplastic Area": "#000000",         
    "Healty Exocrine Structure": "#F2A65A", 
}


In [None]:
plot_roi_highlight_celltype(
    sdata = data['patient_4']['sd'],
    roi_info=roi_info_patient4,
    obs_column='predicted',
    target='Double Positive CD4+CD8+ T Cell',
    figsize=(20,20),
    palette=palette_t
)
plt.savefig("/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/figure2/figure_spatial_double_positive_T_cells_patient_4.png", dpi=300, bbox_inches='tight')


In [None]:
plot_roi_highlight_celltype(
    sdata = data['patient_4']['sd'],
    roi_info=None,
    obs_column='predicted',
    target='Double Positive CD4+CD8+ T Cell',
    figsize=(20,20),
    palette=palette_t
)
plt.savefig("/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/supplementary/fig2/figure_spatial_double_positive_T_cells_patient_4_whole_slide.png", dpi=300, bbox_inches='tight')


In [None]:
plot_roi_highlight_celltype(
    sdata = data['patient_4']['sd'],
    roi_info=roi_info_patient4,
    obs_column='niche_annotation',
    target='TLS',
    figsize=(20,20),
    palette=palette_n
)


In [None]:
df = adata.obs[['predicted', 'cluster_cellcharter']].copy()
ct = pd.crosstab(df['cluster_cellcharter'], df['predicted'])
ct = ct.div(ct.sum(axis=1), axis=0)

plt.figure(figsize=(40,4))
sns.heatmap(
    ct,
    annot=False,
    cmap="Reds",
    linewidths=0.5,
    cbar_kws={'label': 'Count'},
    vmax=0.2
)
plt.xlabel("predicted")
plt.ylabel("cluster_cellcharter")
plt.title("Dot-like heatmap of crosstab")
plt.show()


## radiotherapy

In [None]:
cc.gr.diff_nhood_enrichment(
    adata,
    cluster_key='niche_annotation',
    condition_key='radiotherapy',
    library_key='sample',
    pvalues=True,
    n_jobs=15,
    n_perms=100
)


In [None]:
cc.pl.diff_nhood_enrichment(
    adata,
    cluster_key='niche_annotation',
    condition_key='radiotherapy',
    condition_groups=['Radiotherapy', 'No Radiotherapy'],
    annotate=True,
    figsize=(3,3),
    significance=0.05,
    fontsize=5,
    cmap = 'coolwarm',
    save = "/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/figure4/figure_heatmap_diff_enrichment_niche_radiotherapy.png"
)

In [None]:
adata.obs.niche_annotation = adata.obs.niche_annotation.astype('category')

In [None]:
adata_radio = adata[adata.obs['radiotherapy'] == 'Radiotherapy']
cc.gr.nhood_enrichment(
    adata_radio,
    cluster_key='niche_annotation',
)

cc.pl.nhood_enrichment(
    adata_radio,
    cluster_key='niche_annotation',
    annotate=True,
    vmin=-1,
    vmax=1,
    figsize=(3,3),
    fontsize=5,
)

In [None]:
adata_radio = adata[adata.obs['radiotherapy'] == 'Radiotherapy']
cc.gr.nhood_enrichment(
    adata_radio,
    cluster_key='niche_annotation',
)

cc.pl.nhood_enrichment(
    adata_radio,
    cluster_key='niche_annotation',
    annotate=True,
    vmin=-1,
    vmax=1,
    figsize=(3,3),
    fontsize=5,
)

adata_no_radio = adata[adata.obs['radiotherapy'] == 'No Radiotherapy']
cc.gr.nhood_enrichment(
    adata_no_radio,
    cluster_key='niche_annotation',
)

cc.pl.nhood_enrichment(
    adata_no_radio,
    cluster_key='niche_annotation',
    annotate=True,
    vmin=-1,
    vmax=1,
    figsize=(3,3),
    fontsize=5,
)

In [None]:
cc.pl.proportion(
    adata[adata.obs.predicted.str.contains("Endo")],
    group_key='predicted',
    label_key='radiotherapy',
    figsize=(4,3),
    rotation_xlabel=90
)

In [None]:
radio_endo = adata[(adata.obs['radiotherapy'] == 'Radiotherapy')].obs.predicted.value_counts() / adata[(adata.obs['radiotherapy'] == 'Radiotherapy')].n_obs
radio_no_endo = adata[(adata.obs['radiotherapy'] == 'No Radiotherapy')].obs.predicted.value_counts() / adata[(adata.obs['radiotherapy'] == 'No Radiotherapy')].n_obs

radio_endo = radio_endo[radio_endo.index.str.contains("Endo")]
radio_no_endo = radio_no_endo[radio_no_endo.index.str.contains("Endo")]

In [None]:
print("="*15, "Radiotherapy Group", "="*15)
for patient in data: 
    if data[patient]['radio'] == "Radiotherapy":
        print(f"Rendering patient: {patient} in radiotherapy group")
        data[patient]['sd'].pl.render_shapes("cell_circles", color="niche_annotation").pl.show()
        continue
print("="*15, "No Radiotherapy Group", "="*15)
for patient in data: 
    if data[patient]['radio'] == "No Radiotherapy":
        print(f"Rendering patient: {patient} in no radiotherapy group")
        data[patient]['sd'].pl.render_shapes("cell_circles", color="niche_annotation").pl.show()
        continue


In [None]:
for sample in adata.obs['sample'].unique():
    data[sample]['sd']['table'].obs['predicted'] = adata[adata.obs['sample'] == sample].obs.predicted.astype('category')
    data[sample]['sd']['table'].obs['niche_annotation'] = adata[adata.obs['sample'] == sample].obs.niche_annotation.astype('category')

In [None]:
# load zarr directly

import os
from pathlib import Path

base_path = Path('/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/xenium_atlas_human')

for entry in os.listdir(base_path):
    if "patient" in entry:
        print(entry)
        print("="*20)
        patient_dir = base_path / entry
        if not patient_dir.is_dir():
            continue        
        first_item = next(os.scandir(patient_dir)).name
        item_path = patient_dir / first_item
        data[entry]["sd"].write(os.path.join(item_path, "sd_niche_annotated.zarr"))



In [None]:
sc.pp.calculate_qc_metrics(adata, inplace=True, percent_top=(50,100,150))

In [None]:
adata

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))

sc.pl.scatter(
    adata,
    x='transcript_counts',
    y='cell_area',
    color='sample',
    ax=ax,
    show=False,
    size=10

)

ax.set_xlim(0, 800)
ax.set_ylim(0, 400)

plt.show()


In [None]:
sns.histplot(adata.obs['transcript_counts'], bins = 200)
plt.xlim(0,500)

In [None]:
sns.histplot(adata.obs['cell_area'], bins = 200)
plt.xlim(0,200)

In [None]:
adata.obs

In [None]:
sc.pl.scatter(adata, x='pct_counts_in_top_50_genes', y='transcript_counts', color='radiotherapy')

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))

sc.pl.scatter(
    adata, 
    x='pct_counts_in_top_50_genes', 
    y='transcript_counts', 
    color='radiotherapy',
    ax=ax,
    show=False,
    size=10
)

ax.set_xlim(30, 100)
ax.set_ylim(0, 800)

plt.show()


In [None]:
sc.pl.violin(adata, keys='pct_counts_in_top_50_genes', groupby='radiotherapy')

In [None]:
adata_cut = adata[adata.obs['pct_counts_in_top_50_genes'] >= 50]

In [None]:
sns.boxplot(data=adata_cut.obs, x='sample', y='pct_counts_in_top_50_genes')


In [None]:
adata_sc.obs

In [None]:
sc.pp.calculate_qc_metrics(adata_sc, inplace = True)

In [None]:
sns.boxplot(data=adata_sc.obs, x='radiotherapy', y='pct_counts_in_top_50_genes')
