<a href="https://colab.research.google.com/github/shizuo-kaji/TutorialTopologicalDataAnalysis/blob/master/PersistentHomology_Interactive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interactive Persistent Homology for Vietoris-Rips and Cubical Complexes

This notebook is an interactive explainer of **persistent homology (PH)** for a **Vietoris-Rips (VR) filtration** for point clouds and a **cubical complex filtration** for images.

1. See how a VR complex grows as the filtration parameter \(\varepsilon\) increases.
2. Connect geometric features in a point cloud with births/deaths in **H0** (components) and **H1** (loops).
3. Read barcodes and persistence diagrams while tracking a specific filtration value.


In [None]:
!pip install -U cripser ripser

In [None]:
import itertools
from functools import lru_cache
from typing import Dict, List, Tuple

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
from matplotlib.patches import Circle, Polygon
from ripser import ripser
from scipy.spatial.distance import pdist, squareform

plt.rcParams["figure.dpi"] = 120
plt.rcParams["axes.grid"] = True


## Core helpers

These functions generate point clouds, build VR simplices up to dimension 2, compute PH, and provide plotting utilities.


In [None]:
def make_point_cloud(kind: str, n_points: int, noise: float, seed: int) -> np.ndarray:
    """Create a 2D point cloud for interactive experiments."""
    rng = np.random.default_rng(seed)

    if kind == "Noisy circle":
        t = np.linspace(0.0, 2.0 * np.pi, n_points, endpoint=False)
        pts = np.column_stack([np.cos(t), np.sin(t)])

    elif kind == "Two circles":
        n1 = n_points // 2
        n2 = n_points - n1
        t1 = np.linspace(0.0, 2.0 * np.pi, n1, endpoint=False)
        t2 = np.linspace(0.0, 2.0 * np.pi, n2, endpoint=False)
        c1 = np.column_stack([np.cos(t1), np.sin(t1)]) + np.array([-1.2, 0.0])
        c2 = 0.75 * np.column_stack([np.cos(t2), np.sin(t2)]) + np.array([1.2, 0.0])
        pts = np.vstack([c1, c2])

    elif kind == "Figure-eight":
        t = np.linspace(0.0, 2.0 * np.pi, n_points, endpoint=False)
        x = np.sin(t)
        y = np.sin(t) * np.cos(t)
        pts = np.column_stack([x, y])

    elif kind == "Three clusters":
        centres = np.array([[-1.3, -0.7], [1.2, -0.6], [0.0, 1.2]])
        splits = [n_points // 3, n_points // 3, n_points - 2 * (n_points // 3)]
        clouds = [
            rng.normal(loc=centres[i], scale=0.18, size=(splits[i], 2))
            for i in range(3)
        ]
        pts = np.vstack(clouds)

    else:
        raise ValueError(f"Unknown dataset kind: {kind}")

    if kind != "Three clusters" and noise > 0.0:
        pts = pts + rng.normal(scale=noise, size=pts.shape)

    # Small random jitter keeps points distinct when noise is tiny.
    pts = pts + rng.normal(scale=1e-4, size=pts.shape)
    return pts.astype(float)


def build_vr_simplices(points: np.ndarray, epsilon: float) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int, int]]]:
    """Return edges and triangles in the VR complex at filtration value epsilon."""
    dist = squareform(pdist(points))
    n = len(points)

    edges: List[Tuple[int, int]] = []
    for i in range(n - 1):
        for j in range(i + 1, n):
            if dist[i, j] <= epsilon:
                edges.append((i, j))

    triangles: List[Tuple[int, int, int]] = []
    for i, j, k in itertools.combinations(range(n), 3):
        if dist[i, j] <= epsilon and dist[i, k] <= epsilon and dist[j, k] <= epsilon:
            triangles.append((i, j, k))

    return edges, triangles


