Let us now consider hierarchical cell-type relationships. For this, we will utilize the Allen-Brain datasets, as they come with two (non-overlapping) granularity-levels of cell-types.

In [2]:
import os 
import sys

sys.path.append("../src/")

import json
import logging
import scanpy as sc
import refcm
from refcm import RefCM

from benchutils import load_adata
from benchplots import add_paper_styling


In [3]:
mtg = load_adata("../data/MTG.h5ad")
alm = load_adata("../data/ALM.h5ad")
visp = load_adata("../data/VISp.h5ad")

Let us first retrieve the hierarchical relationships between cell types:

In [4]:
labels = mtg.obs[["labels3", "labels34"]].set_index("labels3")
coarse_levels = labels.index.unique().to_list()

hierarchy = {
    level: labels.loc[level].drop_duplicates().values.flatten().tolist()
    for level in coarse_levels
}

print(json.dumps(hierarchy, indent=4, sort_keys=True))

{
    "Excitatory": [
        "Exc L5/6 IT 3",
        "Exc L6 CT",
        "Exc L6b",
        "Exc L4/5 IT",
        "Exc L5/6 IT 2",
        "Exc L6 IT 1",
        "Exc L6 IT 2",
        "Exc L5/6 NP",
        "Exc L2/3 IT",
        "Exc L5/6 IT 1",
        "Exc L3/5 IT",
        "Exc L5 PT"
    ],
    "Inhibitory": [
        "Sst 1",
        "Lamp5 Rosehip",
        "Vip 5",
        "Pvalb 2",
        "Sst 3",
        "Pvalb 1",
        "Vip 3",
        "Pax6",
        "Vip 1",
        "Vip Sncg",
        "Vip 4",
        "Lamp5 2",
        "Lamp5 Lhx6",
        "Sst 4",
        "Chandelier",
        "Vip 2",
        "Sst 2",
        "Sst 5",
        "Sst Chodl",
        "Lamp5 1"
    ],
    "Non-neuronal": [
        "Astrocyte",
        "Oligo"
    ]
}


and a quick helper function for plotting costs:

In [28]:
import plotly.graph_objects as go
from anndata import AnnData

