# Multi-Feature RSP Visualization

This script creates visualizations of RSP (Radar Scanning Plot) using KPMP data for multiple genes simultaneously. You can use this to check if the package is working correctly and to compare spatial expression patterns between different genes.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np

In [None]:
import spatialrsp as rsp
import logging

In [None]:
logging.basicConfig(level=logging.INFO)

The following code is optional - do it only if you are confident with debugging! This will print debug information to the console and can help you understand the flow of the code and any issues that may arise.

I have instead added a lighter logging setup that will log just the essential information, which is usually sufficient for most users.

```python
logging.basicConfig(
    level=logging.DEBUG, format="%(asctime)s %(name)s %(levelname)-8s %(message)s"
)
```

This step will set up the Matplotlib environment to use a high DPI for better resolution and set the default font size and grid visibility for the plots. This was partly inspired by [andrearyang/housing](https://github.com/andrearyang/housing)'s pretty plots. Thank you Andrea!


In [None]:
plt.rcParams["figure.dpi"] = 400
plt.rcParams["font.size"] = 14
plt.rcParams["axes.grid"] = True

distinct_colors = [
    "#d62728",
    "#2ca02c",
    "#ff7f0e",
    "#1f77b4",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#bcbd22",
    "#17becf",
]

For purposes of this example, we will use the KPMP data. You can replace this with any other dataset that is compatible with the RSP visualization functions. You can download the data using the `download_kpmp` function from the `spatialrsp.utils` module.

_(We also support the Human Cell Landscape dataset, which can be downloaded using the `download_hcl` function from the same module. However, this example will focus on the KPMP dataset, and to be honest, I have not yet tested the HCL dataset with this code. If you do, please let me know if it works or if you encounter any issues!)_

In [None]:
path = rsp.utils.download_kpmp(
    variant="sn",
    # force=True, # Uncomment to force re-download; useful if you think the data was contaminated
)
print("Downloaded KPMP data to", path)

In [None]:
adata = rsp.io.load_data(path)

This gene list is taken from the [KPMP paper](https://doi.org/10.1038/s41586-023-05769-3) and includes genes that are known to be expressed in the TAL (Thick Ascending Limb) of the nephron. You can replace this with any other list of genes that you want to visualize.

We utilize a threshold percentile system to select the most highly expressing cells for each gene. We define a target percentile $X$ (e.g., 0.9) to retain the top $(1-X) \times 100\%$ of expressing cells as the "foreground." This approach ensures that we focus on the most relevant cells for downstream SpatialRSP analyses, while also adapting to the varying expression scales of different genes. The percentile is calculated from the non-zero expression values, and a mask is generated to select cells with expression exceeding this threshold.

In [None]:
genes = ["SLC12A1", "PROM1", "DCDC2", "ITGB6", "EGF", "CDH11", "ESRRB"]  # Select genes of interest
threshold_percentile = 0.9  # Select top x% of expressing cells

In [None]:
tal_cells = adata[adata.obs["subclass.l1"] == "TAL"].copy()
logging.info(f"TAL cells shape: {tal_cells.shape}")

Alternatively, you can filter further based on other criteria, such as the cell subtype or other metadata. For example, you can filter the cells to include only those that are classified as "C-TAL" (Cortical Thick Ascending Limb) cells:

```python
ctal_cells = tal_cells[tal_cells.obs["subclass.l2"] == "C-TAL"].copy()
logging.info(f"C-TAL cells shape: {ctal_cells.shape}")
```

However, this step is optional and depends on your specific analysis needs.

In [None]:
filtered_cells = tal_cells

In [None]:
preprocessor = rsp.Preprocessor()
preprocessor.run(
    filtered_cells,
    qc=False,
    normalize=False,
    reduction=None,
    polar=True,
)

This is a neat script I wrote to match gene names in the AnnData object. The KPMP dataset uses the ENSG gene IDs as the default gene names, which may not match the gene names you are looking for. You can thus use this function to find the correct gene name in the dataset.

In [None]:
def match_gene_name(adata, gene):
    if "feature_name" in adata.var.columns:
        matched_genes = adata.var[
            adata.var["feature_name"].str.contains(gene, case=False)
        ]
        if not matched_genes.empty:
            return matched_genes.index[0]
    raise ValueError(f"Gene '{gene}' not found in the dataset.")

In [None]:
ensgenes = {}
thresholds = {}
expressions = {}

for gene in genes:
    try:
        ensgenes[gene] = match_gene_name(filtered_cells, gene)
        logging.info(f"Matched {gene} to {ensgenes[gene]}")
    except ValueError as e:
        logging.error(e)
        continue

    expression = filtered_cells[:, ensgenes[gene]].X.toarray().flatten()
    expressions[gene] = expression

    thresholds[gene] = rsp.utils.percentile_to_threshold(
        expressions[gene], threshold_percentile
    )
    logging.info(
        f"Threshold for {gene} at {threshold_percentile * 100}%: {thresholds[gene]}"
    )

In [None]:
print("=== Expression Data Storage Summary ===")
print(f"Stored expression data for {len(expressions)} genes:")
for gene in expressions:
    expr_array = expressions[gene]
    non_zero_cells = np.sum(expr_array > 0)
    mean_expr = np.mean(expr_array[expr_array > 0]) if non_zero_cells > 0 else 0
    max_expr = np.max(expr_array)

    print(f"{gene}:")
    print(f"  - Array shape: {expr_array.shape}")
    print(
        f"  - Non-zero cells: {non_zero_cells}/{len(expr_array)} ({non_zero_cells/len(expr_array)*100:.1f}%)"
    )
    print(f"  - Mean expression (non-zero): {mean_expr:.3f}")
    print(f"  - Max expression: {max_expr:.3f}")
    print()

In [None]:
bg_angles = rsp.utils.get_polar_angles(
    adata=filtered_cells,
    mask=None,  # Background = all cells
    polar_coord="X_polar",
)

In [None]:
fg_masks = {}
fg_angles_list = []

In [None]:
for gene in genes:
    ensgene = ensgenes[gene]
    gene_threshold = thresholds[gene]

    fg_mask = rsp.utils.extract_foreground_mask(
        adata=filtered_cells,
        feature=ensgene,
        threshold=gene_threshold,
    )
    fg_mask = fg_mask.astype(bool)

    # check if mask is a boolean array
    if not np.issubdtype(fg_mask.dtype, np.bool_):
        raise ValueError(f"Foreground mask for {gene} is not a boolean array.")
    else:
        logging.info(f"Foreground mask for {gene} is a boolean array.")
        logging.info(
            f"Selected {np.sum(fg_mask)} out of {len(fg_mask)} cells for {gene} "
            f"({np.sum(fg_mask)/len(fg_mask)*100:.1f}%)"
        )

    fg_masks[gene] = fg_mask

    fg_angles = rsp.utils.get_polar_angles(
        adata=filtered_cells,
        mask=fg_mask,
        polar_coord="X_polar",
    )
    fg_angles_list.append(fg_angles)

The following describes the parameters used for the RSP visualization. If you want a more detailed explanation of these parameters, please refer to the docustring of the respective function in the `spatialrsp` module.

In [None]:
scanning_window = np.pi
scanning_range = np.linspace(0, 2 * np.pi, 360)
resolution = 100
mode = "relative"

In [None]:
if mode == "absolute":
    fg_curves, exp_curves, bg_curve = rsp.compute_rsp(
        theta_fgs=fg_angles_list,
        theta_bg=bg_angles,
        scanning_window=scanning_window,
        scanning_range=scanning_range,
        resolution=resolution,
        mode=mode,
    )
elif mode == "relative":
    fg_curves, bg_curve = rsp.compute_rsp(
        theta_fgs=fg_angles_list,
        theta_bg=bg_angles,
        scanning_window=scanning_window,
        scanning_range=scanning_range,
        resolution=resolution,
        mode=mode,
    )
else:
    raise ValueError(f"Unknown mode: {mode}. Use 'absolute' or 'relative'.")

In [None]:
# Display threshold summary
print("=== Threshold Summary ===")
print(
    f"Threshold percentile: {threshold_percentile} (selecting top {(1-threshold_percentile)*100:.0f}% of cells)"
)
print()
for gene in genes:
    threshold_val = thresholds[gene]
    fg_mask = fg_masks[gene]
    selected_cells = np.sum(fg_mask)
    
    # Use stored expression data instead of re-extracting
    is_expressed = expressions[gene] > 0
    total_cells = np.sum(is_expressed)
    selected_pct = selected_cells / total_cells * 100

    print(f"{gene}:")
    print(f"  - Actual threshold: {threshold_val:.3f}")
    print(f"  - Selected cells: {selected_cells}/{total_cells} ({selected_pct:.1f}%)")
    print()

The code below will create RSP visualization with the UMAP plot. For your convenience, I have also included the option to display the UMAP plot in polar coordinates. This can be useful for comparing the UMAP plot with the RSP curves.

In [None]:
def plot_rsp_visualization(
    filtered_cells,
    genes,
    fg_masks,
    thresholds,
    fg_curves,
    bg_curve,
    scanning_range,
    mode="absolute",
    exp_curves=None,
    polar_umap=False,
):
    fig = plt.figure(figsize=(12, 6))
    gs = gridspec.GridSpec(1, 2, figure=fig, width_ratios=[5, 4])

    if polar_umap:
        ax1 = fig.add_subplot(gs[0], projection="polar")
        polar_coords = filtered_cells.obsm["X_polar"]
        umap_coords = np.column_stack((polar_coords[:, 1], polar_coords[:, 0]))
        coord_labels = ["Angle", "Radius"]
    else:
        ax1 = fig.add_subplot(gs[0])
        umap_coords = filtered_cells.obsm["X_umap"]
        coord_labels = ["UMAP1", "UMAP2"]

    colors = distinct_colors[: len(genes)]
    ax1.scatter(umap_coords[:, 0], umap_coords[:, 1], c="gray", s=1, label="Background")
    for i, gene in enumerate(genes):
        fg_mask = fg_masks[gene]
        threshold_val = thresholds[gene]

        ax1.scatter(
            umap_coords[fg_mask, 0],
            umap_coords[fg_mask, 1],
            c=colors[i],
            s=1,
            label=f"{gene} (thr: {threshold_val:.2f})",
        )
    ax1.legend(loc="upper right", fontsize=10)
    if not polar_umap:
        ax1.set_xlabel(coord_labels[0], fontsize=14)
        ax1.set_ylabel(coord_labels[1], fontsize=14)
        ax1.set_aspect("equal")

    ax1.tick_params(labelsize=12)

    ax2 = fig.add_subplot(gs[1], projection="polar")
    theta = np.asarray(scanning_range)
    n = len(fg_curves[0])

    if theta.size == n + 1 and np.isclose((theta[-1] - theta[0]) % (2 * np.pi), 0.0):
        theta = theta[:-1]
    elif theta.size == 2:
        start, end = theta
        theta = np.linspace(start, end, n, endpoint=False)
    elif theta.size != n:
        raise ValueError(f"scanning_range length {theta.size} but fg_curve length {n}")

    theta_closed = np.concatenate([theta, [theta[0]]])
    bg_closed = np.concatenate([bg_curve, [bg_curve[0]]])

    ax2.plot(
        theta_closed, bg_closed, ":", c="darkgray", label="Background", linewidth=1
    )

    for i, gene in enumerate(genes):
        fg_curve = fg_curves[i]
        fg_closed = np.concatenate([fg_curve, [fg_curve[0]]])

        ax2.plot(
            theta_closed,
            fg_closed,
            c=colors[i],
            alpha=0.8,
            label=f"{gene} (observed)",
            linewidth=2,
            linestyle="-",
        )

        if mode == "absolute" and exp_curves is not None:
            exp_curve = exp_curves[i]
            exp_closed = np.concatenate([exp_curve, [exp_curve[0]]])

            ax2.plot(
                theta_closed,
                exp_closed,
                c=colors[i],
                alpha=0.6,
                label=f"{gene} (expected)",
                linewidth=1.5,
                linestyle="--",
            )
    ax2.legend(loc="lower right", bbox_to_anchor=(1.1, -0.22), fontsize=10)
    ax2.tick_params(labelsize=12)

    plt.tight_layout()
    return fig, (ax1, ax2)


print("=== Plotting with Cartesian UMAP ===")
fig_cart, axes_cart = plot_rsp_visualization(
    filtered_cells=filtered_cells,
    genes=genes,
    fg_masks=fg_masks,
    thresholds=thresholds,
    fg_curves=fg_curves,
    bg_curve=bg_curve,
    scanning_range=scanning_range,
    mode=mode,
    exp_curves=exp_curves if mode == "absolute" else None,
    polar_umap=False,
)
plt.show()

print("=== Plotting with Polar UMAP ===")
fig_polar, axes_polar = plot_rsp_visualization(
    filtered_cells=filtered_cells,
    genes=genes,
    fg_masks=fg_masks,
    thresholds=thresholds,
    fg_curves=fg_curves,
    bg_curve=bg_curve,
    scanning_range=scanning_range,
    mode=mode,
    exp_curves=exp_curves if mode == "absolute" else None,
    polar_umap=True,
)
plt.show()