def betti_numbers_at_epsilon(diagrams: List[np.ndarray], epsilon: float, max_dim: int = 1) -> Dict[int, int]:
    """Compute Betti numbers by counting intervals alive at epsilon."""
    betti: Dict[int, int] = {}
    for dim in range(min(max_dim + 1, len(diagrams))):
        alive = 0
        for birth, death in diagrams[dim]:
            if birth <= epsilon < death:
                alive += 1
        betti[dim] = alive
    return betti


def _finite_cap(diagrams: List[np.ndarray], epsilon_max: float = 2.4) -> float:
    """Choose a fixed upper limit for plotting finite/infinite deaths."""
    finite = []
    for dgm in diagrams:
        for _, d in dgm:
            if np.isfinite(d):
                finite.append(float(d))
    if finite:
        cap = 1.15 * max(finite)
    else:
        cap = 1.0
    # Keep axis limits fixed while the epsilon slider moves.
    cap = max(cap, float(epsilon_max) + 0.1)
    return float(cap)


@lru_cache(maxsize=256)
def cached_scene(kind: str, n_points: int, noise: float, seed: int):
    """Cache point clouds and PH diagrams so epsilon-only updates are fast."""
    pts = make_point_cloud(kind=kind, n_points=n_points, noise=noise, seed=seed)
    dgms = ripser(pts, maxdim=1)["dgms"]
    return pts, dgms


def plot_vr_complex(
    ax: plt.Axes,
    points: np.ndarray,
    epsilon: float,
    edges: List[Tuple[int, int]],
    triangles: List[Tuple[int, int, int]],
    show_balls: bool,
) -> None:
    """Draw the VR complex at a fixed epsilon."""
    ax.clear()

    if show_balls:
        radius = epsilon / 2.0
        for x, y in points:
            ax.add_patch(Circle((x, y), radius=radius, facecolor="#a9def9", edgecolor="none", alpha=0.22, zorder=0))

    for tri in triangles:
        ax.add_patch(
            Polygon(points[list(tri)], closed=True, facecolor="#f4a261", edgecolor="none", alpha=0.35, zorder=1)
        )

    if edges:
        segments = [points[[i, j]] for i, j in edges]
        ax.add_collection(LineCollection(segments, colors="#2a9d8f", linewidths=1.25, alpha=0.95, zorder=2))

    ax.scatter(points[:, 0], points[:, 1], s=22, c="#1d3557", edgecolors="white", linewidths=0.45, zorder=3)
    ax.set_aspect("equal")
    ax.set_title(f"Vietoris-Rips complex at ε = {epsilon:.2f}")
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    pad = 0.35
    x_min, x_max = points[:, 0].min() - pad, points[:, 0].max() + pad
    y_min, y_max = points[:, 1].min() - pad, points[:, 1].max() + pad
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)


def plot_barcodes(ax: plt.Axes, diagrams: List[np.ndarray], epsilon: float) -> None:
    """Plot H0/H1 barcodes and mark the current epsilon with a vertical line."""
    ax.clear()
    cap = _finite_cap(diagrams)
    colours = {0: "#457b9d", 1: "#e76f51"}

    row = 0
    for dim in range(min(2, len(diagrams))):
        dgm = diagrams[dim]
        if len(dgm) > 0:
            order = np.argsort(dgm[:, 0])
            dgm = dgm[order]
        for birth, death in dgm:
            end = cap if not np.isfinite(death) else float(death)
            ax.plot([birth, end], [row, row], color=colours.get(dim, "#333333"), linewidth=2.0)
            if not np.isfinite(death):
                ax.plot(end, row, marker=">", color=colours.get(dim, "#333333"), markersize=5)
            row += 1
        row += 1

    ax.axvline(epsilon, linestyle="--", color="black", linewidth=1.3)
    ax.text(0.01, 0.95, "H0 = blue, H1 = orange", transform=ax.transAxes, va="top", fontsize=9)
    ax.set_xlim(0.0, cap * 1.03)
    ax.set_ylim(-1, max(3, row))
    ax.set_xlabel("Filtration value ε")
    ax.set_ylabel("Interval index")
    ax.set_title("Persistence barcodes")


