# 02 · CRT Demo (LV/RV endocardium → Trees → Scenarios → Exports → Viewer)

This notebook builds **separate Purkinje trees** on the bundled CRT demo meshes:
- `data/crtdemo/crtdemo_LVendo_heart_cut.obj`
- `data/crtdemo/crtdemo_RVendo_heart_cut.obj`

Then it runs a few **pacing scenarios** (baseline, RV-only, BiV with VV-delay), maps activation back to each **surface**, and exports everything for ParaView. The final cell renders an **inline Plotly** 3D view that reliably works in Colab.

**Outputs** (under `output/examples/02_crt_demo/`):
- LV/RV trees (`*_tree_AT.vtu`) and PMJs (`*_pmj.vtu`)
- Merged BiV tree per scenario (`biv_tree_AT_<scenario>.vtu`)
- Surface activation for LV/RV per scenario (`*_surface_AT_<scenario>.vtp`)
- Parameter snapshots (`params_lv.json`, `params_rv.json`)

In [None]:
%pip install -q --upgrade pip purkinje-uv plotly

In [None]:
from pathlib import Path
import os
import json
import numpy as np

from purkinje_uv import FractalTreeParameters, FractalTree, PurkinjeTree
import pyvista as pv  # IO + mesh ops
import plotly.graph_objects as go

# Repro + knobs
SEED = int(os.getenv("EXAMPLES_SEED", "1234"))
LITE = bool(int(os.getenv("EXAMPLES_LITE", "1")))          # fast by default
VV_DELAY_MS = int(os.getenv("CRT_VV_DELAY_MS", "0"))       # BiV LV delay vs RV (ms); try -40, 0, +40 later

# Locations
DATA_DIR = Path("data") / "crtdemo"
OUT_DIR = Path("output") / "examples" / "02_crt_demo"
OUT_DIR.mkdir(parents=True, exist_ok=True)

np.random.seed(SEED)
print(f"SEED={SEED}  LITE={LITE}  VV_DELAY_MS={VV_DELAY_MS}")
print("DATA_DIR:", DATA_DIR)
print("OUT_DIR:", OUT_DIR)

In [None]:
# Ensure demo OBJ files exist (download from repo if missing)
import urllib.request

LV_OBJ = DATA_DIR / "crtdemo_LVendo_heart_cut.obj"
RV_OBJ = DATA_DIR / "crtdemo_RVendo_heart_cut.obj"
DATA_DIR.mkdir(parents=True, exist_ok=True)

def _try_fetch(url: str, dst: Path):
    try:
        print("Downloading:", url)
        with urllib.request.urlopen(url, timeout=30) as r, open(dst, "wb") as f:
            f.write(r.read())
        return True
    except Exception as e:
        print("  failed:", e)
        return False

if not LV_OBJ.exists():
    for base in [
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/main/data/crtdemo/",
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/master/data/crtdemo/",
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/dev/data/crtdemo/",
    ]:
        if _try_fetch(base + LV_OBJ.name, LV_OBJ):
            break

if not RV_OBJ.exists():
    for base in [
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/main/data/crtdemo/",
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/master/data/crtdemo/",
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/dev/data/crtdemo/",
    ]:
        if _try_fetch(base + RV_OBJ.name, RV_OBJ):
            break

assert LV_OBJ.exists(), f"Missing LV OBJ: {LV_OBJ}"
assert RV_OBJ.exists(), f"Missing RV OBJ: {RV_OBJ}"
print("LV_OBJ:", LV_OBJ)
print("RV_OBJ:", RV_OBJ)

# Load with PyVista and verify they're open surfaces
def load_surface(path: Path) -> pv.PolyData:
    mesh = pv.read(str(path))
    if not isinstance(mesh, pv.PolyData):
        mesh = mesh.extract_geometry()
    mesh = mesh.clean().triangulate()
    return mesh

surf_lv = load_surface(LV_OBJ)
surf_rv = load_surface(RV_OBJ)

def is_open_surface(pd: pv.PolyData) -> bool:
    edges = pd.extract_feature_edges(boundary_edges=True, feature_edges=False,
                                     manifold_edges=False, non_manifold_edges=True)
    return edges.n_cells > 0

print(f"LV: points={surf_lv.n_points} faces={surf_lv.n_cells} open={is_open_surface(surf_lv)}")
print(f"RV: points={surf_rv.n_points} faces={surf_rv.n_cells} open={is_open_surface(surf_rv)}")

