# markovmodus quickstart

This notebook demonstrates how to configure the simulator, generate synthetic single-cell counts, and visualise a few quick summaries of the resulting dataset.


Install the package with `pip install markovmodus`.
If you are running this notebook inside the repository you can also use an editable install: `pip install -e .[dev]`.


> **Note:** The visualisations below use Matplotlib and ScanPy.
> If you have not installed it yet, run `pip install matplotlib scanpy` in your environment before executing the notebook.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb

from markovmodus import SimulationParameters, simulate_dataset

palette_tab10 = plt.get_cmap("tab10").colors

def _label_color(color) -> str:
    r, g, b = to_rgb(color)
    luminance = 0.299 * r + 0.587 * g + 0.114 * b
    return "black" if luminance > 0.6 else "white"

def plot_transition_graph(matrix: np.ndarray, title: str, *, ax: plt.Axes | None = None, palette=None):
    num_states = matrix.shape[0]
    angles = np.linspace(0, 2 * np.pi, num_states, endpoint=False)
    positions = np.stack((np.cos(angles), np.sin(angles)), axis=1)

    if palette is None:
        palette = palette_tab10

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(4.5, 4.5))
        created_fig = True

    max_rate = matrix.max()

    for idx, (x, y) in enumerate(positions):
        color = palette[idx % len(palette)]
        ax.scatter(x, y, s=500, c=[color], edgecolor="black", linewidth=1.5)
        ax.text(
            x,
            y,
            f"{idx}",
            ha="center",
            va="center",
            fontsize=12,
            fontweight="bold",
            color=_label_color(color),
        )

    if max_rate > 0:
        for i in range(num_states):
            for j in range(num_states):
                if i == j:
                    continue
                weight = matrix[i, j]
                if weight <= 0:
                    continue
                ax.annotate(
                    "",
                    xy=positions[j],
                    xytext=positions[i],
                    arrowprops=dict(
                        arrowstyle="->",
                        lw=0.6 + 2.4 * (weight / max_rate) if max_rate > 0 else 0.6,
                        color="tab:gray",
                        alpha=0.7,
                        shrinkA=15,
                        shrinkB=15,
                        connectionstyle="arc3,rad=0.1",
                    ),
                )
                ax.text(
                    *(0.6 * positions[i] + 0.4 * positions[j]),
                    f"{weight:.2f}",
                    fontsize=8,
                    color="tab:gray",
                    ha="center",
                    va="center",
                )

    ax.set_title(title)
    ax.axis("off")

    if created_fig:
        ax.figure.tight_layout()
        return ax.figure
    return ax



In [None]:
dense_params = SimulationParameters(
    num_states=5,
    num_genes=250,
    num_cells=2000,
    t_final=30.0,
    dt=1.0,
    markers_per_state=70,
    default_transition_rate=0.06,
    rng_seed=42,
)

dense_adata, dense_df = simulate_dataset(dense_params, output="both")
dense_adata.obs["state_label"] = [f"{i}" for i in dense_adata.obs["state"]]
dense_df["state_label"] = dense_df["state"].map(lambda i: f"{int(i)}")
dense_adata



In [None]:
dense_df.head()


In [None]:
# interlocking cycles 0 -> 1 -> 2 -> 0 and 0 -> 3 -> 4 -> 0
sparse_transition = np.array([
    [0.0, 0.12, 0.0, 0.08, 0.0], # 0 -> 1, 3
    [0.0, 0.0, 0.12, 0.0, 0.0], # 1 -> 2
    [0.12, 0.0, 0.0, 0.0, 0.0], # 2 -> 0
    [0.0, 0.0, 0.0, 0.0, 0.08], # 3 -> 4
    [0.08, 0.0, 0.0, 0.0, 0.0], # 4 -> 0
])

sparse_params = SimulationParameters(
    num_states=5,
    num_genes=250,
    num_cells=2000,
    t_final=30.0,
    dt=1.0,
    markers_per_state=70,
    transition_matrix=sparse_transition,
    rng_seed=52,
)

