# Cost estimator playground

Interactively explore how QuASAr's :class:`~quasar.cost.CostEstimator` reacts to
changes in fragment metrics.  Use the dropdown to load metrics sampled from real
circuits and tweak the sliders to see how conversion primitives and backend
estimates respond.

In [None]:
from __future__ import annotations

import json
import math
from pathlib import Path
from typing import Dict, Tuple

import ipywidgets as widgets
import pandas as pd
from IPython.display import Markdown, display

from quasar.cost import Backend, CostEstimator


def _find_data_file() -> Path:
    candidates = [
        Path("conversion_scenarios.json"),
        Path("data") / "conversion_scenarios.json",
        Path("docs") / "data" / "conversion_scenarios.json",
    ]
    for candidate in candidates:
        if candidate.exists():
            return candidate
    raise FileNotFoundError("conversion_scenarios.json not found")


def _load_scenarios(path: Path) -> Tuple[Dict[str, dict], list]:
    data = json.loads(path.read_text(encoding="utf8"))
    mapping = {entry["name"]: entry for entry in data}
    return mapping, data


data_path = _find_data_file()
SCENARIO_MAP, SCENARIOS = _load_scenarios(data_path)

estimator = CostEstimator()

scenario_dropdown = widgets.Dropdown(
    options=["Custom"] + [entry["name"] for entry in SCENARIOS],
    description="Scenario",
    value=SCENARIOS[0]["name"] if SCENARIOS else "Custom",
)

max_qubits = max((entry.get("q", 8) for entry in SCENARIOS), default=16)
q_slider = widgets.IntSlider(
    value=max(1, SCENARIOS[0]["q"] if SCENARIOS else 6),
    min=1,
    max=max(4 * max_qubits, 32),
    step=1,
    description="q",
)

s_slider = widgets.IntSlider(
    value=max(1, SCENARIOS[0]["s"] if SCENARIOS else 4),
    min=1,
    max=2048,
    step=1,
    description="s",
)

chi_slider = widgets.IntSlider(
    value=32,
    min=1,
    max=256,
    step=1,
    description="χ_cap",
)

sparsity_slider = widgets.FloatSlider(
    value=SCENARIOS[0]["sparsity"] if SCENARIOS else 0.5,
    min=0.0,
    max=0.999,
    step=0.001,
    readout_format=".3f",
    description="sparsity",
)

rotation_slider = widgets.FloatSlider(
    value=SCENARIOS[0]["rotation_diversity"] if SCENARIOS else 0.0,
    min=0.0,
    max=64.0,
    step=0.5,
    readout_format=".1f",
    description="rot. div.",
)

backend_options = [(backend.name, backend) for backend in Backend]

source_backend = widgets.Dropdown(
    options=backend_options,
    value=Backend.DECISION_DIAGRAM,
    description="source",
)

target_backend = widgets.Dropdown(
    options=backend_options,
    value=Backend.MPS,
    description="target",
)

summary_out = widgets.Output()
conversion_out = widgets.Output()
simulation_out = widgets.Output()

def _scenario_data() -> dict | None:
    selection = scenario_dropdown.value
    return SCENARIO_MAP.get(selection)


def _gate_counts(q: int) -> Tuple[int, int, int, int]:
    data = _scenario_data()
    if data:
        return (
            int(data.get("num_1q", 0)),
            int(data.get("num_2q", 0)),
            int(data.get("num_meas", 0)),
            int(data.get("num_gates", 0)),
        )
    approx_1q = max(q, 1) * 4
    approx_2q = max(q - 1, 0) * 2
    total = approx_1q + approx_2q
    return approx_1q, approx_2q, 0, total


def _rotation_breakdown() -> Tuple[float, float]:
    data = _scenario_data()
    if data:
        return (
            float(data.get("phase_rotation_diversity", 0.0)),
            float(data.get("amplitude_rotation_diversity", 0.0)),
        )
    total = rotation_slider.value
    return total / 2.0, total / 2.0


def _update_from_scenario(change):
    data = _scenario_data()
    if not data:
        update_dashboard()
        return
    with widgets.hold_trait_notifications(q_slider, s_slider, sparsity_slider, rotation_slider):
        q_slider.value = int(data.get("q", q_slider.value))
        s_slider.value = int(data.get("s", s_slider.value))
        sparsity_slider.value = float(data.get("sparsity", sparsity_slider.value))
        rotation_slider.value = float(data.get("rotation_diversity", rotation_slider.value))
    update_dashboard()


