# Partitioning trade-offs in QuASAr

This notebook explores how QuASAr's cost estimator responds to different partitioning strategies.  Synthetic fragments are used to probe backend feasibility, runtime and memory projections, and the conversion primitives required to glue heterogeneous plans together.

## Goals

* Provide an editable sandbox for fragment parameters (size, sparsity, locality, resource ceilings) and visualise backend feasibility.
* Compare monolithic execution against two- and three-segment plans that incorporate conversion costs, highlighting when partitioning improves the projected runtime or memory footprint.
* Map how gate mix, boundary widths and conversion primitives interact to make partitioning advantageous according to the model.

In [None]:
import itertools
import math
from collections.abc import Iterable, Sequence

import numpy as np
import pandas as pd

from docs.utils.partitioning_analysis import (
    FragmentStats,
    BoundarySpec,
    aggregate_partitioned_plan,
    evaluate_fragment_backends,
)
from quasar.cost import Backend, Cost, CostEstimator

def synthesise_fragment(
    num_qubits: int,
    depth: int,
    entangling_ratio: float,
    measurement_ratio: float = 0.0,
    *,
    is_clifford: bool = False,
    is_local: bool = False,
    frontier: int | None = None,
    frontier_scale: float | None = None,
    chi: int | Sequence[int] | None = None,
) -> FragmentStats:
    """Create a FragmentStats instance using coarse circuit descriptors."""

    depth = max(int(depth), 1)
    entangling_ratio = float(entangling_ratio)
    entangling_layers = max(0, min(depth, int(round(depth * entangling_ratio))))
    one_qubit_layers = depth - entangling_layers
    entangling_pairs = max(num_qubits - 1, 1)
    num_2q_gates = entangling_layers * entangling_pairs
    num_1q_gates = one_qubit_layers * num_qubits
    num_measurements = int(round(measurement_ratio * num_qubits))
    if frontier is None and frontier_scale is not None:
        frontier = max(1, int(round(num_qubits * frontier_scale)))
    return FragmentStats(
        num_qubits=num_qubits,
        num_1q_gates=num_1q_gates,
        num_2q_gates=num_2q_gates,
        num_measurements=num_measurements,
        is_clifford=is_clifford,
        is_local=is_local,
        frontier=frontier,
        chi=chi,
    )


def _to_iterable(value) -> list:
    if isinstance(value, (list, tuple, set, range)):
        return list(value)
    if hasattr(value, "tolist"):
        return list(value)
    return [value]


def _limit_to_bytes(value: float | None) -> float | None:
    if value is None:
        return None
    return float(value) * (1024**3)


def safe_log10(values: Iterable[float], *, floor: float = -12.0) -> np.ndarray:
    array = np.asarray(list(values), dtype=float)
    finite = np.maximum(array, 10 ** floor)
    return np.log10(finite)


def evaluate_parameter_grid(
    fragment_axes: dict,
    metric_axes: dict,
    resource_limits: dict,
    *,
    allow_tableau: bool = True,
    estimator: CostEstimator | None = None,
) -> pd.DataFrame:
    """Sweep synthetic fragment parameters and record backend selections."""

    estimator = estimator or CostEstimator()
    max_memory = _limit_to_bytes(resource_limits.get("max_memory_gb"))
    max_time = resource_limits.get("max_time_s")

    fragment_keys = list(fragment_axes.keys())
    metric_keys = list(metric_axes.keys())
    metric_products = [_to_iterable(metric_axes[key]) for key in metric_keys]
    if not metric_products:
        metric_products = [[]]

    rows: list[dict] = []
    for frag_values in itertools.product(*(_to_iterable(fragment_axes[key]) for key in fragment_keys)):
        frag_params = dict(zip(fragment_keys, frag_values))
        stats = synthesise_fragment(**frag_params)
        for metric_values in itertools.product(*metric_products):
            metrics = dict(zip(metric_keys, metric_values))
            backend, diag = evaluate_fragment_backends(
                stats,
                sparsity=metrics.get("sparsity"),
                phase_rotation_diversity=metrics.get("phase_rotation_diversity"),
                amplitude_rotation_diversity=metrics.get("amplitude_rotation_diversity"),
                allow_tableau=allow_tableau,
                max_memory=max_memory,
                max_time=max_time,
                estimator=estimator,
            )
            if backend is None:
                selected_time = math.nan
                selected_memory = math.nan
            else:
                selected_time = diag["selected_cost"].time
                selected_memory = diag["selected_cost"].memory
            row = {**frag_params, **metrics}
            row["num_1q_gates"] = stats.num_1q_gates
            row["num_2q_gates"] = stats.num_2q_gates
            row["num_measurements"] = stats.num_measurements
            row["selected_backend"] = backend.name if backend else None
            row["selected_time"] = selected_time
            row["selected_memory"] = selected_memory
            for cand_backend, entry in diag["backends"].items():
                label = cand_backend.name.lower()
                feasible = entry.get("feasible") if isinstance(entry, dict) else None
                cost = entry.get("cost") if isinstance(entry, dict) else None
                row[f"{label}_feasible"] = feasible
                row[f"{label}_time"] = cost.time if cost else math.nan
                row[f"{label}_memory"] = cost.memory if cost else math.nan
            rows.append(row)
    return pd.DataFrame(rows)


def run_plan(
    fragment_stats: Sequence[FragmentStats],
    fragment_metrics: Sequence[dict],
    *,
    boundaries: Sequence[BoundarySpec] | None = None,
    resource_limits: dict | None = None,
    allow_tableau: bool = True,
    estimator: CostEstimator | None = None,
) -> dict:
    """Choose backends for each fragment and aggregate plan costs."""

    if len(fragment_stats) != len(fragment_metrics):
        raise ValueError("metrics must align with fragment list")
    estimator = estimator or CostEstimator()
    resource_limits = resource_limits or {}
    max_memory = _limit_to_bytes(resource_limits.get("max_memory_gb"))
    max_time = resource_limits.get("max_time_s")

    selections: list[tuple[Backend, object]] = []
    diagnostics: list[dict] = []
    for stats, metrics in zip(fragment_stats, fragment_metrics):
        backend, diag = evaluate_fragment_backends(
            stats,
            sparsity=metrics.get("sparsity"),
            phase_rotation_diversity=metrics.get("phase_rotation_diversity"),
            amplitude_rotation_diversity=metrics.get("amplitude_rotation_diversity"),
            allow_tableau=allow_tableau,
            max_memory=max_memory,
            max_time=max_time,
            estimator=estimator,
        )
        if backend is None:
            raise RuntimeError("fragment infeasible under the selected limits")
        selections.append((backend, diag["selected_cost"]))
        diagnostics.append(diag)

    if boundaries:
        plan = aggregate_partitioned_plan(selections, boundaries, estimator=estimator)
        total_cost = plan["total_cost"]
        conversions = plan["conversions"]
    else:
        total_cost = aggregate_single_backend_plan(selections)
        conversions = []
    return {
        "fragments": diagnostics,
        "selections": selections,
        "total_cost": total_cost,
        "conversions": conversions,
    }