## Seeding & Parameters

We auto-pick seeds per chamber:
- **Root (`init_node_id`)**: vertex nearest to the **boundary center** (mean of boundary-edge points).
- **Direction (`second_node_id`)**: vertex **farthest** from the root (good initial heading).

Two parameter presets are provided (LITE vs FULL). You can tweak them later to change coverage or density.

In [None]:
# ---------- Utilities ----------
def boundary_center(pd: pv.PolyData) -> np.ndarray:
    be = pd.extract_feature_edges(boundary_edges=True, feature_edges=False,
                                  manifold_edges=False, non_manifold_edges=True)
    if be.n_points == 0:
        return pd.points.mean(0)
    return be.points.mean(0)

def auto_seeds(pd: pv.PolyData):
    bc = boundary_center(pd)
    pts = pd.points
    root = int(np.argmin(np.linalg.norm(pts - bc, axis=1)))
    direction = int(np.argmax(np.linalg.norm(pts - pts[root], axis=1)))
    return root, direction

def params_preset(meshfile: str, init_id: int, second_id: int, chamber: str = "LV"):
    # LITE defaults keep runtime modest; FULL is heavier coverage
    if LITE:
        lseg = 0.01
        if chamber.upper() == "LV":
            preset = dict(N_it=9, init_length=0.34, length=0.18, branch_angle=0.24, w=0.09,
                          fascicles_angles=[-0.6, -0.2, 0.2, 0.6],
                          fascicles_length=[0.20, 0.30, 0.30, 0.20])
        else:  # RV
            preset = dict(N_it=8, init_length=0.32, length=0.16, branch_angle=0.22, w=0.09,
                          fascicles_angles=[-0.5, 0.5],
                          fascicles_length=[0.25, 0.25])
    else:
        lseg = 0.008
        if chamber.upper() == "LV":
            preset = dict(N_it=11, init_length=0.36, length=0.20, branch_angle=0.26, w=0.08,
                          fascicles_angles=[-0.7, -0.35, 0.0, 0.35, 0.7],
                          fascicles_length=[0.25, 0.35, 0.40, 0.35, 0.25])
        else:
            preset = dict(N_it=10, init_length=0.34, length=0.18, branch_angle=0.24, w=0.08,
                          fascicles_angles=[-0.6, 0.0, 0.6],
                          fascicles_length=[0.25, 0.30, 0.25])

    return FractalTreeParameters(
        meshfile=meshfile,
        l_segment=lseg,
        init_length=preset["init_length"],
        length=preset["length"],
        branch_angle=preset["branch_angle"],
        w=preset["w"],
        init_node_id=init_id,
        second_node_id=second_id,
        N_it=preset["N_it"],
        fascicles_angles=preset["fascicles_angles"],
        fascicles_length=[x * lseg for x in preset["fascicles_length"]],
    )

def grow_tree(params: FractalTreeParameters):
    ft = FractalTree(params=params)
    ft.grow_tree()
    return ft

def to_arrays(ft: FractalTree):
    nodes = np.asarray(ft.nodes_xyz, dtype=float)
    edges = np.asarray(ft.connectivity, dtype=int)
    pmj   = np.asarray(ft.end_nodes, dtype=int)
    return nodes, edges, pmj

def merge_trees(nodes_a, edges_a, pmj_a, nodes_b, edges_b, pmj_b):
    nA = nodes_a.shape[0]
    nodes = np.vstack([nodes_a, nodes_b])
    edges = np.vstack([edges_a, edges_b + nA])
    pmj   = np.concatenate([pmj_a, pmj_b + nA])
    return nodes, edges, pmj

def build_P(nodes, edges, pmj) -> PurkinjeTree:
    return PurkinjeTree(nodes=np.asarray(nodes, dtype=float),
                        connectivity=np.asarray(edges, dtype=int),
                        end_nodes=np.asarray(pmj, dtype=int))

def solve_activation(P: PurkinjeTree, sources):
    """sources: list of (node_index, start_time_ms)"""
    x0 = np.array([s[0] for s in sources], dtype=int)
    x0_vals = np.array([float(s[1]) for s in sources], dtype=float)
    return P.activate_fim(x0=x0, x0_vals=x0_vals, return_only_pmj=False)