def plot_persistence_diagram(ax: plt.Axes, diagrams: List[np.ndarray], epsilon: float) -> None:
    """Plot the persistence diagram for H0/H1 and highlight current epsilon."""
    ax.clear()
    cap = _finite_cap(diagrams)

    ax.plot([0, cap], [0, cap], linestyle="--", color="grey", linewidth=1.0)

    labels = {0: "H0", 1: "H1"}
    colours = {0: "#457b9d", 1: "#e76f51"}

    for dim in range(min(2, len(diagrams))):
        dgm = diagrams[dim]
        if len(dgm) == 0:
            continue
        births = dgm[:, 0]
        deaths = np.where(np.isfinite(dgm[:, 1]), dgm[:, 1], cap)
        ax.scatter(births, deaths, s=28, alpha=0.9, label=labels.get(dim, f"H{dim}"), c=colours.get(dim, "#333333"))

    ax.axhline(epsilon, linestyle=":", color="black", linewidth=1.0)
    ax.axvline(epsilon, linestyle=":", color="black", linewidth=1.0)
    ax.set_xlim(0.0, cap * 1.03)
    ax.set_ylim(0.0, cap * 1.03)
    ax.set_xlabel("Birth")
    ax.set_ylabel("Death")
    ax.set_title("Persistence diagram")
    ax.legend(loc="lower right", fontsize=8)


## Interactive PH explorer for Vietoris-Rips complexes

At each \(\varepsilon\):

1. Points closer than \(\varepsilon\) form edges (1-simplices).
2. Cliques of three connected points form filled triangles (2-simplices).
3. As \(\varepsilon\) grows, components merge (H0 intervals die) and loops appear/disappear (H1 intervals are born/die).

The barcode and persistence diagram summarise these events across all filtration scales, so robust topological structure stands out as long intervals.


In [None]:
def render_scene(
    kind: str,
    n_points: int,
    noise: float,
    seed: int,
    epsilon: float,
    show_balls: bool,
    show_triangles: bool,
) -> None:
    """Render VR geometry, barcodes, and persistence diagram for one parameter setting."""
    noise = float(np.round(noise, 4))
    points, diagrams = cached_scene(kind, int(n_points), noise, int(seed))

    edges, triangles = build_vr_simplices(points, epsilon)
    if not show_triangles:
        triangles = []

    betti = betti_numbers_at_epsilon(diagrams, epsilon, max_dim=1)

    fig = plt.figure(figsize=(13.5, 6.0))
    gs = fig.add_gridspec(2, 2, width_ratios=[1.22, 1.0], height_ratios=[1.0, 1.0], wspace=0.3, hspace=0.3)

    ax_complex = fig.add_subplot(gs[:, 0])
    ax_barcode = fig.add_subplot(gs[0, 1])
    ax_diag = fig.add_subplot(gs[1, 1])

    plot_vr_complex(ax_complex, points, epsilon, edges, triangles, show_balls=show_balls)
    plot_barcodes(ax_barcode, diagrams, epsilon)
    plot_persistence_diagram(ax_diag, diagrams, epsilon)

    summary = (
        f"Betti_0 = {betti.get(0, 0)}   |   "
        f"Betti_1 = {betti.get(1, 0)}   |   "
        f"vertices = {len(points)}, edges = {len(edges)}, triangles = {len(triangles)}"
    )
    ax_complex.text(
        0.02,
        0.98,
        summary,
        transform=ax_complex.transAxes,
        va="top",
        fontsize=9,
        bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.85, "edgecolor": "#cccccc"},
    )

    fig.suptitle("Persistent homology of a Vietoris-Rips filtration", fontsize=13)
    plt.show()