def plot_costs(
    q: AnnData,
    q_key: str,
    gt_key: str | None = None,
    hierarchy: dict[str, list[str]] | None = None,
    show_mapping: bool = True,
    show_values: bool = False,
    show_all_labels: bool = False,
    angle_x: bool = False,
    width: float = -1,
    height: float = -1,
) -> go.Figure:
    """
    Display matching cost matrix between query and reference.

    Parameters
    ----------
    q
        Query AnnData with refcm results in .uns["refcm"].
    q_key
        Column in q.obs containing cluster assignments.
    gt_key
        Ground truth column in q.obs (for coloring mapped pairs).
    hierarchy
        Dict mapping parent -> list of children. Labels sharing a parent
        are considered equivalent for correctness checking.
    show_mapping
        Show dots for mapped pairs.
    show_values
        Show cost values as text annotations.
    show_all_labels
        Show all axis labels.
    angle_x
        Angle x labels at 45 deg instead of 90.
    width
        Figure width (if positive).
    height
        Figure height (if positive).

    Returns
    -------
    go.Figure
    """
    if "refcm" not in q.uns:
        raise ValueError("No refcm results found in q.uns")

    costs = q.uns["refcm"]["costs"]
    mapping = q.uns["refcm"]["mapping"]
    ref_ktl = q.uns["refcm"]["ref_ktl"]
    ref_labels = [ref_ktl[i] for i in range(len(ref_ktl))]

    n_rows, n_cols = costs.shape

    # Build label -> parent lookup
    label_to_parent = {}
    if hierarchy:
        for parent, children in hierarchy.items():
            p = parent.lower().strip()
            label_to_parent[p] = p
            for child in children:
                label_to_parent[child.lower().strip()] = p

    def is_match(a: str, b: str) -> bool:
        a, b = a.lower().strip(), b.lower().strip()
        if a == b:
            return True
        if not hierarchy:
            return False
        return label_to_parent.get(a) == label_to_parent.get(b) and label_to_parent.get(a) is not None

    # Build query labels
    clusters = sorted(q.obs[q_key].unique())
    if gt_key is not None:
        q_labels = []
        for c in clusters:
            lbl = q.obs.loc[q.obs[q_key] == c, gt_key].mode().iloc[0]
            q_labels.append(lbl)
    else:
        q_labels = [str(c) for c in clusters]

    # Hovertext
    hovertext = []
    for i in range(n_rows):
        row = []
        for j in range(n_cols):
            row.append(
                f"Query: {q_labels[i]}<br>"
                f"Reference: {ref_labels[j]}<br>"
                f"cost: {costs[i, j]:.4f}"
            )
        hovertext.append(row)

    # Heatmap
    fig = go.Figure(
        go.Heatmap(
            z=costs,
            x=list(range(n_cols)),
            y=list(range(n_rows)),
            colorscale="agsunset",
            colorbar=dict(title="cost"),
            hoverinfo="text",
            text=hovertext,
            xgap=0,
            ygap=0,
        )
    )

    fig.update_xaxes(
        title="Reference",
        tickmode="array",
        tickvals=list(range(n_cols)),
        ticktext=ref_labels,
        tickangle=-45 if angle_x else -90,
        range=[-0.5, n_cols - 0.5],
        constrain="domain",
    )
    fig.update_yaxes(
        title="Query",
        tickmode="array",
        tickvals=list(range(n_rows)),
        ticktext=q_labels,
        range=[n_rows - 0.5, -0.5],
        scaleanchor="x",
        constrain="domain",
    )

    if show_values:
        for i in range(n_rows):
            for j in range(n_cols):
                fig.add_annotation(
                    x=j, y=i,
                    text=f"{costs[i, j]:.2f}",
                    showarrow=False,
                    font=dict(color="white"),
                )

    if show_mapping:
        xs, ys, cs, hover = [], [], [], []
        for i in range(n_rows):
            for j in range(n_cols):
                if mapping[i, j] == 1:
                    if gt_key is not None:
                        c = "green" if is_match(q_labels[i], ref_labels[j]) else "red"
                    else:
                        c = "blue"
                    xs.append(j)
                    ys.append(i)
                    cs.append(c)
                    hover.append(f"{q_labels[i]} -> {ref_labels[j]}")

        fig.add_trace(
            go.Scatter(
                x=xs, y=ys,
                mode="markers",
                marker=dict(color=cs, size=5),
                name="",
                hoverinfo="text",
                hovertext=hover,
            )
        )

    if width > 0:
        fig.update_layout(width=width)
    if height > 0:
        fig.update_layout(height=height)
    if show_all_labels:
        fig.update_xaxes(dtick=1)
        fig.update_yaxes(dtick=1)

    return fig

We can then map across these different levels and datasets -- here VISp to ALM -- and evaluate the performance as follows. Let us first consider mapping from granular to coarse resolutions.

In [29]:
# ensure we allow the query clusters to "merge" without restriction
rcm = RefCM(max_merges=-1)
rcm.setref(alm, "labels3")
m = rcm.annotate(visp, "labels34")

|████████████████| [100.00%] : 00:04


In [30]:
plot_costs(visp, 'labels34', 'labels34', hierarchy)

Comparing with the previously-established hierarchy, every cell was indeed correctly labeled to its coarser cell type!

Conversely, let us now map from coarse to granular annotations:

In [22]:
# ensure we allow the query clusters to "merge" without restriction
rcm = RefCM(max_splits=-1)
rcm.setref(alm, "labels34")
m = rcm.annotate(mtg, "labels3")

|████████████████| [100.00%] : 00:06


In [27]:
fig = plot_costs(mtg, 'labels3', 'labels34', hierarchy)

fig.update_xaxes(title="Reference: ALM")
fig.update_yaxes(title="Query: MTG")
add_paper_styling(fig, lines=False)
fig.show()

In [None]:
os.makedirs("FIG5", exist_ok=True)
fig.write_image("FIG5/hierarchy.png", scale=3)

Comparing this graph with the previous one and the established hierarchy, we conclude that this mapping direction also establishes the correct links in this direction!