def update_dashboard(*_):
    summary_out.clear_output()
    conversion_out.clear_output()
    simulation_out.clear_output()

    q = q_slider.value
    rank = min(s_slider.value, 2 ** q)
    chi_cap = chi_slider.value
    sparsity = sparsity_slider.value
    total_rotation = rotation_slider.value
    phase_rot, amp_rot = _rotation_breakdown()
    data = _scenario_data()
    num_1q, num_2q, num_meas, num_gates = _gate_counts(q)
    ent_entropy = math.log2(rank) if rank > 0 else 0.0
    long_range_fraction = float(data.get("long_range_fraction", 0.0)) if data else 0.0
    long_range_extent = float(data.get("long_range_extent", 0.0)) if data else 0.0
    two_qubit_ratio = num_2q / max(num_gates, 1)

    with summary_out:
        rows = [
            ("q (boundary qubits)", q),
            ("s (Schmidt rank)", rank),
            ("χ_cap (staging)", chi_cap),
            ("sparsity", f"{sparsity:.3f}"),
            ("rotation diversity", f"{total_rotation:.1f}"),
        ]
        if data:
            rows.append(("scenario backend", data.get("backend", "-")))
        summary_df = pd.DataFrame(rows, columns=["Metric", "Value"])
        display(Markdown("### Selected metrics"))
        display(summary_df)

    conv_details = estimator.conversion_candidates(
        source_backend.value,
        target_backend.value,
        num_qubits=q,
        rank=rank,
        frontier=q,
        chi_cap=chi_cap,
        compressed_terms=rank,
    )

    conv_rows = []
    for primitive, detail in conv_details.items():
        conv_rows.append(
            {
                "Primitive": primitive,
                "Time": detail.cost.time,
                "Memory": detail.cost.memory,
                "Window": detail.window if detail.window is not None else "-",
                "χ_cap": detail.chi_cap if detail.chi_cap is not None else "-",
                "Stages": detail.stages if detail.stages is not None else "-",
            }
        )
    conv_df = pd.DataFrame(conv_rows).set_index("Primitive")

    with conversion_out:
        display(Markdown("### Conversion primitive estimates"))
        display(conv_df.style.format({"Time": "{:.3e}", "Memory": "{:.3e}"}))

    sim_rows = []
    sv_cost = estimator.statevector(
        q,
        num_1q,
        num_2q,
        num_meas,
        sparsity=sparsity,
        rotation_diversity=total_rotation,
        entanglement_entropy=ent_entropy,
    )
    sim_rows.append({"Backend": "STATEVECTOR", "Time": sv_cost.time, "Memory": sv_cost.memory})

    tab_cost = estimator.tableau(
        q,
        num_gates,
        num_meas=num_meas,
        depth=num_gates,
        rotation_diversity=total_rotation,
    )
    sim_rows.append({"Backend": "TABLEAU", "Time": tab_cost.time, "Memory": tab_cost.memory})

    mps_cost = estimator.mps(
        q,
        num_1q,
        num_2q,
        chi=max(1, min(rank, chi_cap)),
        entanglement_entropy=ent_entropy,
        sparsity=sparsity,
        rotation_diversity=total_rotation,
        long_range_fraction=long_range_fraction,
        long_range_extent=long_range_extent,
    )
    sim_rows.append({"Backend": "MPS", "Time": mps_cost.time, "Memory": mps_cost.memory})

    dd_cost = estimator.decision_diagram(
        num_gates,
        q,
        sparsity=sparsity,
        phase_rotation_diversity=phase_rot,
        amplitude_rotation_diversity=amp_rot,
        entanglement_entropy=ent_entropy,
        two_qubit_ratio=two_qubit_ratio,
    )
    sim_rows.append({"Backend": "DECISION_DIAGRAM", "Time": dd_cost.time, "Memory": dd_cost.memory})

    sim_df = pd.DataFrame(sim_rows).set_index("Backend")

    with simulation_out:
        display(Markdown("### Simulation cost estimates"))
        display(sim_df.style.format({"Time": "{:.3e}", "Memory": "{:.3e}"}))


scenario_dropdown.observe(_update_from_scenario, names="value")
for control in (
    q_slider,
    s_slider,
    chi_slider,
    sparsity_slider,
    rotation_slider,
    source_backend,
    target_backend,
):
    control.observe(update_dashboard, names="value")

_update_from_scenario({"new": scenario_dropdown.value})

controls = widgets.VBox(
    [
        scenario_dropdown,
        widgets.HBox([q_slider, s_slider, chi_slider]),
        widgets.HBox([sparsity_slider, rotation_slider]),
        widgets.HBox([source_backend, target_backend]),
    ]
)

outputs = widgets.VBox([summary_out, conversion_out, simulation_out])

ui = widgets.HBox([controls, outputs])
display(ui)