dataset = widgets.Dropdown(
    options=["Noisy circle", "Two circles", "Figure-eight", "Three clusters"],
    value="Noisy circle",
    description="Dataset:",
)

n_points = widgets.IntSlider(value=45, min=15, max=90, step=1, description="Points:", continuous_update=False)
noise = widgets.FloatSlider(value=0.06, min=0.0, max=0.25, step=0.01, description="Noise:", continuous_update=False)
seed = widgets.IntSlider(value=7, min=0, max=200, step=1, description="Seed:", continuous_update=False)
epsilon = widgets.FloatSlider(value=0.35, min=0.01, max=2.4, step=0.01, description="Epsilon:", continuous_update=False)
show_balls = widgets.Checkbox(value=True, description="Show epsilon/2 balls")
show_triangles = widgets.Checkbox(value=True, description="Fill 2-simplices")

ui = widgets.VBox([
    widgets.HBox([dataset, n_points, noise]),
    widgets.HBox([seed, epsilon, show_balls, show_triangles]),
])

out = widgets.interactive_output(
    render_scene,
    {
        "kind": dataset,
        "n_points": n_points,
        "noise": noise,
        "seed": seed,
        "epsilon": epsilon,
        "show_balls": show_balls,
        "show_triangles": show_triangles,
    },
)

display(ui, out)


## Interactive PH explorer for Cubical complexes

We use a **sublevel-set filtration**: cells with value \(\le t\) are active at threshold \(t\).

- In **1D**, this is a filtered signal.
- In **2D**, this is a filtered image/height map.


In [None]:
import cripser


def _normalise01(arr: np.ndarray) -> np.ndarray:
    """Scale an array to [0, 1] for consistent threshold controls."""
    arr = np.asarray(arr, dtype=float)
    lo, hi = float(arr.min()), float(arr.max())
    if hi <= lo:
        return np.zeros_like(arr)
    return (arr - lo) / (hi - lo)


def make_cubical_array(kind: str) -> np.ndarray:
    """Create small 1D/2D toy arrays for cubical PH demos."""
    if kind == "1D: two valleys":
        arr = np.array([0.94, 0.78, 0.60, 0.38, 0.20, 0.35, 0.62, 0.86,
                        0.91, 0.66, 0.33, 0.18, 0.31, 0.58, 0.81, 0.95])
        return _normalise01(arr)

    if kind == "1D: wavy signal":
        x = np.linspace(-1.0, 1.0, 20)
        arr = 0.50 + 0.30 * np.sin(2.7 * np.pi * x) + 0.15 * np.cos(5.5 * np.pi * x)
        return _normalise01(arr)

    yy, xx = np.mgrid[-1.0:1.0:14j, -1.0:1.0:14j]

    if kind == "2D: two basins":
        b1 = (xx + 0.48) ** 2 + (yy - 0.05) ** 2
        b2 = (xx - 0.52) ** 2 + (yy + 0.02) ** 2
        arr = np.minimum(b1, b2) + 0.08 * (xx ** 2 + yy ** 2)
        return _normalise01(arr)

    if kind == "2D: ring valley":
        rr = np.sqrt(xx ** 2 + yy ** 2)
        arr = (rr - 0.56) ** 2 + 0.10 * rr
        return _normalise01(arr)

    raise ValueError(f"Unknown cubical dataset: {kind}")


@lru_cache(maxsize=32)
def cached_cubical_scene(kind: str, maxdim: int = 1):
    """Cache arrays and their cubical persistence outputs."""
    arr = make_cubical_array(kind)
    rows = cripser.compute_ph(arr, filtration="V", maxdim=maxdim, location="yes")

    diagrams: List[np.ndarray] = []
    if rows.size == 0:
        diagrams = [np.empty((0, 2), dtype=float) for _ in range(maxdim + 1)]
    else:
        dims = rows[:, 0].astype(int)
        for dim in range(maxdim + 1):
            block = rows[dims == dim][:, 1:3]
            if block.size == 0:
                block = np.empty((0, 2), dtype=float)
            diagrams.append(block)

    return arr, diagrams, rows


