In [None]:
from pathlib import Path
import scanpy as sc
import cell2location
import matplotlib.pyplot as plt

DPI = 300
FONTSIZE = 20  # 42
sc.settings.set_figure_params(
    scanpy=True, dpi=100, transparent=True, vector_friendly=True, dpi_save=DPI
)
from matplotlib import rcParams

rcParams["pdf.fonttype"] = 42

from vistools import utils

### Load data obtained by segmentation on cell2location counts

SpatialDE2 region segmentation combined with cell type count estimates obtained from computational deconvolution method that leverage annotated scRNA-seq dataset to estimate cell type abundance

We use the absolute amount of mRNA contributed by each cell type to each spot. We leverage the 5% percentile of the posterior distribution of this parameter (mRNA counts), representing the number of mRNA molecules confidently assigned to each cell type.

**Segmentation**: *aims to assign a cluster label to each location based on its gene expression profile and the identity of its neighboring locations, with the underlying assumption that neighboring locations likely have the same label, i.e. the segmentation should be spatially smooth.*

In [None]:
SAMPLE_NAME = "concat_withWu2022"
optimal_s = 1.2 #1.2  # 1 1.2 1.5 2

In [None]:
# store results
DIR2SAVE = Path(
    f"/data/BCI-CRC/nasrine/data/CRC/spatial/CRC_LM_VISIUM/CRC_LM_VISIUM_04_08_09_11/cell2loc_spatialde2/{SAMPLE_NAME}"
)  # to change
DIR2SAVE.mkdir(parents=True, exist_ok=True)

In [None]:
# figures
FIG2SAVE = DIR2SAVE.joinpath("figures/")
FIG2SAVE.mkdir(parents=True, exist_ok=True)
# set the global variable: sc.settings.figdir to save all plots
sc.settings.figdir = FIG2SAVE

In [None]:
cell2loc_counts = sc.read_h5ad(
    DIR2SAVE.joinpath(f"sp_segmentation_smoothness{optimal_s}.h5ad")
)

In [None]:
cell2loc_counts

In [None]:
cell2loc_counts.obs.Sample.value_counts()

In [None]:
cell2loc_counts.obs

In [None]:
cell2loc_counts.obs

In [None]:
cell2loc_counts

In [None]:
cell2loc_counts.obs.segmentation_labels.cat.categories

#### add colour to segmentation labels so that it is the same colour for the same label across all visium samples

In [None]:
cell2loc_counts.uns["segmentation_labels_colors"] = [
    "#fdb462", #0
    "#FCCDE5", #1
    "#FB8072", #2
    "#e78ac3", #3
    "#67A9CF", #4
    "#feed8b", #5
    "#1f78b4", #6
    "#91CF60", #7 
    "#b3ffff", #8 
    "#33a02c" #9
] 

In [None]:
for samp in set(cell2loc_counts.obs.Sample):
    slide = utils.select_slide(cell2loc_counts, s=samp, batch_key="Sample")
    sc.pl.spatial(
        slide,
        color=f"segmentation_labels",
        show=True,
        alpha_img=0.3,
        size=1.3,
        title="",
        save=f"{samp}_regions_s{optimal_s}.pdf",
    )

In [None]:
set(cell2loc_counts.obs.Sample)

In [None]:
cell2loc_counts.obs.segmentation_labels.value_counts()

### compute average abundance of each region cluster

Figure shows cell type groups enriched in regions on Visium ST slide. Normalised average cell abundance (dot size and color) for cell types annotated through scRNA-seq analysis across the regions obtained through segmentation using SpatialDE2. 

In [None]:
# compute average abundance of each region cluster
# fct returns pd.DataFrame of cluster average expression of each gene, here we use celltype instead of gene
from cell2location.cluster_averages.cluster_averages import get_cluster_averages_df
from cell2location.plt.plot_heatmap import clustermap

# select only columns that are cell type names
ct_average = get_cluster_averages_df(
    X=cell2loc_counts.obs[[i for i in cell2loc_counts.uns["mod"]["factor_names"]]],
    cluster_col=cell2loc_counts.obs["segmentation_labels"],
)
# ct_average.index = adata_vis.uns['mod']['factor_names']

# to get proportion of cell types in each region, need to normalise by max value of each region?

# normalise average abudance of each cell type by region:
# to get proportion of cell types in each region, need to normalise by max value of each celltype across regions?
# transpose for plotting
ct_average_norm = (ct_average.T / ct_average.max(1)).T

In [None]:
clustermap(
    ct_average_norm,
    cluster_rows=False,
    cluster_cols=True,
    figure_size=[20, 15],
    # figure_size=[5.9 + 0.12 * mod_sk.n_fact, 5.9 + 0.1 * mod_sk.n_var],
    fun_type="dotplot",
    array_size=None,
    cmap= "GnBu", #"PuBuGn",
)  # cmap='RdBu'
plt.savefig(
    f"{FIG2SAVE}/celltype_mRNAcount_region_smoothness{optimal_s}.pdf",
    bbox_inches="tight",
    dpi=DPI
)
plt.show()

### plotting transposed version so colour bar is smaller

In [None]:
clustermap(
    ct_average_norm.T,
    cluster_rows=False,
    cluster_cols=True,
    figure_size=[17, 4.5],
    # figure_size=[5.9 + 0.12 * mod_sk.n_fact, 5.9 + 0.1 * mod_sk.n_var],
    fun_type="dotplot",
    array_size=None,
    equal=True,
    cmap="GnBu",
)  # cmap='PuBu'
plt.savefig(
    f"{FIG2SAVE}/celltype_mRNAcount_region_smoothness{optimal_s}_transposed.pdf",
    bbox_inches="tight",
    facecolor='white',
    dpi=DPI
)
plt.show()

In [None]:
FIG2SAVE