# Progression comparison

This notebook contrasts two staged progressions on six compartments:
`Q1 → Q2 → A → M1 → M2 → A → D`. The second configuration adds a
backward hop from A back to Q2, while both versions permit exits into D
from either A or M2.


Install the package with `pip install markovmodus`.
If you are running this notebook inside the repository you can also use an editable install: `pip install -e .[dev]`.


> **Note:** The visualisations below use Matplotlib and ScanPy. Install the optional dependencies with `pip install matplotlib scanpy` if you want to execute every cell.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb, to_hex

from markovmodus import SimulationParameters, simulate_dataset

colormap = plt.get_cmap("viridis")
state_labels = ["Q1", "Q2", "A", "M1", "M2", "D"]
state_lookup = {i: label for i, label in enumerate(state_labels)}
state_palette = [to_hex(colormap(i / max(1, len(state_labels) - 1))) for i in range(len(state_labels))]
state_color_map = {label: color for label, color in zip(state_labels, state_palette)}

def _label_color(color) -> str:
    r, g, b = to_rgb(color)
    luminance = 0.299 * r + 0.587 * g + 0.114 * b
    return "black" if luminance > 0.6 else "white"

def plot_transition_graph(matrix: np.ndarray, title: str, *, ax: plt.Axes | None = None, palette=None, state_labels=None):
    num_states = matrix.shape[0]
    angles = np.linspace(0, 2 * np.pi, num_states, endpoint=False)
    positions = np.stack((np.cos(angles), np.sin(angles)), axis=1)

    if palette is None:
        palette = colormap

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(4.5, 4.5))
        created_fig = True

    max_rate = matrix.max()

    for idx, (x, y) in enumerate(positions):
        if callable(palette):
            color = palette(idx / max(1, num_states - 1))
        else:
            color = palette[idx % len(palette)]
        ax.scatter(x, y, s=500, c=[color], edgecolor="black", linewidth=1.5)
        label = state_labels[idx] if state_labels is not None and idx < len(state_labels) else str(idx)
        ax.text(
            x,
            y,
            label,
            ha="center",
            va="center",
            fontsize=12,
            fontweight="bold",
            color=_label_color(color),
        )

    if max_rate > 0:
        for i in range(num_states):
            for j in range(num_states):
                if i == j:
                    continue
                weight = matrix[i, j]
                if weight <= 0:
                    continue
                ax.annotate(
                    "",
                    xy=positions[j],
                    xytext=positions[i],
                    arrowprops=dict(
                        arrowstyle="->",
                        lw=0.6 + 2.4 * (weight / max_rate) if max_rate > 0 else 0.6,
                        color="tab:gray",
                        alpha=0.7,
                        shrinkA=15,
                        shrinkB=15,
                        connectionstyle="arc3,rad=0.1",
                    ),
                )
                ax.text(
                    *(0.6 * positions[i] + 0.4 * positions[j]),
                    f"{weight:.2f}",
                    fontsize=8,
                    color="tab:gray",
                    ha="center",
                    va="center",
                )

    ax.set_title(title)
    ax.axis("off")

    if created_fig:
        ax.figure.tight_layout()
        return ax.figure
    return ax


In [None]:
base_progression_transition = np.array([
    [0.0, 0.02, 0.0, 0.0, 0.0, 0.0],  # Q1 -> Q2 (rare)
    [0.0, 0.0, 0.10, 0.0, 0.0, 0.0],  # Q2 -> A
    [0.0, 0.0, 0.0, 0.05, 0.0, 0.05],  # A -> M1, D
    [0.0, 0.0, 0.0, 0.0, 0.05, 0.0],  # M1 -> M2
    [0.0, 0.0, 0.05, 0.0, 0.0, 0.0],  # M2 -> A
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],  # D (terminal)
])

progression_with_transition_to_Q_transition = base_progression_transition.copy()
progression_with_transition_to_Q_transition[2, 1] = 0.05  # A -> Q2 feedback

base_progression_params = SimulationParameters(
    num_states=6,
    num_genes=250,
    num_cells=2000,
    t_final=30.0,
    dt=1.0,
    markers_per_state=70,
    transition_matrix=base_progression_transition,
    rng_seed=81,
)