def _cubical_cap(diagrams: List[np.ndarray], threshold_max: float = 1.0) -> float:
    """Fixed plotting cap so axes stay stable when threshold slider moves."""
    finite = []
    for dgm in diagrams:
        for _, death in dgm:
            if np.isfinite(death):
                finite.append(float(death))
    if finite:
        cap = 1.15 * max(finite)
    else:
        cap = threshold_max + 0.1
    cap = max(cap, threshold_max + 0.1)
    return float(cap)


def plot_cubical_field(ax_val: plt.Axes, ax_mask: plt.Axes, arr: np.ndarray, threshold: float) -> None:
    """Plot scalar values and current sublevel-set mask for 1D or 2D arrays."""
    ax_val.clear()
    ax_mask.clear()

    if arr.ndim == 1:
        x = np.arange(arr.shape[0])
        active = arr <= threshold

        ax_val.plot(x, arr, color="#1d3557", linewidth=1.8)
        ax_val.scatter(x[~active], arr[~active], c="#8d99ae", s=35, label="inactive")
        ax_val.scatter(x[active], arr[active], c="#2a9d8f", s=42, label="active")
        ax_val.axhline(threshold, linestyle="--", color="black", linewidth=1.2)
        ax_val.set_title("1D scalar signal")
        ax_val.set_xlabel("Index")
        ax_val.set_ylabel("Value")
        ax_val.set_ylim(-0.05, 1.05)
        ax_val.legend(loc="upper right", fontsize=8)

        mask_img = active.astype(float)[None, :]
        ax_mask.imshow(mask_img, cmap="Greens", vmin=0.0, vmax=1.0, aspect="auto", origin="lower")
        ax_mask.set_title("Sublevel set (value ≤ t)")
        ax_mask.set_xlabel("Index")
        ax_mask.set_yticks([])

    elif arr.ndim == 2:
        mask = arr <= threshold

        ax_val.imshow(arr, cmap="viridis", vmin=0.0, vmax=1.0, origin="lower")
        ax_val.contour(arr, levels=[threshold], colors="white", linewidths=1.1)
        ax_val.set_title("2D scalar field")
        ax_val.set_xlabel("x")
        ax_val.set_ylabel("y")

        ax_mask.imshow(mask.astype(float), cmap="Greens", vmin=0.0, vmax=1.0, origin="lower")
        ax_mask.set_title("Sublevel set (value ≤ t)")
        ax_mask.set_xlabel("x")
        ax_mask.set_ylabel("y")

    else:
        raise ValueError("Only 1D and 2D arrays are supported in this demo.")


def plot_cubical_barcodes(ax: plt.Axes, diagrams: List[np.ndarray], threshold: float) -> None:
    """Plot cubical H0/H1 barcodes with a threshold marker."""
    ax.clear()
    cap = _cubical_cap(diagrams, threshold_max=1.0)
    colours = {0: "#457b9d", 1: "#e76f51"}

    row = 0
    for dim in range(min(2, len(diagrams))):
        dgm = diagrams[dim]
        if len(dgm) > 0:
            dgm = dgm[np.argsort(dgm[:, 0])]
        for birth, death in dgm:
            end = cap if not np.isfinite(death) else float(death)
            ax.plot([birth, end], [row, row], color=colours.get(dim, "#333333"), linewidth=2.0)
            if not np.isfinite(death):
                ax.plot(end, row, marker=">", color=colours.get(dim, "#333333"), markersize=5)
            row += 1
        row += 1

    ax.axvline(threshold, linestyle="--", color="black", linewidth=1.3)
    ax.text(0.01, 0.95, "H0 = blue, H1 = orange", transform=ax.transAxes, va="top", fontsize=9)
    ax.set_xlim(0.0, cap * 1.03)
    ax.set_ylim(-1, max(3, row))
    ax.set_xlabel("Filtration threshold t")
    ax.set_ylabel("Interval index")
    ax.set_title("Cubical persistence barcodes")