def plan_overview(label: str, plan: dict) -> pd.DataFrame:
    cost = plan["total_cost"]
    backends = " → ".join(selection[0].name for selection in plan["selections"])
    return pd.DataFrame(
        [
            {
                "plan": label,
                "backends": backends,
                "total_time": cost.time,
                "peak_memory": cost.memory,
                "conversion_time": cost.conversion,
            }
        ]
    )



def fragment_breakdown(plan: dict) -> pd.DataFrame:
    rows = []
    for idx, (backend, cost) in enumerate(plan["selections"]):
        diag = plan["fragments"][idx]
        metrics = diag.get("metrics", {}) if isinstance(diag, dict) else {}
        stats = diag.get("stats") if isinstance(diag, dict) else None

        def _lookup(field: str):
            if stats is not None and hasattr(stats, field):
                return getattr(stats, field)
            return metrics.get(field)

        rows.append(
            {
                "fragment": idx,
                "backend": backend.name,
                "num_qubits": _lookup("num_qubits"),
                "num_1q_gates": _lookup("num_1q_gates"),
                "num_2q_gates": _lookup("num_2q_gates"),
                "num_gates": metrics.get("num_gates"),
                "num_measurements": _lookup("num_measurements"),
                "sparsity": metrics.get("sparsity"),
                "is_clifford": _lookup("is_clifford"),
                "time": cost.time,
                "memory": cost.memory,
            }
        )
    return pd.DataFrame(rows)


def conversion_breakdown(plan: dict) -> pd.DataFrame:
    if not plan["conversions"]:
        return pd.DataFrame(columns=["index", "source", "target", "primitive", "time", "memory"])
    rows = []
    for entry in plan["conversions"]:
        rows.append(
            {
                "index": entry["index"],
                "source": entry["source"].name,
                "target": entry["target"].name,
                "primitive": entry["primitive"],
                "time": entry["cost"].time,
                "memory": entry["cost"].memory,
            }
        )
    return pd.DataFrame(rows)


