# Partitioning theory plots

This notebook regenerates the theoretical partitioning figures used throughout the QuASAr documentation.

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from docs.utils.partitioning_analysis import (
    build_clifford_fragment_curves,
    build_statevector_partition_tradeoff,
    build_statevector_vs_mps,
    export_figure,
    load_calibrated_estimator,
)

project_root = Path.cwd().resolve()
if not (project_root / "quasar").exists():
    project_root = project_root.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

sns.set_theme(context="talk", style="whitegrid", palette="colorblind", font_scale=0.9)

COLOR_PALETTE = sns.color_palette("colorblind", n_colors=6)
COLOR_LOOKUP = {
    "Statevector": COLOR_PALETTE[0],
    "Statevector only": COLOR_PALETTE[0],
    "Partitioned with conversions": COLOR_PALETTE[2],
    "Partitioned": COLOR_PALETTE[2],
    "Tableau": COLOR_PALETTE[1],
    "MPS (χ=4)": COLOR_PALETTE[3],
}
LEGEND_ORDER = {
    "Statevector": 0,
    "Statevector only": 0,
    "Partitioned with conversions": 1,
    "Partitioned": 1,
    "Tableau": 2,
    "MPS (χ=4)": 3,
}

def color_for(label: str):
    return COLOR_LOOKUP.get(label, COLOR_PALETTE[-1])


def apply_consistent_legend(ax):
    handles, labels = ax.get_legend_handles_labels()
    if not handles:
        return
    ordering = []
    for handle, label in zip(handles, labels):
        sort_key = (LEGEND_ORDER.get(label, 99), label)
        ordering.append((sort_key, handle, label))
    ordering.sort(key=lambda entry: entry[0])
    seen = set()
    sorted_handles = []
    sorted_labels = []
    for _, handle, label in ordering:
        if label in seen:
            continue
        seen.add(label)
        sorted_handles.append(handle)
        sorted_labels.append(label)
    ax.legend(sorted_handles, sorted_labels, loc="upper left")


def set_plot_theme():
    sns.set_theme(context="talk", style="whitegrid", palette=COLOR_PALETTE, font_scale=0.9)


estimator, calibration_path = load_calibrated_estimator()
if calibration_path:
    print(f"Loaded calibration coefficients from {calibration_path}")
else:
    print("Using built-in CostEstimator defaults (no calibration file found)")

In [None]:
set_plot_theme()
curves = build_clifford_fragment_curves(estimator)
fig, ax = plt.subplots(figsize=(7.0, 4.0))
ax.plot(
    curves["num_qubits"],
    curves["statevector"],
    label="Statevector",
    linewidth=2.2,
    color=color_for("Statevector"),
)
ax.plot(
    curves["num_qubits"],
    curves["tableau"],
    label="Tableau",
    linewidth=2.2,
    color=color_for("Tableau"),
)
ax.set_xlim(curves["num_qubits"][0], curves["num_qubits"][-1])
y_max = float(np.max(np.concatenate((curves["statevector"], curves["tableau"]))))
ax.set_ylim(0, y_max * 1.1)
threshold = curves["threshold"]
if threshold is not None:
    idx = int(np.where(curves["num_qubits"] == threshold)[0][0])
    y_val = curves["tableau"][idx]
    ax.axvline(threshold, color="black", linestyle="--", linewidth=1.2)
    x_text = min(curves["num_qubits"][-1], threshold + 1.5)
    ax.annotate(
        f"Tableau cheaper ≥ {threshold} qubits",
        xy=(threshold, y_val),
        xytext=(x_text, y_val * 1.05),
        arrowprops=dict(arrowstyle="->", linewidth=1.0),
        fontsize=10,
    )
    print(f"Tableau becomes cheaper from {threshold} qubits onwards.")
else:
    print("Tableau path never overtakes statevector in the sampled range.")
ax.set_xlabel("Active qubits")
ax.set_ylabel("Estimated runtime (arb. units)")
ax.set_title("Clifford fragment crossover")
apply_consistent_legend(ax)
fig.tight_layout()
export_figure(fig, "clifford_crossover")
plt.show()