def plot_cubical_diagram(ax: plt.Axes, diagrams: List[np.ndarray], threshold: float) -> None:
    """Plot cubical persistence diagram (H0/H1) with a threshold guide."""
    ax.clear()
    cap = _cubical_cap(diagrams, threshold_max=1.0)

    ax.plot([0, cap], [0, cap], linestyle="--", color="grey", linewidth=1.0)
    labels = {0: "H0", 1: "H1"}
    colours = {0: "#457b9d", 1: "#e76f51"}

    for dim in range(min(2, len(diagrams))):
        dgm = diagrams[dim]
        if len(dgm) == 0:
            continue
        births = dgm[:, 0]
        deaths = np.where(np.isfinite(dgm[:, 1]), dgm[:, 1], cap)
        ax.scatter(births, deaths, s=28, alpha=0.9, label=labels.get(dim, f"H{dim}"), c=colours.get(dim, "#333333"))

    ax.axvline(threshold, linestyle=":", color="black", linewidth=1.0)
    ax.axhline(threshold, linestyle=":", color="black", linewidth=1.0)
    ax.set_xlim(0.0, cap * 1.03)
    ax.set_ylim(0.0, cap * 1.03)
    ax.set_xlabel("Birth")
    ax.set_ylabel("Death")
    ax.set_title("Cubical persistence diagram")
    ax.legend(loc="lower right", fontsize=8)


In [None]:
def render_cubical_scene(kind: str, threshold: float) -> None:
    """Render cubical filtration visuals with barcodes and persistence diagram."""
    arr, diagrams, _ = cached_cubical_scene(kind, maxdim=1)
    betti = betti_numbers_at_epsilon(diagrams, threshold, max_dim=1)
    active_cells = int((arr <= threshold).sum())

    fig = plt.figure(figsize=(13.5, 6.4))
    gs = fig.add_gridspec(2, 2, width_ratios=[1.22, 1.0], height_ratios=[1.0, 1.0], wspace=0.3, hspace=0.3)

    ax_val = fig.add_subplot(gs[0, 0])
    ax_mask = fig.add_subplot(gs[1, 0])
    ax_bar = fig.add_subplot(gs[0, 1])
    ax_diag = fig.add_subplot(gs[1, 1])

    plot_cubical_field(ax_val, ax_mask, arr, threshold)
    plot_cubical_barcodes(ax_bar, diagrams, threshold)
    plot_cubical_diagram(ax_diag, diagrams, threshold)

    summary = (
        f"Betti_0 = {betti.get(0, 0)}   |   "
        f"Betti_1 = {betti.get(1, 0)}   |   "
        f"active cells = {active_cells}/{arr.size}"
    )
    ax_mask.text(
        0.02,
        0.98,
        summary,
        transform=ax_mask.transAxes,
        va="top",
        fontsize=9,
        bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.85, "edgecolor": "#cccccc"},
    )

    fig.suptitle(f"Cubical complex PH (sublevel filtration): {kind}", fontsize=13)
    plt.show()


cubical_dataset = widgets.Dropdown(
    options=["1D: two valleys", "1D: wavy signal", "2D: two basins", "2D: ring valley"],
    value="2D: ring valley",
    description="Array:",
)

threshold = widgets.FloatSlider(
    value=0.35,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Threshold:",
    continuous_update=False,
)

cubical_ui = widgets.VBox([widgets.HBox([cubical_dataset, threshold])])

cubical_out = widgets.interactive_output(
    render_cubical_scene,
    {
        "kind": cubical_dataset,
        "threshold": threshold,
    },
)

display(cubical_ui, cubical_out)