def evaluate_partition_advantage(
    gate_mixes: Sequence[float],
    boundary_qubits: Sequence[int],
    ranks: Sequence[int],
    *,
    total_qubits: int = 34,
    total_depth: int = 72,
    local_threshold: float = 0.32,
    estimator: CostEstimator | None = None,
) -> pd.DataFrame:
    """Map when a two-fragment plan beats a monolithic execution."""

    estimator = estimator or CostEstimator()
    rows: list[dict] = []
    depth_a = int(round(total_depth * 0.55))
    depth_b = total_depth - depth_a

    for gate_mix, boundary, rank in itertools.product(gate_mixes, boundary_qubits, ranks):
        monolithic = synthesise_fragment(total_qubits, total_depth, gate_mix, is_local=False)
        mono_backend, mono_diag = evaluate_fragment_backends(
            monolithic,
            sparsity=max(0.35, 1.0 - gate_mix * 1.4),
            estimator=estimator,
        )
        mono_cost = mono_diag["selected_cost"]

        local_flag = gate_mix <= local_threshold
        if local_flag:
            frag_a = synthesise_fragment(
                18,
                depth_a,
                max(0.08, gate_mix * 0.55),
                is_local=True,
                frontier_scale=0.24,
                chi=64,
            )
            sparsity_a = min(0.95, 0.8 + (0.26 - gate_mix) * 1.5)
        else:
            frag_a = synthesise_fragment(
                total_qubits,
                depth_a,
                gate_mix * 0.9,
                is_local=False,
            )
            sparsity_a = max(0.35, 1.0 - gate_mix * 1.1)

        frag_b = synthesise_fragment(total_qubits, depth_b, gate_mix, is_local=False)
        sparsity_b = max(0.35, 1.0 - gate_mix * 1.3)

        sel_a, diag_a = evaluate_fragment_backends(
            frag_a,
            sparsity=sparsity_a,
            estimator=estimator,
        )
        sel_b, diag_b = evaluate_fragment_backends(
            frag_b,
            sparsity=sparsity_b,
            estimator=estimator,
        )

        boundary_spec = BoundarySpec(
            num_qubits=boundary,
            rank=rank,
            frontier=max(boundary, 12),
            window=min(10, boundary // 2 + 2),
            window_1q_gates=boundary * 4,
            window_2q_gates=boundary * 2,
        )

        plan = aggregate_partitioned_plan(
            [(sel_a, diag_a["selected_cost"]), (sel_b, diag_b["selected_cost"])],
            [boundary_spec],
            estimator=estimator,
        )
        total = plan["total_cost"]
        primitive = plan["conversions"][0]["primitive"] if plan["conversions"] else "None"
        rows.append(
            {
                "gate_mix": gate_mix,
                "boundary_qubits": boundary,
                "rank": rank,
                "fragment_a_backend": sel_a.name,
                "fragment_b_backend": sel_b.name,
                "monolithic_backend": mono_backend.name if mono_backend else None,
                "partition_time": total.time,
                "monolithic_time": mono_cost.time,
                "speedup": mono_cost.time / total.time if total.time else float("inf"),
                "primitive": primitive,
                "partition_wins": total.time < mono_cost.time,
            }
        )
    return pd.DataFrame(rows)

## Parameter grid (edit me)

The dictionaries below define the parameter sweep for the feasibility study.  Update the lists to explore different fragment sizes, sparsities or resource ceilings.

In [None]:

fragment_axes = {
    "num_qubits": [18, 24, 30],
    "depth": [48, 64],
    "entangling_ratio": [0.12, 0.18, 0.28, 0.38],
    "measurement_ratio": [0.0],
    "is_clifford": [False, True],
    "is_local": [False, True],
    "frontier": [None],
    "frontier_scale": [0.25],
    "chi": [None],
}
metric_axes = {
    "sparsity": np.linspace(0.55, 0.9, 4),
    "phase_rotation_diversity": [6],
    "amplitude_rotation_diversity": [8],
}
resource_limits = {"max_memory_gb": 64, "max_time_s": None}
ALLOW_TABLEAU = True
ESTIMATOR = CostEstimator()


In [None]:
grid_results = evaluate_parameter_grid(
    fragment_axes,
    metric_axes,
    resource_limits,
    allow_tableau=ALLOW_TABLEAU,
    estimator=ESTIMATOR,
)
grid_results["log_selected_time"] = safe_log10(grid_results["selected_time"])
grid_results["log_selected_memory"] = safe_log10(grid_results["selected_memory"])
grid_results.head()


In [None]:
import matplotlib as mpl
from matplotlib.colors import ListedColormap

CELL_SUMMARY_MODE = "mode_backend"  # Choices: "mode_backend", "majority_backend", "mean_runtime"
SUMMARY_LABELS = {
    "mode_backend": "Most common backend",
    "majority_backend": "Backend majority (>50%)",
    "mean_runtime": "Mean log10 runtime",
}
if CELL_SUMMARY_MODE not in SUMMARY_LABELS:
    raise ValueError(
        f"CELL_SUMMARY_MODE must be one of {', '.join(SUMMARY_LABELS)}, got {CELL_SUMMARY_MODE!r}"
    )

preferred_backend_order = [
    "STATEVECTOR",
    "DECISION_DIAGRAM",
    "TABLEAU",
    "MPS",
    "TN",
    "HYBRID",
    "CUSTOM",
]
color_lookup = {
    "STATEVECTOR": "#4477aa",
    "DECISION_DIAGRAM": "#66c2a5",
    "TABLEAU": "#8dd3c7",
    "MPS": "#ffa600",
    "TN": "#aa3377",
    "HYBRID": "#ccbb44",
    "CUSTOM": "#7f7f7f",
    "MIXED": "#bdbdbd",
}
fallback_colors = ["#a1d99b", "#984ea3", "#f781bf", "#999999"]

selected_depth = _to_iterable(fragment_axes["depth"])[0]
subset = grid_results[grid_results["depth"] == selected_depth].copy()
if "phase_rotation_diversity" in subset and not subset["phase_rotation_diversity"].isna().all():
    subset = subset[
        subset["phase_rotation_diversity"]
        == _to_iterable(metric_axes["phase_rotation_diversity"])[0]
    ]
if "amplitude_rotation_diversity" in subset and not subset["amplitude_rotation_diversity"].isna().all():
    subset = subset[
        subset["amplitude_rotation_diversity"]
        == _to_iterable(metric_axes["amplitude_rotation_diversity"])[0]
    ]

sweep_columns = sorted({*fragment_axes.keys(), *metric_axes.keys()})
sweep_columns = [col for col in sweep_columns if col in subset.columns]


def _mode_or_first(series: pd.Series):
    cleaned = series.dropna()
    if cleaned.empty:
        return np.nan
    modes = cleaned.mode()
    if not modes.empty:
        return modes.iloc[0]
    return cleaned.iloc[0]


def _log_value(value: float) -> float:
    if value is None or not np.isfinite(value):
        return np.nan
    return float(safe_log10([value])[0])


if sweep_columns:
    aggregated = (
        subset.groupby(sweep_columns, dropna=False)
        .agg(
            selected_backend=("selected_backend", _mode_or_first),
            selected_time=("selected_time", "mean"),
            selected_memory=("selected_memory", "mean"),
            log_selected_time=("log_selected_time", "mean"),
            log_selected_memory=("log_selected_memory", "mean"),
            num_1q_gates=("num_1q_gates", "first"),
            num_2q_gates=("num_2q_gates", "first"),
            num_measurements=("num_measurements", "first"),
        )
        .reset_index()
    )
else:
    aggregated = subset.copy()


def _summarise_variants(group: pd.DataFrame) -> pd.Series:
    backend_series = group["selected_backend"].dropna().astype(str)
    counts = backend_series.value_counts()
    total = int(counts.sum())
    if total:
        top_backend = counts.index[0]
        top_count = int(counts.iloc[0])
        majority_backend = top_backend if top_count > total / 2 else "MIXED"
        majority_fraction = top_count / total
    else:
        top_backend = np.nan
        majority_backend = np.nan
        majority_fraction = np.nan
    runtime_series = group["selected_time"].dropna().astype(float)
    if runtime_series.empty:
        mean_runtime = np.nan
        median_runtime = np.nan
    else:
        mean_runtime = runtime_series.mean()
        median_runtime = runtime_series.median()
    mean_log_runtime = _log_value(mean_runtime) if np.isfinite(mean_runtime) else np.nan
    median_log_runtime = _log_value(median_runtime) if np.isfinite(median_runtime) else np.nan
    log_series = group["log_selected_time"].dropna().astype(float)
    if log_series.size > 1:
        runtime_std = log_series.std(ddof=0)
    elif log_series.size == 1:
        runtime_std = 0.0
    else:
        runtime_std = np.nan
    return pd.Series(
        {
            "mode_backend": top_backend if total else np.nan,
            "majority_backend": majority_backend if total else np.nan,
            "majority_fraction": majority_fraction,
            "mean_runtime": mean_runtime,
            "median_runtime": median_runtime,
            "mean_log_runtime": mean_log_runtime,
            "median_log_runtime": median_log_runtime,
            "runtime_std": runtime_std,
            "num_variants": total,
            "unique_backends": len(counts),
            "backend_labels": tuple(counts.index.tolist()) if total else tuple(),
            "tableau_fraction": (float(counts.get("TABLEAU", 0)) / total) if total else np.nan,
        }
    )


if aggregated.empty:
    cell_summary = pd.DataFrame(
        columns=[
            "is_local",
            "num_qubits",
            "sparsity",
            "mode_backend",
            "majority_backend",
            "majority_fraction",
            "mean_runtime",
            "median_runtime",
            "mean_log_runtime",
            "median_log_runtime",
            "runtime_std",
            "num_variants",
            "unique_backends",
        ]
    )
else:
    cell_summary = (
        aggregated.groupby(["is_local", "num_qubits", "sparsity"], dropna=False)
        .apply(_summarise_variants)
        .reset_index()
    )

summary_column_map = {
    "mode_backend": "mode_backend",
    "majority_backend": "majority_backend",
    "mean_runtime": "mean_log_runtime",
}
summary_column = summary_column_map[CELL_SUMMARY_MODE]
backend_mode = CELL_SUMMARY_MODE in {"mode_backend", "majority_backend"}

runtime_values = cell_summary.get("mean_log_runtime")
if runtime_values is not None:
    runtime_values = runtime_values.to_numpy(dtype=float)
    runtime_values = runtime_values[np.isfinite(runtime_values)]
else:
    runtime_values = np.array([])
time_vmin = runtime_values.min() if runtime_values.size else None
time_vmax = runtime_values.max() if runtime_values.size else None

if backend_mode:
    backend_values = cell_summary.get(summary_column, pd.Series(dtype=object)).dropna().astype(str)
    backend_values = backend_values[backend_values.ne("")]
    observed_backends = list(dict.fromkeys(backend_values))
    backend_order = [b for b in preferred_backend_order if b in observed_backends]
    backend_order.extend([b for b in observed_backends if b not in backend_order and b != "MIXED"])
    if "MIXED" in observed_backends:
        backend_order.append("MIXED")
    if not backend_order:
        backend_order = ["STATEVECTOR"]
    colors = []
    fallback_pool = list(fallback_colors)
    for backend in backend_order:
        color = color_lookup.get(backend)
        if color is None:
            color = fallback_pool.pop(0) if fallback_pool else "#cccccc"
        colors.append(color)
    backend_cmap = ListedColormap(colors)
    backend_norm = mpl.colors.BoundaryNorm(
        np.arange(len(backend_order) + 1) - 0.5, backend_cmap.N
    )
    backend_to_idx = {name: idx for idx, name in enumerate(backend_order)}
else:
    backend_order = []
    backend_cmap = None
    backend_norm = None
    backend_to_idx = None


def _heatmap_payload(summary: pd.DataFrame, value_column: str, *, index_lookup: dict[str, int] | None = None, label_column: str | None = None):
    if summary.empty:
        if label_column:
            return [], [], np.empty((0, 0)), []
        return [], [], np.empty((0, 0))
    qubits = sorted(summary["num_qubits"].unique())
    sparsities = sorted(summary["sparsity"].unique())
    data = np.full((len(sparsities), len(qubits)), np.nan)
    labels = [[tuple() for _ in qubits] for _ in sparsities] if label_column else None
    for i, sparsity in enumerate(sparsities):
        for j, nq in enumerate(qubits):
            row = summary[
                (summary["num_qubits"] == nq) & (summary["sparsity"] == sparsity)
            ]
            if row.empty:
                continue
            entry = row.iloc[0]
            value = entry[value_column]
            if label_column:
                raw_labels = entry.get(label_column)
                if isinstance(raw_labels, float) and np.isnan(raw_labels):
                    label_value = tuple()
                elif raw_labels is None:
                    label_value = tuple()
                elif isinstance(raw_labels, (list, tuple, set)):
                    label_value = tuple(raw_labels)
                else:
                    label_value = (str(raw_labels),)
                labels[i][j] = label_value
            if index_lookup is None:
                data[i, j] = value
            else:
                if pd.isna(value):
                    continue
                key = str(value)
                if key in index_lookup:
                    data[i, j] = index_lookup[key]
    if labels is None:
        return qubits, sparsities, data
    return qubits, sparsities, data, labels


titles = {False: "Distributed fragments", True: "Local fragments"}

fig_backend, backend_axes = plt.subplots(1, 2, figsize=(12, 4.5), sharex=True, sharey=True)
cell_backend_labels = {}
backend_axes = np.atleast_1d(backend_axes)
backend_im = None

for ax_idx, is_local in enumerate([False, True]):
    frame = cell_summary[cell_summary["is_local"] == is_local]
    if backend_mode:
        payload = _heatmap_payload(
            frame, summary_column, index_lookup=backend_to_idx, label_column="backend_labels"
        )
        qubits, sparsities, matrix, label_grid = payload
    else:
        payload = _heatmap_payload(frame, summary_column, index_lookup=backend_to_idx)
        qubits, sparsities, matrix = payload
        label_grid = None
    ax = backend_axes[ax_idx]
    if not qubits or not sparsities:
        ax.axis("off")
        ax.text(0.5, 0.5, "No data", ha="center", va="center")
        continue
    if backend_mode:
        masked = np.ma.masked_invalid(matrix)
        backend_im = ax.imshow(
            masked,
            cmap=backend_cmap,
            norm=backend_norm,
            aspect="auto",
            origin="lower",
        )
    else:
        masked = np.ma.masked_invalid(matrix)
        backend_im = ax.imshow(
            masked,
            cmap="viridis",
            aspect="auto",
            origin="lower",
            vmin=time_vmin,
            vmax=time_vmax,
        )
    if backend_mode and label_grid:
        for i, sparsity in enumerate(sparsities):
            for j, nq in enumerate(qubits):
                cell_backend_labels[(is_local, nq, sparsity)] = label_grid[i][j]
    ax.set_title(titles[is_local], fontsize=12, fontweight="bold")
    ax.set_xticks(range(len(qubits)))
    ax.set_xticklabels(qubits)
    ax.set_yticks(range(len(sparsities)))
    ax.set_yticklabels([f"{val:.2f}" for val in sparsities])
    if ax_idx == 0:
        ax.set_ylabel("sparsity")
    else:
        ax.set_yticklabels([])
    ax.set_xlabel("fragment qubits")

if backend_im is not None:
    if backend_mode:
        cbar_backend = fig_backend.colorbar(
            mpl.cm.ScalarMappable(cmap=backend_cmap, norm=backend_norm),
            ax=backend_axes,
            orientation="horizontal",
            pad=0.15,
        )
        cbar_backend.set_label(SUMMARY_LABELS[CELL_SUMMARY_MODE])
        cbar_backend.set_ticks(range(len(backend_order)))
        cbar_backend.set_ticklabels(backend_order)
    else:
        cbar_backend = fig_backend.colorbar(
            backend_im,
            ax=backend_axes,
            orientation="horizontal",
            pad=0.15,
        )
        cbar_backend.set_label(SUMMARY_LABELS[CELL_SUMMARY_MODE])
fig_backend.suptitle(
    f"{SUMMARY_LABELS[CELL_SUMMARY_MODE]} across the parameter sweep", fontsize=14
)
fig_backend.tight_layout(rect=[0, 0, 1, 0.94])

fig_runtime, runtime_axes = plt.subplots(1, 2, figsize=(12, 4.5), sharex=True, sharey=True)
runtime_axes = np.atleast_1d(runtime_axes)
runtime_im = None

for ax_idx, is_local in enumerate([False, True]):
    frame = cell_summary[cell_summary["is_local"] == is_local]
    qubits, sparsities, log_time = _heatmap_payload(frame, "mean_log_runtime")
    ax = runtime_axes[ax_idx]
    if not qubits or not sparsities:
        ax.axis("off")
        ax.text(0.5, 0.5, "No data", ha="center", va="center")
        continue
    masked_runtime = np.ma.masked_invalid(log_time)
    runtime_im = ax.imshow(
        masked_runtime,
        cmap="viridis",
        aspect="auto",
        origin="lower",
        vmin=time_vmin,
        vmax=time_vmax,
    )
    ax.set_title(titles[is_local], fontsize=12, fontweight="bold")
    ax.set_xticks(range(len(qubits)))
    ax.set_xticklabels(qubits)
    ax.set_yticks(range(len(sparsities)))
    ax.set_yticklabels([f"{val:.2f}" for val in sparsities])
    if ax_idx == 0:
        ax.set_ylabel("sparsity")
    else:
        ax.set_yticklabels([])
    ax.set_xlabel("fragment qubits")

if runtime_im is not None:
    cbar_runtime = fig_runtime.colorbar(
        runtime_im,
        ax=runtime_axes,
        orientation="horizontal",
        pad=0.15,
    )
    cbar_runtime.set_label("Mean log10 runtime estimate")
fig_runtime.suptitle(
    "Mean log10 runtime across the parameter sweep", fontsize=14
)
fig_runtime.tight_layout(rect=[0, 0, 1, 0.94])

plt.show()

tableau_variants = aggregated[aggregated["selected_backend"] == "TABLEAU"].copy()
if not tableau_variants.empty:
    print("Tableau-winning parameter combinations (aggregated view):")
    if cell_backend_labels:
        tableau_variants["cell_backends"] = tableau_variants.apply(
            lambda row: cell_backend_labels.get(
                (row["is_local"], row["num_qubits"], row["sparsity"]), tuple()
            ),
            axis=1,
        )
        tableau_variants["cell_backends"] = tableau_variants["cell_backends"].apply(
            lambda labels: " → ".join(labels) if labels else ""
        )
    tableau_variants["log10_time"] = safe_log10(tableau_variants["selected_time"])
    tableau_variants["log10_memory"] = safe_log10(tableau_variants["selected_memory"])
    tableau_columns = [
        "is_local",
        "is_clifford",
        "num_qubits",
        "depth",
        "entangling_ratio",
        "sparsity",
        "num_1q_gates",
        "num_2q_gates",
        "num_measurements",
        "selected_backend",
        "selected_time",
        "log10_time",
        "selected_memory",
        "log10_memory",
        "cell_backends",
    ]
    display(
        tableau_variants[tableau_columns]
        .sort_values(["is_local", "is_clifford", "num_qubits", "entangling_ratio", "sparsity"])
        .reset_index(drop=True)
    )
else:
    print("Tableau did not win for any sampled configuration in this sweep.")


### Backend feasibility maps

The heatmaps aggregate every combination of swept parameters before drawing a cell. All fragment variants that share the same size, sparsity and locality are grouped across depth, entangling ratio and the remaining sweep dimensions. The `CELL_SUMMARY_MODE` toggle controls the summary shown per cell: "mode_backend" displays the most common backend, "majority_backend" only labels a backend when it wins a strict majority (otherwise the cell is marked as `MIXED`), and "mean_runtime" renders the mean log10 runtime for the aggregated variants. Switching between these options highlights when the omitted dimensions drive different backend choices or materially change the runtime estimates.

A summary table beneath the figures highlights the parameter combinations where the Tableau simulator prevails, including the gate counts contributed by each fragment configuration.

In [None]:
import matplotlib as mpl
from matplotlib.colors import ListedColormap

CELL_SUMMARY_MODE = "mode_backend"  # Choices: "mode_backend", "majority_backend", "mean_runtime"
SUMMARY_LABELS = {
    "mode_backend": "Most common backend",
    "majority_backend": "Backend majority (>50%)",
    "mean_runtime": "Mean log10 runtime",
}
if CELL_SUMMARY_MODE not in SUMMARY_LABELS:
    raise ValueError(
        f"CELL_SUMMARY_MODE must be one of {', '.join(SUMMARY_LABELS)}, got {CELL_SUMMARY_MODE!r}"
    )

preferred_backend_order = [
    "STATEVECTOR",
    "DECISION_DIAGRAM",
    "MPS",
    "TN",
    "HYBRID",
    "CUSTOM",
]
color_lookup = {
    "STATEVECTOR": "#4477aa",
    "DECISION_DIAGRAM": "#66c2a5",
    "MPS": "#ffa600",
    "TN": "#aa3377",
    "HYBRID": "#ccbb44",
    "CUSTOM": "#7f7f7f",
    "MIXED": "#bdbdbd",
}
fallback_colors = ["#a1d99b", "#984ea3", "#f781bf", "#999999"]

selected_depth = _to_iterable(fragment_axes["depth"])[0]
subset = grid_results[grid_results["depth"] == selected_depth].copy()
if "phase_rotation_diversity" in subset and not subset["phase_rotation_diversity"].isna().all():
    subset = subset[
        subset["phase_rotation_diversity"]
        == _to_iterable(metric_axes["phase_rotation_diversity"])[0]
    ]
if "amplitude_rotation_diversity" in subset and not subset["amplitude_rotation_diversity"].isna().all():
    subset = subset[
        subset["amplitude_rotation_diversity"]
        == _to_iterable(metric_axes["amplitude_rotation_diversity"])[0]
    ]

sweep_columns = sorted({*fragment_axes.keys(), *metric_axes.keys()})
sweep_columns = [col for col in sweep_columns if col in subset.columns]


def _mode_or_first(series: pd.Series):
    cleaned = series.dropna()
    if cleaned.empty:
        return np.nan
    modes = cleaned.mode()
    if not modes.empty:
        return modes.iloc[0]
    return cleaned.iloc[0]


def _log_value(value: float) -> float:
    if value is None or not np.isfinite(value):
        return np.nan
    return float(safe_log10([value])[0])


if sweep_columns:
    aggregated = (
        subset.groupby(sweep_columns, dropna=False)
        .agg(
            selected_backend=("selected_backend", _mode_or_first),
            selected_time=("selected_time", "mean"),
            log_selected_time=("log_selected_time", "mean"),
        )
        .reset_index()
    )
else:
    aggregated = subset.copy()


def _summarise_variants(group: pd.DataFrame) -> pd.Series:
    backend_series = group["selected_backend"].dropna().astype(str)
    counts = backend_series.value_counts()
    total = int(counts.sum())
    if total:
        top_backend = counts.index[0]
        top_count = int(counts.iloc[0])
        majority_backend = top_backend if top_count > total / 2 else "MIXED"
        majority_fraction = top_count / total
    else:
        top_backend = np.nan
        majority_backend = np.nan
        majority_fraction = np.nan
    runtime_series = group["selected_time"].dropna().astype(float)
    if runtime_series.empty:
        mean_runtime = np.nan
        median_runtime = np.nan
    else:
        mean_runtime = runtime_series.mean()
        median_runtime = runtime_series.median()
    mean_log_runtime = _log_value(mean_runtime) if np.isfinite(mean_runtime) else np.nan
    median_log_runtime = _log_value(median_runtime) if np.isfinite(median_runtime) else np.nan
    log_series = group["log_selected_time"].dropna().astype(float)
    if log_series.size > 1:
        runtime_std = log_series.std(ddof=0)
    elif log_series.size == 1:
        runtime_std = 0.0
    else:
        runtime_std = np.nan
    return pd.Series(
        {
            "mode_backend": top_backend if total else np.nan,
            "majority_backend": majority_backend if total else np.nan,
            "majority_fraction": majority_fraction,
            "mean_runtime": mean_runtime,
            "median_runtime": median_runtime,
            "mean_log_runtime": mean_log_runtime,
            "median_log_runtime": median_log_runtime,
            "runtime_std": runtime_std,
            "num_variants": total,
            "unique_backends": len(counts),
        }
    )


if aggregated.empty:
    cell_summary = pd.DataFrame(
        columns=[
            "is_local",
            "num_qubits",
            "sparsity",
            "mode_backend",
            "majority_backend",
            "majority_fraction",
            "mean_runtime",
            "median_runtime",
            "mean_log_runtime",
            "median_log_runtime",
            "runtime_std",
            "num_variants",
            "unique_backends",
        ]
    )
else:
    cell_summary = (
        aggregated.groupby(["is_local", "num_qubits", "sparsity"], dropna=False)
        .apply(_summarise_variants)
        .reset_index()
    )

summary_column_map = {
    "mode_backend": "mode_backend",
    "majority_backend": "majority_backend",
    "mean_runtime": "mean_log_runtime",
}
summary_column = summary_column_map[CELL_SUMMARY_MODE]
backend_mode = CELL_SUMMARY_MODE in {"mode_backend", "majority_backend"}

runtime_values = cell_summary.get("mean_log_runtime")
if runtime_values is not None:
    runtime_values = runtime_values.to_numpy(dtype=float)
    runtime_values = runtime_values[np.isfinite(runtime_values)]
else:
    runtime_values = np.array([])
time_vmin = runtime_values.min() if runtime_values.size else None
time_vmax = runtime_values.max() if runtime_values.size else None

if backend_mode:
    backend_values = cell_summary.get(summary_column, pd.Series(dtype=object)).dropna().astype(str)
    backend_values = backend_values[backend_values.ne("TABLEAU") & backend_values.ne("")]
    observed_backends = list(dict.fromkeys(backend_values))
    backend_order = [b for b in preferred_backend_order if b in observed_backends]
    backend_order.extend([b for b in observed_backends if b not in backend_order and b != "MIXED"])
    if "MIXED" in observed_backends:
        backend_order.append("MIXED")
    if not backend_order:
        backend_order = ["STATEVECTOR"]
    colors = []
    fallback_pool = list(fallback_colors)
    for backend in backend_order:
        color = color_lookup.get(backend)
        if color is None:
            color = fallback_pool.pop(0) if fallback_pool else "#cccccc"
        colors.append(color)
    backend_cmap = ListedColormap(colors)
    backend_norm = mpl.colors.BoundaryNorm(
        np.arange(len(backend_order) + 1) - 0.5, backend_cmap.N
    )
    backend_to_idx = {name: idx for idx, name in enumerate(backend_order)}
else:
    backend_order = []
    backend_cmap = None
    backend_norm = None
    backend_to_idx = None


def _heatmap_payload(summary: pd.DataFrame, value_column: str, *, index_lookup: dict[str, int] | None = None):
    if summary.empty:
        return [], [], np.empty((0, 0))
    qubits = sorted(summary["num_qubits"].unique())
    sparsities = sorted(summary["sparsity"].unique())
    data = np.full((len(sparsities), len(qubits)), np.nan)
    for i, sparsity in enumerate(sparsities):
        for j, nq in enumerate(qubits):
            row = summary[
                (summary["num_qubits"] == nq) & (summary["sparsity"] == sparsity)
            ]
            if row.empty:
                continue
            value = row.iloc[0][value_column]
            if index_lookup is None:
                data[i, j] = value
            else:
                if pd.isna(value):
                    continue
                key = str(value)
                if key in index_lookup:
                    data[i, j] = index_lookup[key]
    return qubits, sparsities, data


titles = {False: "Distributed fragments", True: "Local fragments"}

fig_backend, backend_axes = plt.subplots(1, 2, figsize=(12, 4.5), sharex=True, sharey=True)
backend_axes = np.atleast_1d(backend_axes)
backend_im = None

for ax_idx, is_local in enumerate([False, True]):
    frame = cell_summary[cell_summary["is_local"] == is_local]
    qubits, sparsities, matrix = _heatmap_payload(
        frame, summary_column, index_lookup=backend_to_idx
    )
    ax = backend_axes[ax_idx]
    if not qubits or not sparsities:
        ax.axis("off")
        ax.text(0.5, 0.5, "No data", ha="center", va="center")
        continue
    if backend_mode:
        masked = np.ma.masked_invalid(matrix)
        backend_im = ax.imshow(
            masked,
            cmap=backend_cmap,
            norm=backend_norm,
            aspect="auto",
            origin="lower",
        )
    else:
        masked = np.ma.masked_invalid(matrix)
        backend_im = ax.imshow(
            masked,
            cmap="viridis",
            aspect="auto",
            origin="lower",
            vmin=time_vmin,
            vmax=time_vmax,
        )
    ax.set_title(titles[is_local], fontsize=12, fontweight="bold")
    ax.set_xticks(range(len(qubits)))
    ax.set_xticklabels(qubits)
    ax.set_yticks(range(len(sparsities)))
    ax.set_yticklabels([f"{val:.2f}" for val in sparsities])
    if ax_idx == 0:
        ax.set_ylabel("sparsity")
    else:
        ax.set_yticklabels([])
    ax.set_xlabel("fragment qubits")

if backend_im is not None:
    if backend_mode:
        cbar_backend = fig_backend.colorbar(
            mpl.cm.ScalarMappable(cmap=backend_cmap, norm=backend_norm),
            ax=backend_axes,
            orientation="horizontal",
            pad=0.15,
        )
        cbar_backend.set_label(SUMMARY_LABELS[CELL_SUMMARY_MODE])
        cbar_backend.set_ticks(range(len(backend_order)))
        cbar_backend.set_ticklabels(backend_order)
    else:
        cbar_backend = fig_backend.colorbar(
            backend_im,
            ax=backend_axes,
            orientation="horizontal",
            pad=0.15,
        )
        cbar_backend.set_label(SUMMARY_LABELS[CELL_SUMMARY_MODE])
fig_backend.suptitle(
    f"{SUMMARY_LABELS[CELL_SUMMARY_MODE]} across the parameter sweep", fontsize=14
)
fig_backend.tight_layout(rect=[0, 0, 1, 0.94])

fig_runtime, runtime_axes = plt.subplots(1, 2, figsize=(12, 4.5), sharex=True, sharey=True)
runtime_axes = np.atleast_1d(runtime_axes)
runtime_im = None

for ax_idx, is_local in enumerate([False, True]):
    frame = cell_summary[cell_summary["is_local"] == is_local]
    qubits, sparsities, log_time = _heatmap_payload(frame, "mean_log_runtime")
    ax = runtime_axes[ax_idx]
    if not qubits or not sparsities:
        ax.axis("off")
        ax.text(0.5, 0.5, "No data", ha="center", va="center")
        continue
    masked_runtime = np.ma.masked_invalid(log_time)
    runtime_im = ax.imshow(
        masked_runtime,
        cmap="viridis",
        aspect="auto",
        origin="lower",
        vmin=time_vmin,
        vmax=time_vmax,
    )
    ax.set_title(titles[is_local], fontsize=12, fontweight="bold")
    ax.set_xticks(range(len(qubits)))
    ax.set_xticklabels(qubits)
    ax.set_yticks(range(len(sparsities)))
    ax.set_yticklabels([f"{val:.2f}" for val in sparsities])
    if ax_idx == 0:
        ax.set_ylabel("sparsity")
    else:
        ax.set_yticklabels([])
    ax.set_xlabel("fragment qubits")

if runtime_im is not None:
    cbar_runtime = fig_runtime.colorbar(
        runtime_im,
        ax=runtime_axes,
        orientation="horizontal",
        pad=0.15,
    )
    cbar_runtime.set_label("Mean log10 runtime estimate")
fig_runtime.suptitle(
    "Mean log10 runtime across the parameter sweep", fontsize=14
)
fig_runtime.tight_layout(rect=[0, 0, 1, 0.94])

plt.show()


### Case study: two segments with a decision diagram prefix

A sparse, locally entangling prefix can often be simulated efficiently with the decision-diagram backend before converting into a dense statevector section.

In [None]:
resource_case1 = {"max_memory_gb": 64}
monolithic_stats = synthesise_fragment(34, 70, 0.35, is_local=False)
monolithic_plan = run_plan([monolithic_stats], [{"sparsity": 0.62}], resource_limits=resource_case1, estimator=ESTIMATOR)

fragment_a = synthesise_fragment(18, 44, 0.18, is_local=True, frontier_scale=0.25, chi=48)
fragment_b = synthesise_fragment(34, 32, 0.32, is_local=False)
boundary_ab = BoundarySpec(num_qubits=14, rank=48, frontier=20, window=8, window_1q_gates=60, window_2q_gates=16)
partition_plan = run_plan(
    [fragment_a, fragment_b],
    [{"sparsity": 0.88}, {"sparsity": 0.58}],
    boundaries=[boundary_ab],
    resource_limits=resource_case1,
    estimator=ESTIMATOR,
)

overview = pd.concat(
    [
        plan_overview("Monolithic statevector", monolithic_plan),
        plan_overview("DD → statevector", partition_plan),
    ],
    ignore_index=True,
)
overview["time_speedup_vs_monolithic"] = overview.loc[0, "total_time"] / overview["total_time"]
overview["memory_ratio_vs_monolithic"] = overview.loc[0, "peak_memory"] / overview["peak_memory"]
display(overview)

display(fragment_breakdown(partition_plan))
conversion_df = conversion_breakdown(partition_plan)
if not conversion_df.empty:
    display(conversion_df)

fig, axes = plt.subplots(1, 2, figsize=(10, 3))
axes[0].bar(overview["plan"], safe_log10(overview["total_time"]))
axes[0].set_ylabel("log10 total time")
axes[1].bar(overview["plan"], safe_log10(overview["peak_memory"]))
axes[1].set_ylabel("log10 peak memory")
plt.suptitle("Two-segment plan compared to the monolithic execution")
plt.tight_layout()
plt.show()


### Case study: three segments with conversions

A Clifford initialisation, a local MPS window and a dense finale illustrate how multiple conversions accumulate while still reducing projected runtime.

In [None]:
resource_case2 = {"max_memory_gb": 128}
mono_stats = synthesise_fragment(48, 90, 0.32, is_local=False)
mono_plan = run_plan([mono_stats], [{"sparsity": 0.58}], resource_limits=resource_case2, estimator=ESTIMATOR)

frag1 = synthesise_fragment(48, 18, 0.0, is_clifford=True, is_local=False)
frag2 = synthesise_fragment(32, 50, 0.18, is_local=True, frontier_scale=0.2, chi=48)
frag3 = synthesise_fragment(48, 36, 0.34, is_local=False)

boundary_12 = BoundarySpec(num_qubits=12, rank=32, frontier=24, window=6, window_1q_gates=40, window_2q_gates=12)
boundary_23 = BoundarySpec(num_qubits=18, rank=64, frontier=30, window=8, window_1q_gates=60, window_2q_gates=20)
three_plan = run_plan(
    [frag1, frag2, frag3],
    [{"sparsity": 0.95}, {"sparsity": 0.88}, {"sparsity": 0.55}],
    boundaries=[boundary_12, boundary_23],
    resource_limits=resource_case2,
    estimator=ESTIMATOR,
)

overview_three = pd.concat(
    [
        plan_overview("Monolithic statevector", mono_plan),
        plan_overview("Tableau → MPS → statevector", three_plan),
    ],
    ignore_index=True,
)
overview_three["time_speedup_vs_monolithic"] = overview_three.loc[0, "total_time"] / overview_three["total_time"]
overview_three["memory_ratio_vs_monolithic"] = overview_three.loc[0, "peak_memory"] / overview_three["peak_memory"]
display(overview_three)

display(fragment_breakdown(three_plan))
conv_three = conversion_breakdown(three_plan)
if not conv_three.empty:
    display(conv_three)

fig, axes = plt.subplots(1, 2, figsize=(10, 3))
axes[0].bar(overview_three["plan"], safe_log10(overview_three["total_time"]))
axes[0].set_ylabel("log10 total time")
axes[1].bar(overview_three["plan"], safe_log10(overview_three["peak_memory"]))
axes[1].set_ylabel("log10 peak memory")
plt.suptitle("Three-segment plan compared to the monolithic execution")
plt.tight_layout()
plt.show()


### Tableau snapshot for the Clifford fragment

The three-segment case study begins with a Clifford initialisation that is scheduled on the Tableau simulator. The table below compares its gate counts and projected cost with the subsequent fragments in the plan.

In [None]:
segment_labels = ["Clifford prefix", "Local tensor window", "Dense conclusion"]
fragment_stats_seq = [frag1, frag2, frag3]
segment_rows = []
for label, stats, selection in zip(segment_labels, fragment_stats_seq, three_plan["selections"]):
    backend, cost = selection
    segment_rows.append(
        {
            "segment": label,
            "backend": backend.name,
            "is_clifford": getattr(stats, "is_clifford", False),
            "num_qubits": getattr(stats, "num_qubits", None),
            "num_1q_gates": getattr(stats, "num_1q_gates", None),
            "num_2q_gates": getattr(stats, "num_2q_gates", None),
            "num_measurements": getattr(stats, "num_measurements", None),
            "time": cost.time,
            "memory": cost.memory,
        }
    )
clifford_profile = pd.DataFrame(segment_rows)
clifford_profile["log10_time"] = safe_log10(clifford_profile["time"])
clifford_profile["log10_memory"] = safe_log10(clifford_profile["memory"])
display(clifford_profile)

### Feature map for partition advantage

The following sweep varies the gate mix (fraction of entangling layers), conversion boundary width and Schmidt-rank cap.  It highlights where the model predicts a win for the two-fragment plan and which conversion primitive is selected.

In [None]:
partition_df = evaluate_partition_advantage(
    gate_mixes=[0.22, 0.28, 0.34, 0.4],
    boundary_qubits=[8, 12, 16, 20, 24, 28],
    ranks=[16, 32, 64],
    estimator=ESTIMATOR,
)
partition_df.head()


In [None]:
advantage = partition_df[partition_df["partition_wins"]].copy()
losses = partition_df[~partition_df["partition_wins"]].copy()

summary = (
    advantage.groupby(["gate_mix", "primitive"])
    .agg(
        min_boundary=("boundary_qubits", "min"),
        max_speedup=("speedup", "max"),
    )
    .reset_index()
    .sort_values(["gate_mix", "min_boundary"])
)
summary


In [None]:
if not losses.empty:
    loss_summary = (
        losses.groupby(["gate_mix", "primitive"])
        .agg(
            min_boundary=("boundary_qubits", "min"),
            max_boundary=("boundary_qubits", "max"),
            worst_speedup=("speedup", "min"),
        )
        .reset_index()
        .sort_values(["gate_mix", "min_boundary"])
    )
    display(loss_summary)
else:
    print("Partitioning won for every sampled configuration.")


In [None]:
fig, axes = plt.subplots(1, len(sorted(partition_df["rank"].unique())), figsize=(15, 3), sharey=True)
for ax, rank in zip(np.atleast_1d(axes), sorted(partition_df["rank"].unique())):
    view = partition_df[partition_df["rank"] == rank]
    pivot = view.pivot_table(index="gate_mix", columns="boundary_qubits", values="speedup", aggfunc="mean")
    im = ax.imshow(np.log10(pivot.values), cmap="viridis", aspect="auto", origin="lower")
    ax.set_title(f"rank ≤ {rank}")
    ax.set_xticks(range(len(pivot.columns)), pivot.columns)
    ax.set_yticks(range(len(pivot.index)), [f"{v:.2f}" for v in pivot.index])
    ax.set_xlabel("boundary qubits")
axes[0].set_ylabel("gate mix (entangling fraction)")
fig.colorbar(im, ax=axes, orientation="horizontal", fraction=0.04, pad=0.1, label="log10 speedup")
plt.suptitle("Speedup heatmap across boundary sizes and gate mixes")
plt.tight_layout()
plt.show()


The remaining tables characterise losing configurations and conversion choices.  When the initial fragment is forced onto the dense statevector backend, the conversions disappear and the partition offers no benefit, underscoring the importance of sparsity or locality for heterogeneous plans.