In [None]:
set_plot_theme()
tradeoff = build_statevector_partition_tradeoff(estimator)
fig, ax = plt.subplots(figsize=(7.5, 4.2))
ax.plot(
    tradeoff["num_qubits"],
    tradeoff["statevector"],
    label="Statevector only",
    linewidth=2.2,
    color=color_for("Statevector only"),
)
ax.plot(
    tradeoff["num_qubits"],
    tradeoff["partitioned"],
    label="Partitioned with conversions",
    linewidth=2.2,
    color=color_for("Partitioned with conversions"),
)
ax.set_xlim(tradeoff["num_qubits"][0], tradeoff["num_qubits"][-1])
y_max = float(np.max(np.concatenate((tradeoff["statevector"], tradeoff["partitioned"]))))
ax.set_ylim(0, y_max * 1.12)
threshold = tradeoff["threshold"]
if threshold is not None:
    idx = int(np.where(tradeoff["num_qubits"] == threshold)[0][0])
    boundary = int(tradeoff["boundary"][idx])
    rank = int(tradeoff["rank"][idx])
    y_val = tradeoff["partitioned"][idx]
    ax.axvline(threshold, color="black", linestyle="--", linewidth=1.2)
    x_text = min(tradeoff["num_qubits"][-1], threshold + 2)
    ax.annotate(
        f"Switch at q={boundary} (rank≤{rank})",
        xy=(threshold, y_val),
        xytext=(x_text, y_val * 1.05),
        arrowprops=dict(arrowstyle="->", linewidth=1.0),
        fontsize=10,
    )
    print(f"Partitioned execution becomes cheaper from {threshold} qubits onwards.")
else:
    print("Partitioned execution is never cheaper in this range.")
ax.set_xlabel("Active qubits")
ax.set_ylabel("Estimated runtime (arb. units)")
ax.set_title("Statevector vs. tableau with conversions")
apply_consistent_legend(ax)
fig.tight_layout()
export_figure(fig, "statevector_tableau_partition")
plt.show()

In [None]:
set_plot_theme()
mps_curves = build_statevector_vs_mps(estimator)
fig, ax = plt.subplots(figsize=(7.0, 4.0))
ax.plot(
    mps_curves["num_qubits"],
    mps_curves["statevector"],
    label="Statevector",
    linewidth=2.2,
    color=color_for("Statevector"),
)
ax.plot(
    mps_curves["num_qubits"],
    mps_curves["mps"],
    label="MPS (χ=4)",
    linewidth=2.2,
    color=color_for("MPS (χ=4)"),
)
ax.set_xlim(mps_curves["num_qubits"][0], mps_curves["num_qubits"][-1])
y_max = float(np.max(np.concatenate((mps_curves["statevector"], mps_curves["mps"]))))
ax.set_ylim(0, y_max * 1.12)
threshold = mps_curves["threshold"]
if threshold is not None:
    idx = int(np.where(mps_curves["num_qubits"] == threshold)[0][0])
    y_val = mps_curves["mps"][idx]
    ax.axvline(threshold, color="black", linestyle="--", linewidth=1.2)
    x_text = min(mps_curves["num_qubits"][-1], threshold + 1.5)
    ax.annotate(
        f"MPS cheaper ≥ {threshold} qubits",
        xy=(threshold, y_val),
        xytext=(x_text, y_val * 1.05),
        arrowprops=dict(arrowstyle="->", linewidth=1.0),
        fontsize=10,
    )
    print(f"MPS simulation becomes cheaper from {threshold} qubits onwards.")
else:
    print("MPS path never overtakes statevector in the sampled range.")
ax.set_xlabel("Active qubits")
ax.set_ylabel("Estimated runtime (arb. units)")
ax.set_title("Statevector vs. χ=4 MPS")
apply_consistent_legend(ax)
fig.tight_layout()
export_figure(fig, "statevector_vs_mps")
plt.show()