progression_with_transition_to_Q_params = SimulationParameters(
    num_states=6,
    num_genes=250,
    num_cells=2000,
    t_final=30.0,
    dt=1.0,
    markers_per_state=70,
    transition_matrix=progression_with_transition_to_Q_transition,
    rng_seed=82,
)


In [None]:
base_progression_adata, base_progression_df = simulate_dataset(
    base_progression_params,
    output="both",
)
progression_with_transition_to_Q_adata, progression_with_transition_to_Q_df = simulate_dataset(
    progression_with_transition_to_Q_params,
    output="both",
)

for adata in (base_progression_adata, progression_with_transition_to_Q_adata):
    labels = pd.Categorical(
        [state_labels[i] for i in adata.obs["state"]],
        categories=state_labels,
        ordered=True,
    )
    adata.obs["state_label"] = labels

base_progression_df["state_label"] = pd.Categorical(
    base_progression_df["state"].map(state_lookup),
    categories=state_labels,
    ordered=True,
)
progression_with_transition_to_Q_df["state_label"] = pd.Categorical(
    progression_with_transition_to_Q_df["state"].map(state_lookup),
    categories=state_labels,
    ordered=True,
)


In [None]:
base_progression_df.head()


In [None]:
progression_with_transition_to_Q_df.head()


In [None]:
occupancy = pd.DataFrame({
    "Base progression": base_progression_df["state"].value_counts().sort_index(),
    "Progression with A→Q2": progression_with_transition_to_Q_df["state"].value_counts().sort_index(),
}).fillna(0)

state_order = pd.Index(range(len(state_labels)))
occupancy = occupancy.reindex(state_order, fill_value=0).astype(int)
occupancy.index = [state_labels[i] for i in occupancy.index]

fig, ax = plt.subplots(figsize=(7, 4))
occupancy.plot(kind="bar", ax=ax)
ax.set_xlabel("Latent state")
ax.set_ylabel("Number of cells")
ax.set_title("State occupancy comparison")
plt.tight_layout()


### Marker allocation overview

The next plots compare the intended marker profiles with the mean expression recovered
from simulated counts for each progression.


In [None]:
datasets = [
    ("Base progression", base_progression_adata),
    ("Progression with A→Q2", progression_with_transition_to_Q_adata),
]

truth_matrices = {name: np.array(adata.uns["state_expression"]) for name, adata in datasets}
num_states = next(iter(truth_matrices.values())).shape[0]

fig = plt.figure(figsize=(14, 8))
gs = fig.add_gridspec(len(datasets), 2, width_ratios=(20, 20), wspace=0.2, hspace=0.25)
axes = [fig.add_subplot(gs[r, c]) for r in range(len(datasets)) for c in range(2)]
cbar_ax = fig.add_axes([0.92, 0.18, 0.02, 0.64])

def build_order(state_expression: np.ndarray) -> list[int]:
    active_sets = []
    for gene_index in range(state_expression.shape[1]):
        col = state_expression[:, gene_index]
        if np.max(col) - np.min(col) < 1e-6:
            active = tuple()
        else:
            max_val = np.max(col)
            active = tuple(np.flatnonzero(np.isclose(col, max_val, atol=1e-6)))
        active_sets.append(active)

    combinations = []
    seen = set()
    for active in active_sets:
        if active and active not in seen:
            combinations.append(active)
            seen.add(active)

    combinations.sort(key=lambda s: (s[0], tuple(s[1:]), len(s)))
    order_index = {subset: idx for idx, subset in enumerate(combinations)}
    baseline_bucket = len(combinations)

    def sort_key(gene_index: int) -> tuple:
        active = active_sets[gene_index]
        if not active:
            return (baseline_bucket, gene_index)
        priority = order_index[active]
        strength = -state_expression[active[0], gene_index]
        return (priority, strength, gene_index)

    return sorted(range(state_expression.shape[1]), key=sort_key)

def mean_spliced_by_state(adata) -> np.ndarray:
    spliced = np.asarray(adata.layers["spliced"])
    states = adata.obs["state"].astype(int).to_numpy()
    means = []
    for state in range(num_states):
        mask = states == state
        means.append(spliced[mask].mean(axis=0) if mask.any() else np.zeros(spliced.shape[1]))
    return np.vstack(means)

orders = {name: build_order(truth_matrices[name]) for name, _ in datasets}
means = {name: mean_spliced_by_state(adata) for name, adata in datasets}
state_labels_subset = [state_labels[i] for i in range(num_states)]