sparse_adata, sparse_df = simulate_dataset(sparse_params, output="both")
sparse_adata.obs["state_label"] = [f"state_{i}" for i in sparse_adata.obs["state"]]
sparse_df["state_label"] = sparse_df["state"].map(lambda i: f"state_{int(i)}")
sparse_adata



In [None]:
sparse_df.head()


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
plots = [
    ("Dense transition graph", dense_params),
    ("Sparse transition graph", sparse_params),
]

for ax, (title, params) in zip(axes, plots):
    plot_transition_graph(
        params.resolve_transition_matrix(),
        title,
        ax=ax,
        palette=palette_tab10,
    )

plt.tight_layout()


In [None]:
occupancy = pd.DataFrame({
    "Dense": dense_df["state"].value_counts().sort_index(),
    "Sparse": sparse_df["state"].value_counts().sort_index(),
}).fillna(0)

state_order = pd.Index(range(dense_params.num_states))
occupancy = occupancy.reindex(state_order, fill_value=0).astype(int)
occupancy.index = [f"state_{i}" for i in occupancy.index]

fig, ax = plt.subplots(figsize=(6, 4))
occupancy.plot(kind="bar", ax=ax)
ax.set_xlabel("Latent state")
ax.set_ylabel("Number of cells")
ax.set_title("State occupancy comparison")
plt.tight_layout()


### Marker allocation overview

Each latent state receives high-expression markers when we build `state_expression`.
The plots below compare the intended marker profiles with the mean expression recovered from simulated counts for the dense and sparse configurations.


In [None]:
datasets = [
    ("Dense", dense_adata),
    ("Sparse", sparse_adata),
]

truth_matrices = {name: np.array(adata.uns["state_expression"]) for name, adata in datasets}
num_states = next(iter(truth_matrices.values())).shape[0]

fig = plt.figure(figsize=(14, 8))
gs = fig.add_gridspec(len(datasets), 2, width_ratios=(20, 20), wspace=0.2, hspace=0.25)
axes = [fig.add_subplot(gs[r, c]) for r in range(len(datasets)) for c in range(2)]
cbar_ax = fig.add_axes([0.92, 0.18, 0.02, 0.64])

def build_order(state_expression: np.ndarray) -> list[int]:
    active_sets = []
    for gene_index in range(state_expression.shape[1]):
        col = state_expression[:, gene_index]
        if np.max(col) - np.min(col) < 1e-6:
            active = tuple()
        else:
            max_val = np.max(col)
            active = tuple(np.flatnonzero(np.isclose(col, max_val, atol=1e-6)))
        active_sets.append(active)

    combinations = []
    seen = set()
    for active in active_sets:
        if active and active not in seen:
            combinations.append(active)
            seen.add(active)

    combinations.sort(key=lambda s: (s[0], tuple(s[1:]), len(s)))
    order_index = {subset: idx for idx, subset in enumerate(combinations)}
    baseline_bucket = len(combinations)

    def sort_key(gene_index: int) -> tuple:
        active = active_sets[gene_index]
        if not active:
            return (baseline_bucket, gene_index)
        priority = order_index[active]
        strength = -state_expression[active[0], gene_index]
        return (priority, strength, gene_index)

    return sorted(range(state_expression.shape[1]), key=sort_key)

def mean_spliced_by_state(adata) -> np.ndarray:
    spliced = np.asarray(adata.layers["spliced"])
    states = adata.obs["state"].astype(int).to_numpy()
    means = []
    for state in range(num_states):
        mask = states == state
        means.append(spliced[mask].mean(axis=0) if mask.any() else np.zeros(spliced.shape[1]))
    return np.vstack(means)

orders = {name: build_order(truth_matrices[name]) for name, _ in datasets}
means = {name: mean_spliced_by_state(adata) for name, adata in datasets}

vmax = max(
    np.max(matrix)
    for matrix in list(truth_matrices.values()) + list(means.values())
)