def nearest_map_activation(surface: pv.PolyData, nodes_xyz: np.ndarray, AT: np.ndarray, chunk=2000):
    pts = surface.points
    out = np.empty(pts.shape[0], dtype=float)
    n = pts.shape[0]
    for i in range(0, n, chunk):
        j = min(i + chunk, n)
        batch = pts[i:j]
        # (batch, 1, 3) - (1, nodes, 3) -> (batch, nodes, 3)
        d2 = np.sum((batch[:, None, :] - nodes_xyz[None, :, :]) ** 2, axis=2)
        idx = np.argmin(d2, axis=1)
        out[i:j] = AT[idx]
    surf_copy = surface.copy(deep=True)
    surf_copy.point_data.clear()
    surf_copy.point_data["AT"] = out
    return surf_copy

def save_surface_vtp(surf: pv.PolyData, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    surf.save(str(path))
# --------------------------------------

# Auto-pick seeds
init_lv, second_lv = auto_seeds(surf_lv)
init_rv, second_rv = auto_seeds(surf_rv)
print("Seeds (LV):", init_lv, second_lv)
print("Seeds (RV):", init_rv, second_rv)

# Build params + snapshots
params_lv = params_preset(str(LV_OBJ), init_lv, second_lv, chamber="LV")
params_rv = params_preset(str(RV_OBJ), init_rv, second_rv, chamber="RV")

(params_lv.to_json_file(OUT_DIR / "params_lv.json"), params_rv.to_json_file(OUT_DIR / "params_rv.json"))
print("Saved params:", OUT_DIR / "params_lv.json", "|", OUT_DIR / "params_rv.json")

In [None]:
# Grow LV & RV
ft_lv = grow_tree(params_lv)
ft_rv = grow_tree(params_rv)

nodes_lv, edges_lv, pmj_lv = to_arrays(ft_lv)
nodes_rv, edges_rv, pmj_rv = to_arrays(ft_rv)

print(f"LV: nodes={nodes_lv.shape[0]} edges={edges_lv.shape[0]} pmj={pmj_lv.shape[0]}")
print(f"RV: nodes={nodes_rv.shape[0]} edges={edges_rv.shape[0]} pmj={pmj_rv.shape[0]}")

In [None]:
P_lv = build_P(nodes_lv, edges_lv, pmj_lv)
P_rv = build_P(nodes_rv, edges_rv, pmj_rv)

# quick baseline activation per chamber (root @ 0ms) — mainly as a sanity check
AT_lv_baseline = solve_activation(P_lv, sources=[(0, 0.0)])
AT_rv_baseline = solve_activation(P_rv, sources=[(0, 0.0)])
print("Baseline sanity — LV AT(min/max):", float(AT_lv_baseline.min()), float(AT_lv_baseline.max()))
print("Baseline sanity — RV AT(min/max):", float(AT_rv_baseline.min()), float(AT_rv_baseline.max()))

# Export chamber trees + pmj
lv_tree_path = OUT_DIR / "lv_tree_AT.vtu"
rv_tree_path = OUT_DIR / "rv_tree_AT.vtu"
lv_pmj_path  = OUT_DIR / "lv_pmj.vtu"
rv_pmj_path  = OUT_DIR / "rv_pmj.vtu"

P_lv.save(str(lv_tree_path)); P_lv.save_pmjs(str(lv_pmj_path))
P_rv.save(str(rv_tree_path)); P_rv.save_pmjs(str(rv_pmj_path))

print("Wrote:")
print(" -", lv_tree_path)
print(" -", rv_tree_path)
print(" -", lv_pmj_path)
print(" -", rv_pmj_path)

## Merge Trees & Target Pacing Sites

We merge LV+RV into a single **disconnected** tree for BiV scenarios (FIM supports multiple sources).  
We then pick:
- **RV apex**: extreme along the first principal component (PC1) of the RV surface.
- **LV lateral**: extreme along the second principal component (PC2) of the LV surface.

Finally, we locate the **nearest tree node** to each target point.

In [None]:
# Merge LV + RV into a single tree (no connecting edge)
nodes_biv, edges_biv, pmj_biv = merge_trees(nodes_lv, edges_lv, pmj_lv, nodes_rv, edges_rv, pmj_rv)
P_biv = build_P(nodes_biv, edges_biv, pmj_biv)

# Helpers for PCA-based targets
def pca_axes(pts: np.ndarray):
    X = pts - pts.mean(0, keepdims=True)
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    # rows of Vt are principal axes
    return Vt

def extreme_point_along(pc: np.ndarray, pts: np.ndarray, which: str = "min"):
    proj = pts @ pc
    return pts[np.argmin(proj) if which == "min" else np.argmax(proj)]

def nearest_node_idx(nodes: np.ndarray, point: np.ndarray) -> int:
    return int(np.argmin(np.sum((nodes - point[None, :])**2, axis=1)))

# RV apex (along PC1)
Vt_rv = pca_axes(surf_rv.points)
rv_apex_pt = extreme_point_along(Vt_rv[0], surf_rv.points, which="min")
rv_apex_idx_rv = nearest_node_idx(nodes_rv, rv_apex_pt)       # index in RV tree

# LV lateral (along PC2)
Vt_lv = pca_axes(surf_lv.points)
lv_lat_pt = extreme_point_along(Vt_lv[1], surf_lv.points, which="max")
lv_lat_idx_lv = nearest_node_idx(nodes_lv, lv_lat_pt)         # index in LV tree

# Convert LV/RV indices to merged space
offset_rv = nodes_lv.shape[0]
lv_lat_idx_biv = lv_lat_idx_lv                                # LV stays as-is
rv_apex_idx_biv = rv_apex_idx_rv + offset_rv                  # RV shifted by LV size

print("Lead targets → RV apex (rv idx):", rv_apex_idx_rv, " | LV lateral (lv idx):", lv_lat_idx_lv)
print("Merged indices  → RV apex:", rv_apex_idx_biv, " | LV lateral:", lv_lat_idx_biv)

In [None]:
# Define scenarios
# - baseline: stimulate the root (node 0) of each chamber at t=0 ms
# - rv_only: RV apex at t=0
# - biv: RV apex at t=0 and LV lateral at t=VV_DELAY_MS
scenarios = {}

# Baseline: two independent stim in merged tree (LV root @0, RV root+offset @0)
src_baseline = [(0, 0.0), (nodes_lv.shape[0], 0.0)]
scenarios["baseline"] = solve_activation(P_biv, src_baseline)

# RV-only
src_rv = [(rv_apex_idx_biv, 0.0)]
scenarios["rv_only"] = solve_activation(P_biv, src_rv)

# BiV
src_biv = [(rv_apex_idx_biv, 0.0), (lv_lat_idx_biv, float(VV_DELAY_MS)/1000.0)]  # adjust units if needed
scenarios["biv"] = solve_activation(P_biv, src_biv)

for name, AT in scenarios.items():
    print(f"{name}: AT min/max = {float(AT.min()):.4f} / {float(AT.max()):.4f}")

In [None]:
# Map activations back to each surface and export
def export_surface_AT(name: str, AT_biv: np.ndarray):
    # Split node arrays back into LV and RV portions
    n_lv = nodes_lv.shape[0]
    AT_lv = AT_biv[:n_lv]
    AT_rv = AT_biv[n_lv:]
    # Map to surfaces (nearest neighbor)
    lv_at_surf = nearest_map_activation(surf_lv, nodes_lv, AT_lv)
    rv_at_surf = nearest_map_activation(surf_rv, nodes_rv, AT_rv)
    # Save
    lv_path = OUT_DIR / f"lv_surface_AT_{name}.vtp"
    rv_path = OUT_DIR / f"rv_surface_AT_{name}.vtp"
    save_surface_vtp(lv_at_surf, lv_path)
    save_surface_vtp(rv_at_surf, rv_path)
    print("Saved surfaces:", lv_path.name, "|", rv_path.name)

for scen, AT in scenarios.items():
    export_surface_AT(scen, AT)

In [None]:
# Save merged tree per scenario (useful for ParaView)
def save_tree_with_AT(nodes, edges, AT, name: str):
    Ptmp = build_P(nodes, edges, pmj_biv)
    path = OUT_DIR / f"biv_tree_AT_{name}.vtu"
    Ptmp.save(str(path))
    print("Saved:", path.name)

for scen, AT in scenarios.items():
    save_tree_with_AT(nodes_biv, edges_biv, AT, scen)

## Plotly Viewer (LV/RV surfaces + tree + PMJs)

Interactive and reliable in Colab. Use the dropdown to switch scenarios.

In [None]:
# Build Plotly figure with a dropdown to switch scenarios
def mesh3d_from_surface_with_AT(surface: pv.PolyData):
    pts = surface.points
    faces = surface.faces.reshape(-1, 4)[:, 1:]
    at = surface.point_data["AT"] if "AT" in surface.point_data else None
    trace = go.Mesh3d(
        x=pts[:,0], y=pts[:,1], z=pts[:,2],
        i=faces[:,0], j=faces[:,1], k=faces[:,2],
        name="surface", opacity=0.25, lighting=dict(ambient=0.6, diffuse=0.6),
        intensity=at if at is not None else None,
        colorscale="Viridis",
        showscale=True if at is not None else False,
        colorbar=dict(title="AT") if at is not None else None,
    )
    return trace

def line_segments_from_edges(points: np.ndarray, edges: np.ndarray, name="tree", width=3):
    xs, ys, zs = [], [], []
    for u, v in edges:
        xs += [points[u,0], points[v,0], None]
        ys += [points[u,1], points[v,1], None]
        zs += [points[u,2], points[v,2], None]
    return go.Scatter3d(x=xs, y=ys, z=zs, mode="lines", line=dict(width=width), name=name)

def scatter_points(points: np.ndarray, name="PMJs", size=4, color="crimson"):
    if points.size == 0:
        return go.Scatter3d(x=[], y=[], z=[], mode="markers", name=name)
    return go.Scatter3d(x=points[:,0], y=points[:,1], z=points[:,2],
                        mode="markers", marker=dict(size=size, color=color, symbol="square"),
                        name=name)

# Precompute PMJ coordinates
pmj_coords_lv = nodes_lv[pmj_lv] if pmj_lv.size else np.empty((0,3))
pmj_coords_rv = nodes_rv[pmj_rv] if pmj_rv.size else np.empty((0,3))

# Prepare surface AT variants for each scenario
surfaces_by_scen = {}
for scen, AT in scenarios.items():
    n_lv = nodes_lv.shape[0]
    lv_at = AT[:n_lv]
    rv_at = AT[n_lv:]
    lv_surf = nearest_map_activation(surf_lv, nodes_lv, lv_at)
    rv_surf = nearest_map_activation(surf_rv, nodes_rv, rv_at)
    surfaces_by_scen[scen] = (lv_surf, rv_surf)

scen_list = list(surfaces_by_scen.keys())

# Construct traces per scenario
traces_all = []
vis_masks = []
for scen in scen_list:
    lv_surf, rv_surf = surfaces_by_scen[scen]
    # Surfaces
    t_lv = mesh3d_from_surface_with_AT(lv_surf); traces_all.append(t_lv)
    t_rv = mesh3d_from_surface_with_AT(rv_surf); traces_all.append(t_rv)
    # Trees (same geometry for all scenarios)
    t_tree_lv = line_segments_from_edges(nodes_lv, edges_lv, name="LV tree"); traces_all.append(t_tree_lv)
    t_tree_rv = line_segments_from_edges(nodes_rv, edges_rv, name="RV tree"); traces_all.append(t_tree_rv)
    # PMJs
    t_pmj_lv = scatter_points(pmj_coords_lv, name="LV PMJs"); traces_all.append(t_pmj_lv)
    t_pmj_rv = scatter_points(pmj_coords_rv, name="RV PMJs"); traces_all.append(t_pmj_rv)
    # visibility mask for this scenario (True for these 6 traces; False for others)
    mask = [False]* (6*len(scen_list))
    start = scen_list.index(scen)*6
    for k in range(start, start+6):
        mask[k] = True
    vis_masks.append(mask)

fig = go.Figure(data=traces_all)
# Initial visibility: first scenario
if vis_masks:
    for trace, show in zip(fig.data, vis_masks[0]):
        trace.visible = show

# Dropdown buttons
buttons = []
for i, scen in enumerate(scen_list):
    buttons.append(dict(method="update",
                        label=scen,
                        args=[{"visible": vis_masks[i]}],
                        args2=[{"visible": vis_masks[0]}]))

fig.update_layout(
    updatemenus=[dict(type="dropdown", x=0.01, y=0.99, buttons=buttons, showactive=True)],
    scene=dict(aspectmode="data"),
    margin=dict(l=0, r=0, t=30, b=0),
    title=f"CRT Demo — scenarios: {', '.join(scen_list)}"
)

fig.show()