vmax = max(
    np.max(matrix)
    for matrix in list(truth_matrices.values()) + list(means.values())
)

axes_iter = iter(axes)
for name, _ in datasets:
    order = orders[name]
    truth_ax = next(axes_iter)
    sim_ax = next(axes_iter)

    truth_subset = truth_matrices[name][:, order]
    sim_subset = means[name][:, order]

    im = truth_ax.imshow(
        truth_subset,
        aspect="auto",
        cmap="viridis",
        interpolation="nearest",
        vmin=0,
        vmax=vmax,
    )
    truth_ax.set_title(f"{name} target expression")
    truth_ax.set_ylabel("State")
    truth_ax.set_xticks([])
    truth_ax.set_yticks(range(num_states))
    truth_ax.set_yticklabels(state_labels_subset)

    sim_ax.imshow(
        sim_subset,
        aspect="auto",
        cmap="viridis",
        interpolation="nearest",
        vmin=0,
        vmax=vmax,
    )
    sim_ax.set_title(f"{name} simulation means")
    sim_ax.set_xticks([])
    sim_ax.set_yticks(range(num_states))
    sim_ax.set_yticklabels([])

fig.colorbar(im, cax=cbar_ax, label="Expression level")
plt.tight_layout(rect=[0.02, 0.12, 0.9, 0.95])


In [None]:
from collections import Counter

pair_counts = {}
for label, matrix in truth_matrices.items():
    counter = Counter()
    for gene in range(matrix.shape[1]):
        col = matrix[:, gene]
        if np.max(col) - np.min(col) < 1e-6:
            continue
        max_val = np.max(col)
        active = tuple(int(i) for i in np.flatnonzero(np.isclose(col, max_val, atol=1e-6)))
        if len(active) >= 2:
            counter[active] += 1
    pair_counts[label] = counter

for label, counter in pair_counts.items():
    print(f"{label} pairwise/shared markers:")
    if counter:
        for key, value in sorted(counter.items()):
            named = tuple(state_labels[i] for i in key)
            print(f"  states {named}: {value} genes")
    else:
        print("  (no shared markers)")
    print()


## Scanpy workflow

> **Note:** The remaining cells require `scanpy` and its dependencies.


In [None]:
import scanpy as sc

sc.settings.set_figure_params(dpi=100)


In [None]:
def preprocess(adata, label: str, *, n_hvg: int = 150):
    adata = adata.copy()
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=n_hvg)
    adata = adata[:, adata.var["highly_variable"]].copy()
    sc.pp.scale(adata, max_value=10)
    sc.tl.pca(adata, n_comps=30)
    sc.pp.neighbors(adata, n_neighbors=15)
    sc.tl.umap(adata)
    adata.obs["state"] = adata.obs["state"].astype("category")
    if "state_label" in adata.obs:
        adata.obs["state_label"] = adata.obs["state_label"].astype("category")
    return adata


In [None]:
base_progression_processed = preprocess(base_progression_adata, "Base progression")
progression_with_transition_to_Q_processed = preprocess(progression_with_transition_to_Q_adata, "Progression with A→Q2")


In [None]:
rows = [
    ("Base progression", base_progression_params, base_progression_processed),
    ("Progression with A→Q2", progression_with_transition_to_Q_params, progression_with_transition_to_Q_processed),
]

fig, axes = plt.subplots(len(rows), 3, figsize=(18, 4 * len(rows)))

for row_idx, (label, params, processed) in enumerate(rows):
    plot_transition_graph(
        params.resolve_transition_matrix(),
        f"{label} transition graph",
        ax=axes[row_idx, 0],
        state_labels=state_labels,
    )

    sc.pl.pca(
        processed,
        color="state_label",
        show=False,
        ax=axes[row_idx, 1],
        legend_loc="right margin",
        palette=state_color_map,
    )
    axes[row_idx, 1].set_title(f"{label} – PCA")

    sc.pl.umap(
        processed,
        color="state_label",
        show=False,
        ax=axes[row_idx, 2],
        legend_loc="right margin",
        palette=state_color_map,
    )
    axes[row_idx, 2].set_title(f"{label} – UMAP")

for ax in axes.flat:
    ax.set_xlabel("")
    ax.set_ylabel("")

fig.tight_layout()