axes_iter = iter(axes)
for name, _ in datasets:
    order = orders[name]
    truth_ax = next(axes_iter)
    sim_ax = next(axes_iter)

    truth_subset = truth_matrices[name][:, order]
    sim_subset = means[name][:, order]

    im = truth_ax.imshow(
        truth_subset,
        aspect="auto",
        cmap="viridis",
        interpolation="nearest",
        vmin=0,
        vmax=vmax,
    )
    truth_ax.set_title(f"{name} target expression")
    truth_ax.set_ylabel("State")
    truth_ax.set_xticks([])
    truth_ax.set_yticks(range(num_states))
    truth_ax.set_yticklabels([f"state_{i}" for i in range(num_states)])

    sim_ax.imshow(
        sim_subset,
        aspect="auto",
        cmap="viridis",
        interpolation="nearest",
        vmin=0,
        vmax=vmax,
    )
    sim_ax.set_title(f"{name} simulation means")
    sim_ax.set_xticks([])
    sim_ax.set_yticks(range(num_states))
    sim_ax.set_yticklabels([])

fig.colorbar(im, cax=cbar_ax, label="Expression level")
plt.tight_layout(rect=[0.02, 0.12, 0.9, 0.95])


In [None]:
from collections import Counter

pair_counts = {}
for label, matrix in truth_matrices.items():
    counter = Counter()
    for gene in range(matrix.shape[1]):
        col = matrix[:, gene]
        if np.max(col) - np.min(col) < 1e-6:
            continue
        max_val = np.max(col)
        active = tuple(int(i) for i in np.flatnonzero(np.isclose(col, max_val, atol=1e-6)))
        if len(active) >= 2:
            counter[active] += 1
    pair_counts[label] = counter

for label, counter in pair_counts.items():
    print(f"{label} pairwise/shared markers:")
    if counter:
        for key, value in sorted(counter.items()):
            named = tuple(f"state_{i}" for i in key)
            print(f"  states {named}: {value} genes")
    else:
        print("  (no shared markers)")
    print()


## Scanpy workflow

> **Note:** The following steps require `scanpy` and its dependencies. Install with `pip install scanpy` if needed.


In [None]:
import scanpy as sc

sc.settings.set_figure_params(dpi=100)


In [None]:
def preprocess(adata, label: str, *, n_hvg: int = 150):
    adata = adata.copy()
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=n_hvg)
    adata = adata[:, adata.var["highly_variable"]].copy()
    sc.pp.scale(adata, max_value=10)
    sc.tl.pca(adata, n_comps=30)
    sc.pp.neighbors(adata, n_neighbors=15)
    sc.tl.umap(adata)
    adata.obs["state"] = adata.obs["state"].astype("category")
    if "state_label" in adata.obs:
        adata.obs["state_label"] = adata.obs["state_label"].astype("category")
    return adata



In [None]:
dense_processed = preprocess(dense_adata, "Dense simulation")
sparse_processed = preprocess(sparse_adata, "Sparse simulation")


In [None]:
rows = [
    ("Dense", dense_params, dense_processed),
    ("Sparse", sparse_params, sparse_processed),
]

fig, axes = plt.subplots(len(rows), 3, figsize=(18, 4 * len(rows)))

for row_idx, (label, params, processed) in enumerate(rows):
    plot_transition_graph(
        params.resolve_transition_matrix(),
        f"{label} transition graph",
        ax=axes[row_idx, 0],
        palette=palette_tab10,
    )

    sc.pl.pca(
        processed,
        color="state_label",
        show=False,
        ax=axes[row_idx, 1],
        palette=palette_tab10,
        legend_loc="right margin",
    )
    axes[row_idx, 1].set_title(f"{label} – PCA")

    sc.pl.umap(
        processed,
        color="state_label",
        show=False,
        ax=axes[row_idx, 2],
        palette=palette_tab10,
        legend_loc="right margin",
    )
    axes[row_idx, 2].set_title(f"{label} – UMAP")

for ax in axes.flat:
    ax.set_xlabel("")
    ax.set_ylabel("")

fig.tight_layout()
