In [1]:
#!/usr/bin/env python3
from __future__ import annotations

import csv
import os
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

try:
    import cudaq  # type: ignore
    _CUDAQ_AVAILABLE = True
except Exception:
    cudaq = None  # type: ignore
    _CUDAQ_AVAILABLE = False
try:
    import cupy as cp  # type: ignore
    _CUPY_AVAILABLE = True
except Exception:
    cp = None  # type: ignore
    _CUPY_AVAILABLE = False


CONFIG = {
    "MODE": "phase-a",  # "phase-a" or "phase-b"
    "OUT_DIR": "out_stats",
    "PHASE_A_CSV": "out_stats/phase_a_samplesnewMTShard.csv",
    "PHASE_B_FULL_CSV": "out_stats/phase_b_fullnew.csv",
    "PHASE_B_CSV": "out_stats/phase_b_outliersnew.csv",
    "TARGET": "nvidia",
    "TARGET_FALLBACKS": ["qpp", "qpp-cpu"],
    "USE_GPU_STATS": True,
    "USE_GPU_RANDOM": True,
    "USE_GPU_ENERGY": True,
    "CUPY_USE_MEMORY_POOL": True,
    "RANDOM_UNIQUE": False,
    "CHUNK_SIZE": 4000,
    "HEARTBEAT_EVERY_S": 20.0,
    "N_LIST": [16, 18, 20, 22, 24]
    "STEPS_LIST": [10],
    "DT": 0.05,
    "COST_SCALES": [0.05],
    "CD_SCALES": [0.1],
    "MIX_SCALE": 1.0,
    "SHOTS_DAQO": 1000,
    "SHOTS_DCQO": 1000,
    "SHOTS_RANDOM": 100,
    "SHOTS_MATCHED": True,
    "REPEATS": 100,
    "BASE_SEED": 1234,
    "METHODS_A": ["DAQO", "DCQO", "RANDOM"],
    "METHODS_B": ["DAQO", "DCQO"],
    "DCQO_MAX_N": 20,
    "DAQO_MAX_N": 24,
    "OUTLIER_PCTL": 0.04,
    "OUTLIER_MIN_GROUP": 10,
    "OUTLIER_MAX_PER_GROUP": 0,
    "POP_SIZE": 16,
    "N_RAND_PAD": 10,
    "PCOMB": 0.9,
    "PMUTATE_FACTOR": 1.0,
    "TARGET_E": 0,
    "USE_EOPT_TARGET": True,
    "MTS_TIME_BUDGET_S": 2.0,
    "MTS_MAX_OUTER_ITERS": 50000,
    "TABU_MAX_ITERS": 48,
    "TABU_TENURE": 6,
    "TABU_SAMPLE_K": 0,
    "TABU_PLATEAU": 25,
    "TABU_ALLOW_WORSEN": True,
}

if _CUPY_AVAILABLE and CONFIG.get("CUPY_USE_MEMORY_POOL", True):
    try:
        cp.cuda.set_allocator(cp.cuda.MemoryPool().malloc)
    except Exception:
        pass

# Optimal LABS energies (E_opt) for open boundary conditions.
# Values from Mertens open.dat (see LABS literature).
EOPT_LABS = {
    3: 1, 4: 2, 5: 2, 6: 7, 7: 3, 8: 8, 9: 12, 10: 13,
    11: 5, 12: 10, 13: 6, 14: 19, 15: 15, 16: 24, 17: 32, 18: 25,
    19: 29, 20: 26, 21: 26, 22: 39, 23: 47, 24: 36, 25: 36, 26: 45,
    27: 37, 28: 50, 29: 62, 30: 59,
}


def get_target_e(n: int) -> int:
    if bool(CONFIG.get("USE_EOPT_TARGET", True)):
        e = EOPT_LABS.get(int(n))
        if e is not None:
            return int(e)
    return int(CONFIG.get("TARGET_E", 0))


def bs_to_bits(bs: str) -> List[int]:
    return [1 if c == "1" else 0 for c in bs.strip()]


def bitstrings_to_bits_array(bitstrings: List[str], n: int) -> np.ndarray:
    if not bitstrings:
        return np.zeros((0, int(n)), dtype=np.int8)
    joined = "".join(bitstrings)
    arr = np.frombuffer(joined.encode("ascii"), dtype=np.uint8)
    arr = arr.reshape((len(bitstrings), int(n)))
    return (arr == ord("1")).astype(np.int8)


def labs_energy(bits: List[int]) -> int:
    s = np.asarray([1 if b else -1 for b in bits], dtype=np.int16)
    n = int(s.shape[0])
    e = 0
    for k in range(1, n):
        ck = int(np.sum(s[: n - k] * s[k:]))
        e += ck * ck
    return int(e)


def labs_energy_batch_numpy(pop: Any) -> np.ndarray:
    if len(pop) == 0:
        return np.zeros((0,), dtype=np.int64)
    bits = np.asarray(pop, dtype=np.int8)
    k, n = bits.shape
    s = (2 * bits - 1).astype(np.int16)
    e = np.zeros((k,), dtype=np.int64)
    for d in range(1, n):
        ck = np.sum(s[:, : n - d] * s[:, d:], axis=1, dtype=np.int64)
        e += ck * ck
    return e


def labs_energy_batch_cupy(pop: Any) -> np.ndarray:
    if not _CUPY_AVAILABLE:
        raise RuntimeError("CuPy not available.")
    if len(pop) == 0:
        return np.zeros((0,), dtype=np.int64)
    if isinstance(pop, cp.ndarray):
        bits_d = pop
        if bits_d.dtype != cp.int8:
            bits_d = bits_d.astype(cp.int8, copy=False)
    else:
        bits = np.asarray(pop, dtype=np.int8)
        bits_d = cp.asarray(bits)
    k, n = bits_d.shape
    s = (2 * bits_d - 1).astype(cp.int16)
    e = cp.zeros((k,), dtype=cp.int64)
    for d in range(1, n):
        ck = cp.sum(s[:, : n - d] * s[:, d:], axis=1, dtype=cp.int64)
        e += ck * ck
    return cp.asnumpy(e)


def labs_energy_batch(pop: Any, use_gpu: bool = False) -> np.ndarray:
    return labs_energy_batch_cupy(pop) if (use_gpu and _CUPY_AVAILABLE) else labs_energy_batch_numpy(pop)


@dataclass(frozen=True)
class LabsCostTerms:
    const: float
    pairs: List[Tuple[int, int, float]]
    quads: List[Tuple[int, int, int, int, float]]


@dataclass(frozen=True)
class CostArrays:
    pairs_i: List[int]
    pairs_j: List[int]
    pairs_w: List[float]
    quads_a: List[int]
    quads_b: List[int]
    quads_c: List[int]
    quads_d: List[int]
    quads_w: List[float]

    @property
    def num_pairs(self) -> int:
        return len(self.pairs_i)

    @property
    def num_quads(self) -> int:
        return len(self.quads_a)


_TERMS_CACHE: Dict[int, LabsCostTerms] = {}
_COST_CACHE: Dict[int, CostArrays] = {}
_CD_LISTS_CACHE: Dict[int, Tuple[List[int], List[int], List[int], List[int], List[int], List[int]]] = {}
_SCHEDULE_CACHE: Dict[Tuple[int, float], Tuple[List[float], List[float], float]] = {}
_THETA_CACHE: Dict[Tuple[int, float], List[float]] = {}
_TARGET_LAST: Optional[str] = None


def labs_cost_terms(n: int) -> LabsCostTerms:
    n = int(n)
    const = 0.0
    pair_map: Dict[Tuple[int, int], float] = {}
    quad_map: Dict[Tuple[int, int, int, int], float] = {}
    for k in range(1, n):
        m = n - k
        for i in range(m):
            for j in range(m):
                a, b = i, i + k
                c, d = j, j + k
                if i == j:
                    const += 1.0
                    continue
                idx = [a, b, c, d]
                counts: Dict[int, int] = {}
                for t in idx:
                    counts[t] = counts.get(t, 0) + 1
                remaining = []
                for t, ct in counts.items():
                    if ct % 2 == 1:
                        remaining.append(t)
                remaining.sort()
                if len(remaining) == 0:
                    const += 1.0
                elif len(remaining) == 2:
                    i1, i2 = remaining
                    if i1 > i2:
                        i1, i2 = i2, i1
                    pair_map[(i1, i2)] = pair_map.get((i1, i2), 0.0) + 1.0
                elif len(remaining) == 4:
                    key = tuple(remaining)  # type: ignore
                    quad_map[key] = quad_map.get(key, 0.0) + 1.0
                else:
                    raise RuntimeError("Unexpected remaining indices length.")
    pairs = [(i, j, float(w)) for (i, j), w in pair_map.items()]
    quads = [(a, b, c, d, float(w)) for (a, b, c, d), w in quad_map.items()]
    return LabsCostTerms(const=float(const), pairs=pairs, quads=quads)


def terms_to_cost_arrays(terms: LabsCostTerms) -> CostArrays:
    pairs_i = [int(i) for (i, _, _) in terms.pairs]
    pairs_j = [int(j) for (_, j, _) in terms.pairs]
    pairs_w = [float(w) for (_, _, w) in terms.pairs]
    quads_a = [int(a) for (a, _, _, _, _) in terms.quads]
    quads_b = [int(b) for (_, b, _, _, _) in terms.quads]
    quads_c = [int(c) for (_, _, c, _, _) in terms.quads]
    quads_d = [int(d) for (_, _, _, d, _) in terms.quads]
    quads_w = [float(w) for (_, _, _, _, w) in terms.quads]
    return CostArrays(
        pairs_i=pairs_i, pairs_j=pairs_j, pairs_w=pairs_w,
        quads_a=quads_a, quads_b=quads_b, quads_c=quads_c, quads_d=quads_d, quads_w=quads_w,
    )


def cd_interactions_lists(n: int) -> Tuple[List[int], List[int], List[int], List[int], List[int], List[int]]:
    n = int(n)
    g2_i: List[int] = []
    g2_j: List[int] = []
    g4_a: List[int] = []
    g4_b: List[int] = []
    g4_c: List[int] = []
    g4_d: List[int] = []
    for i0 in range(0, n - 2):
        kmax = (n - i0 - 1) // 2
        for k in range(1, kmax + 1):
            g2_i.append(i0)
            g2_j.append(i0 + k)
    for i0 in range(0, n - 3):
        tmax = (n - i0 - 2) // 2
        for t in range(1, tmax + 1):
            kmax = n - i0 - t - 1
            for k in range(t + 1, kmax + 1):
                g4_a.append(i0)
                g4_b.append(i0 + t)
                g4_c.append(i0 + k)
                g4_d.append(i0 + k + t)
    return g2_i, g2_j, g4_a, g4_b, g4_c, g4_d


def linear_schedule_fixed_dt(steps: int, dt: float) -> Tuple[List[float], List[float], float]:
    steps = int(steps)
    dt = float(dt)
    ttot = steps * dt
    gammas: List[float] = []
    betas: List[float] = []
    for p in range(steps):
        t_mid = (p + 0.5) * dt
        s = t_mid / ttot
        gammas.append(dt * s)
        betas.append(dt * (1.0 - s))
    return gammas, betas, ttot


def get_cost_arrays(n: int) -> CostArrays:
    n = int(n)
    cached = _COST_CACHE.get(n)
    if cached is not None:
        return cached
    terms = _TERMS_CACHE.get(n)
    if terms is None:
        terms = labs_cost_terms(n)
        _TERMS_CACHE[n] = terms
    cost = terms_to_cost_arrays(terms)
    _COST_CACHE[n] = cost
    return cost


def get_cd_lists(n: int) -> Tuple[List[int], List[int], List[int], List[int], List[int], List[int]]:
    n = int(n)
    cached = _CD_LISTS_CACHE.get(n)
    if cached is not None:
        return cached
    cd = cd_interactions_lists(n)
    _CD_LISTS_CACHE[n] = cd
    return cd


def get_schedule(steps: int, dt: float) -> Tuple[List[float], List[float], float]:
    key = (int(steps), float(dt))
    cached = _SCHEDULE_CACHE.get(key)
    if cached is not None:
        return cached
    out = linear_schedule_fixed_dt(steps, dt)
    _SCHEDULE_CACHE[key] = out
    return out


def get_thetas(steps: int, dt: float, ttot: Optional[float] = None) -> List[float]:
    key = (int(steps), float(dt))
    cached = _THETA_CACHE.get(key)
    if cached is not None:
        return cached
    if ttot is None:
        _, _, ttot = get_schedule(steps, dt)
    thetas: List[float] = []
    for p in range(int(steps)):
        t_mid = (p + 0.5) * float(dt)
        s = t_mid / float(ttot)
        thetas.append(float(dt) * float(s * (1.0 - s)))
    _THETA_CACHE[key] = thetas
    return thetas


def _set_target(target: Optional[str]) -> None:
    if not _CUDAQ_AVAILABLE or target is None:
        return
    global _TARGET_LAST
    if _TARGET_LAST == target:
        return
    try:
        cudaq.set_target(target)
    except Exception:
        return
    _TARGET_LAST = target


def choose_target(preferred: str, fallbacks: List[str]) -> str:
    if not _CUDAQ_AVAILABLE:
        return preferred
    for tgt in [preferred] + list(fallbacks):
        if not tgt:
            continue
        try:
            cudaq.set_target(tgt)
            return tgt
        except Exception:
            continue
    return preferred


def _merge_counts(dst: Dict[str, int], src: Dict[str, int]) -> None:
    for bs, ct in src.items():
        dst[bs] = dst.get(bs, 0) + int(ct)


def sample_chunked(sample_fn, shots: int, chunk_size: int, **kwargs) -> Tuple[Dict[str, int], float]:
    shots = int(shots)
    if shots <= int(chunk_size):
        return sample_fn(shots=shots, **kwargs)
    combined: Dict[str, int] = {}
    tval = 0.0
    remaining = shots
    while remaining > 0:
        chunk = min(remaining, int(chunk_size))
        counts, tval = sample_fn(shots=chunk, **kwargs)
        _merge_counts(combined, counts)
        remaining -= chunk
    return combined, float(tval)

_DAQO_RT_KERNEL_CACHE: Dict[int, Any] = {}
_DCQO_RT_KERNEL_CACHE: Dict[int, Any] = {}


def _get_daqo_kernel_rt(n: int):
    n = int(n)
    ker = _DAQO_RT_KERNEL_CACHE.get(n)
    if ker is not None:
        return ker

    @cudaq.kernel
    def daqo_rt(
        pairs_i: list[int], pairs_j: list[int], pairs_w: list[float],
        quads_a: list[int], quads_b: list[int], quads_c: list[int], quads_d: list[int], quads_w: list[float],
        gammas: list[float], betas: list[float],
        steps: int, num_pairs: int, num_quads: int,
        cost_scale: float, mix_scale: float
    ):
        q = cudaq.qvector(n)
        h(q)
        for p in range(steps):
            gamma = gammas[p]
            beta = betas[p]
            for qi in range(n):
                rx(beta * mix_scale, q[qi])
            for t in range(num_pairs):
                ii = pairs_i[t]; jj = pairs_j[t]; ww = pairs_w[t]
                x.ctrl(q[ii], q[jj])
                rz(2.0 * gamma * ww * cost_scale, q[jj])
                x.ctrl(q[ii], q[jj])
            for t in range(num_quads):
                qa = quads_a[t]; qb = quads_b[t]; qc = quads_c[t]; qd = quads_d[t]; ww = quads_w[t]
                x.ctrl(q[qa], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qc], q[qd])
                rz(2.0 * gamma * ww * cost_scale, q[qd])
                x.ctrl(q[qc], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qa], q[qd])
            for qi in range(n):
                rx(beta * mix_scale, q[qi])
        mz(q)

    _DAQO_RT_KERNEL_CACHE[n] = daqo_rt
    return daqo_rt


def _get_dcqo_kernel_rt(n: int):
    n = int(n)
    ker = _DCQO_RT_KERNEL_CACHE.get(n)
    if ker is not None:
        return ker

    @cudaq.kernel
    def dcqo_rt(
        cd_pi: list[int], cd_pj: list[int],
        cd_qa: list[int], cd_qb: list[int], cd_qc: list[int], cd_qd: list[int],
        cost_pi: list[int], cost_pj: list[int], cost_pw: list[float],
        cost_qa: list[int], cost_qb: list[int], cost_qc: list[int], cost_qd: list[int], cost_qw: list[float],
        gammas: list[float], betas: list[float], thetas: list[float],
        steps: int,
        num_cd_pairs: int, num_cd_quads: int,
        num_cost_pairs: int, num_cost_quads: int,
        cost_scale: float, cd_scale: float, mix_scale: float
    ):
        pio2 = 1.5707963267948966
        q = cudaq.qvector(n)
        h(q)
        for p in range(steps):
            beta = betas[p] * mix_scale
            gamma = gammas[p] * cost_scale
            theta = thetas[p] * cd_scale
            for qi in range(n):
                rx(2.0 * beta, q[qi])
            for t in range(num_cost_pairs):
                ii = cost_pi[t]; jj = cost_pj[t]; ww = cost_pw[t]
                phi = gamma * ww
                x.ctrl(q[ii], q[jj])
                rz(2.0 * phi, q[jj])
                x.ctrl(q[ii], q[jj])
            for t in range(num_cost_quads):
                qa = cost_qa[t]; qb = cost_qb[t]; qc = cost_qc[t]; qd = cost_qd[t]; ww = cost_qw[t]
                phi = gamma * ww
                x.ctrl(q[qa], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qc], q[qd])
                rz(2.0 * phi, q[qd])
                x.ctrl(q[qc], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qa], q[qd])
            for t in range(num_cd_pairs):
                ii = cd_pi[t]; jj = cd_pj[t]
                rx(pio2, q[ii])
                x.ctrl(q[ii], q[jj]); rz(4.0 * theta, q[jj]); x.ctrl(q[ii], q[jj])
                rx(-pio2, q[ii])
                rx(pio2, q[jj])
                x.ctrl(q[ii], q[jj]); rz(4.0 * theta, q[jj]); x.ctrl(q[ii], q[jj])
                rx(-pio2, q[jj])
            for t in range(num_cd_quads):
                qa = cd_qa[t]; qb = cd_qb[t]; qc = cd_qc[t]; qd = cd_qd[t]
                rx(pio2, q[qa])
                x.ctrl(q[qa], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qc], q[qd])
                rz(8.0 * theta, q[qd])
                x.ctrl(q[qc], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qa], q[qd])
                rx(-pio2, q[qa])
                rx(pio2, q[qb])
                x.ctrl(q[qa], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qc], q[qd])
                rz(8.0 * theta, q[qd])
                x.ctrl(q[qc], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qa], q[qd])
                rx(-pio2, q[qb])
                rx(pio2, q[qc])
                x.ctrl(q[qa], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qc], q[qd])
                rz(8.0 * theta, q[qd])
                x.ctrl(q[qc], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qa], q[qd])
                rx(-pio2, q[qc])
                rx(pio2, q[qd])
                x.ctrl(q[qa], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qc], q[qd])
                rz(8.0 * theta, q[qd])
                x.ctrl(q[qc], q[qd]); x.ctrl(q[qb], q[qd]); x.ctrl(q[qa], q[qd])
                rx(-pio2, q[qd])
        mz(q)

    _DCQO_RT_KERNEL_CACHE[n] = dcqo_rt
    return dcqo_rt


def _sample_daqo_counts_once(n: int, shots: int, steps: int, dt: float, cost_scale: float, mix_scale: float, target: str) -> Tuple[Dict[str, int], float]:
    if not _CUDAQ_AVAILABLE:
        raise RuntimeError("cudaq not available.")
    _set_target(target)
    cost = get_cost_arrays(n)
    gammas, betas, ttot = get_schedule(steps, dt)
    ker = _get_daqo_kernel_rt(n)
    res = cudaq.sample(
        ker,
        cost.pairs_i, cost.pairs_j, cost.pairs_w,
        cost.quads_a, cost.quads_b, cost.quads_c, cost.quads_d, cost.quads_w,
        gammas, betas,
        int(steps), int(cost.num_pairs), int(cost.num_quads),
        float(cost_scale), float(mix_scale),
        shots_count=int(shots),
    )
    return dict(res.items()), float(ttot)


def _sample_dcqo_counts_once(n: int, shots: int, steps: int, dt: float, cost_scale: float, cd_scale: float, mix_scale: float, target: str) -> Tuple[Dict[str, int], float]:
    if not _CUDAQ_AVAILABLE:
        raise RuntimeError("cudaq not available.")
    _set_target(target)
    cost = get_cost_arrays(n)
    cdpi, cdpj, cdqa, cdqb, cdqc, cdqd = get_cd_lists(n)
    gammas, betas, ttot = get_schedule(steps, dt)
    thetas = get_thetas(steps, dt, ttot)
    ker = _get_dcqo_kernel_rt(n)
    res = cudaq.sample(
        ker,
        cdpi, cdpj, cdqa, cdqb, cdqc, cdqd,
        cost.pairs_i, cost.pairs_j, cost.pairs_w,
        cost.quads_a, cost.quads_b, cost.quads_c, cost.quads_d, cost.quads_w,
        gammas, betas, thetas,
        int(steps),
        int(len(cdpi)), int(len(cdqa)),
        int(cost.num_pairs), int(cost.num_quads),
        float(cost_scale), float(cd_scale), float(mix_scale),
        shots_count=int(shots),
    )
    return dict(res.items()), float(ttot)


def sample_daqo_counts(n: int, shots: int, steps: int, dt: float, cost_scale: float, mix_scale: float, target: str) -> Tuple[Dict[str, int], float]:
    chunk = int(CONFIG.get("CHUNK_SIZE", 0) or 0)
    if chunk and int(shots) > chunk:
        return sample_chunked(
            _sample_daqo_counts_once,
            shots=int(shots),
            chunk_size=chunk,
            n=int(n),
            steps=int(steps),
            dt=float(dt),
            cost_scale=float(cost_scale),
            mix_scale=float(mix_scale),
            target=target,
        )
    return _sample_daqo_counts_once(int(n), int(shots), int(steps), float(dt), float(cost_scale), float(mix_scale), target)


def sample_dcqo_counts(n: int, shots: int, steps: int, dt: float, cost_scale: float, cd_scale: float, mix_scale: float, target: str) -> Tuple[Dict[str, int], float]:
    chunk = int(CONFIG.get("CHUNK_SIZE", 0) or 0)
    if chunk and int(shots) > chunk:
        return sample_chunked(
            _sample_dcqo_counts_once,
            shots=int(shots),
            chunk_size=chunk,
            n=int(n),
            steps=int(steps),
            dt=float(dt),
            cost_scale=float(cost_scale),
            cd_scale=float(cd_scale),
            mix_scale=float(mix_scale),
            target=target,
        )
    return _sample_dcqo_counts_once(int(n), int(shots), int(steps), float(dt), float(cost_scale), float(cd_scale), float(mix_scale), target)


def weighted_quantile(values: np.ndarray, weights: np.ndarray, q: float) -> float:
    if len(values) == 0:
        return float("nan")
    order = np.argsort(values)
    v = values[order]
    w = weights[order]
    cdf = np.cumsum(w) / np.sum(w)
    idx = int(np.searchsorted(cdf, q, side="left"))
    idx = min(max(idx, 0), len(v) - 1)
    return float(v[idx])


def counts_stats(counts: Dict[str, int], n: int) -> Dict[str, float]:
    if not counts:
        return {"unique": 0.0, "bestE": float("nan"), "q10": float("nan"), "q50": float("nan"), "q90": float("nan")}
    keys = list(counts.keys())
    bits = bitstrings_to_bits_array(keys, int(n))
    es_np = labs_energy_batch(bits, use_gpu=bool(CONFIG.get("USE_GPU_STATS", False))).astype(float)
    ws_np = np.fromiter((counts[k] for k in keys), dtype=float, count=len(keys))
    best_e = float(np.min(es_np)) if es_np.size else float("nan")
    return {
        "unique": float(len(counts)),
        "bestE": float(best_e),
        "q10": weighted_quantile(es_np, ws_np, 0.10),
        "q50": weighted_quantile(es_np, ws_np, 0.50),
        "q90": weighted_quantile(es_np, ws_np, 0.90),
    }


def random_stats(n: int, shots: int, rng: np.random.Generator) -> Dict[str, float]:
    shots = int(shots)
    if shots <= 0:
        return {"unique": 0.0, "bestE": float("nan"), "q10": float("nan"), "q50": float("nan"), "q90": float("nan")}
    use_gpu = bool(CONFIG.get("USE_GPU_RANDOM", False)) and _CUPY_AVAILABLE
    if use_gpu:
        bits_d = cp.random.randint(0, 2, size=(shots, int(n)), dtype=cp.int8)
        e = labs_energy_batch_cupy(bits_d).astype(float)
        unique = float("nan") if not CONFIG.get("RANDOM_UNIQUE", False) else float(cp.unique(bits_d, axis=0).shape[0])
    else:
        bits = rng.integers(0, 2, size=(shots, int(n)), dtype=np.int8)
        e = labs_energy_batch_numpy(bits).astype(float)
        unique = float("nan") if not CONFIG.get("RANDOM_UNIQUE", False) else float(np.unique(bits, axis=0).shape[0])
    w = np.ones_like(e)
    return {
        "unique": unique,
        "bestE": float(np.min(e)),
        "q10": weighted_quantile(e, w, 0.10),
        "q50": weighted_quantile(e, w, 0.50),
        "q90": weighted_quantile(e, w, 0.90),
    }

@dataclass
class TabuConfig:
    max_iters: int = 80
    tenure: int = 6
    aspiration: bool = True
    sample_k: int = 0
    plateau_limit: int = 25
    allow_worsen: bool = True


@dataclass
class MTSConfig:
    population_size: int = 64
    pcomb: float = 0.9
    pmutate: float = 0.1
    target_e: int = 0
    max_outer_iters: int = 1500
    tabu: TabuConfig = field(default_factory=TabuConfig)


def mutate(bits: List[int], pmutate: float, rng: np.random.Generator) -> List[int]:
    if pmutate <= 0.0 or not bits:
        return bits[:]
    arr = np.asarray(bits, dtype=np.int8)
    mask = rng.random(arr.shape[0]) < float(pmutate)
    if mask.any():
        arr ^= mask.astype(np.int8)
    return arr.tolist()


def crossover_half(p1: List[int], p2: List[int]) -> List[int]:
    n = len(p1)
    cut = n // 2
    return p1[:cut] + p2[cut:]


def _bits_to_spins(bits: List[int]) -> np.ndarray:
    b = np.asarray(bits, dtype=np.int8)
    return (2 * b - 1).astype(np.int8)


def _corrs_from_spins(s: np.ndarray) -> np.ndarray:
    n = int(s.shape[0])
    c = np.zeros((n,), dtype=np.int32)
    for k in range(1, n):
        c[k] = int(np.sum(s[: n - k] * s[k:], dtype=np.int64))
    return c


def _deltaE_for_flip(s: np.ndarray, c: np.ndarray, i: int) -> int:
    n = int(s.shape[0])
    si = int(s[i])
    de = 0
    for k in range(1, n):
        acc = 0
        j = i + k
        if j < n:
            acc += int(s[j])
        j = i - k
        if j >= 0:
            acc += int(s[j])
        dc = -2 * si * acc
        ck = int(c[k])
        de += 2 * ck * dc + dc * dc
    return int(de)


def _apply_flip_update_c(s: np.ndarray, c: np.ndarray, i: int) -> None:
    n = int(s.shape[0])
    si_old = int(s[i])
    for k in range(1, n):
        acc = 0
        j = i + k
        if j < n:
            acc += int(s[j])
        j = i - k
        if j >= 0:
            acc += int(s[j])
        dc = -2 * si_old * acc
        c[k] = np.int32(int(c[k]) + dc)
    s[i] = np.int8(-si_old)


def tabu_search(start_bits: List[int], cfg: TabuConfig, rng: Optional[np.random.Generator] = None) -> Tuple[List[int], int]:
    if rng is None:
        rng = np.random.default_rng()
    bits = start_bits[:]
    n = len(bits)
    if n == 0:
        return bits, 0
    s = _bits_to_spins(bits)
    c = _corrs_from_spins(s)
    e = int(np.sum(c[1:].astype(np.int64) ** 2))
    best_bits = bits[:]
    best_e = e
    tabu_until = np.zeros((n,), dtype=np.int64)
    no_improve = 0
    for it in range(int(cfg.max_iters)):
        if int(cfg.sample_k) > 0 and int(cfg.sample_k) < n:
            cand = rng.choice(n, size=int(cfg.sample_k), replace=False)
        else:
            cand = range(n)
        best_move = None
        best_move_e = None
        for i in cand:
            i = int(i)
            e_new = e + _deltaE_for_flip(s, c, i)
            is_tabu = it < int(tabu_until[i])
            if is_tabu and cfg.aspiration and e_new >= best_e:
                continue
            if best_move_e is None or e_new < best_move_e:
                best_move = i
                best_move_e = int(e_new)
        if best_move is None:
            for i in range(n):
                e_new = e + _deltaE_for_flip(s, c, i)
                if best_move_e is None or e_new < best_move_e:
                    best_move = i
                    best_move_e = int(e_new)
        assert best_move is not None and best_move_e is not None
        if (not cfg.allow_worsen) and best_move_e > e:
            for i in range(n):
                e_new = e + _deltaE_for_flip(s, c, i)
                if e_new <= e:
                    best_move = i
                    best_move_e = int(e_new)
                    break
        bits[int(best_move)] ^= 1
        _apply_flip_update_c(s, c, int(best_move))
        e = int(best_move_e)
        tabu_until[int(best_move)] = it + int(cfg.tenure)
        if e < best_e:
            best_e = e
            best_bits = bits[:]
            no_improve = 0
        else:
            no_improve += 1
            if int(cfg.plateau_limit) > 0 and no_improve >= int(cfg.plateau_limit):
                break
    return best_bits, int(best_e)


def memetic_tabu_search(population: List[List[int]], cfg: MTSConfig, rng: Optional[np.random.Generator] = None, time_budget_s: Optional[float] = None) -> Tuple[List[int], int, Dict[str, float]]:
    if rng is None:
        rng = np.random.default_rng()
    k = len(population)
    if k == 0:
        raise ValueError("Population is empty.")
    t0 = time.perf_counter()
    energies = labs_energy_batch(population, use_gpu=bool(CONFIG.get("USE_GPU_ENERGY", False))).astype(np.int64)
    best_idx = int(np.argmin(energies))
    best_bits = population[best_idx][:]
    best_e = int(energies[best_idx])
    t_hit = float("nan")
    success = 0.0
    if best_e <= int(cfg.target_e):
        t_hit = 0.0
        success = 1.0
    outer = 0
    for outer in range(int(cfg.max_outer_iters)):
        if success:
            break
        if best_e <= int(cfg.target_e):
            break
        if time_budget_s is not None and (time.perf_counter() - t0) >= time_budget_s:
            break
        if rng.random() < float(cfg.pcomb):
            i1 = int(rng.integers(0, k))
            i2 = int(rng.integers(0, k - 1))
            if i2 >= i1:
                i2 += 1
            child = crossover_half(population[i1], population[i2])
        else:
            child = population[int(rng.integers(0, k))][:]
        child = mutate(child, float(cfg.pmutate), rng)
        child, child_e = tabu_search(child, cfg.tabu, rng=rng)
        if child_e < best_e:
            best_e = int(child_e)
            best_bits = child[:]
            if best_e <= int(cfg.target_e):
                t_hit = time.perf_counter() - t0
                success = 1.0
        r = int(rng.integers(0, k))
        population[r] = child
        energies[r] = int(child_e)
    elapsed = time.perf_counter() - t0
    stats = {
        "outer_iters": float(outer + 1),
        "elapsed_s": float(elapsed),
        "best_e": float(best_e),
        "t_hit_s": float(t_hit),
        "success": float(success),
    }
    return best_bits, int(best_e), stats


def build_population_from_counts(counts: Dict[str, int], n: int, k: int, rng: np.random.Generator, n_rand_pad: int = 10) -> List[List[int]]:
    if not counts:
        return [rng.integers(0, 2, size=(int(n),), dtype=np.int8).tolist() for _ in range(int(k))]
    keys = list(counts.keys())
    weights = np.fromiter((counts[kb] for kb in keys), dtype=np.int64, count=len(keys))
    bits = bitstrings_to_bits_array(keys, int(n))
    energies = labs_energy_batch(bits, use_gpu=bool(CONFIG.get("USE_GPU_STATS", False))).astype(np.int64)
    order_e = np.lexsort((-weights, energies))
    order_c = np.argsort(-weights)

    pop: List[List[int]] = []
    seen = set()
    take_total = max(0, k - int(n_rand_pad))
    take_e = take_total // 2
    take_c = take_total - take_e

    for idx in order_e:
        t = tuple(bits[idx].tolist())
        if t in seen:
            continue
        pop.append(list(t))
        seen.add(t)
        if len(pop) >= take_e:
            break

    for idx in order_c:
        t = tuple(bits[idx].tolist())
        if t in seen:
            continue
        pop.append(list(t))
        seen.add(t)
        if len(pop) >= take_e + take_c:
            break

    while len(pop) < k:
        pop.append(rng.integers(0, 2, size=(int(n),), dtype=np.int8).tolist())
    return pop

def run_phase_a() -> None:
    out_dir = CONFIG["OUT_DIR"]
    os.makedirs(out_dir, exist_ok=True)
    out_csv = CONFIG["PHASE_A_CSV"]
    if not _CUDAQ_AVAILABLE:
        raise SystemExit("CUDA-Q not available.")
    target = choose_target(CONFIG["TARGET"], list(CONFIG["TARGET_FALLBACKS"]))
    print(f"[phase-a] target={target}")

    heartbeat_every = float(CONFIG.get("HEARTBEAT_EVERY_S", 0.0) or 0.0)
    start_time = time.perf_counter()
    last_heartbeat = start_time

    def heartbeat(note: str) -> None:
        nonlocal last_heartbeat
        if not heartbeat_every:
            return
        now = time.perf_counter()
        if (now - last_heartbeat) < heartbeat_every:
            return
        elapsed = now - start_time
        print(f"[phase-a] t={elapsed:.1f}s {note}")
        last_heartbeat = now

    fieldnames = [
        "method", "rep", "seed",
        "N", "steps", "dt", "T",
        "cost_scale", "cd_scale", "mix_scale",
        "shots", "sample_elapsed_s", "shots_per_s",
        "unique", "bestE_sample", "q10", "q50", "q90",
        "status", "error",
    ]
    with open(out_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        reps = int(CONFIG["REPEATS"])
        base_seed = int(CONFIG["BASE_SEED"])
        for n in CONFIG["N_LIST"]:
            for steps in CONFIG["STEPS_LIST"]:
                for cost_scale in CONFIG["COST_SCALES"]:
                    for cd_scale in CONFIG["CD_SCALES"]:
                        shots_daqo = int(CONFIG["SHOTS_DAQO"])
                        shots_dcqo = shots_daqo if bool(CONFIG.get("SHOTS_MATCHED", True)) else int(CONFIG["SHOTS_DCQO"])
                        print(f"[phase-a] case N={n} steps={steps} cost={cost_scale} cd={cd_scale}")
                        for rep in range(reps):
                            seed = base_seed + 100000 * int(n) + 1000 * int(steps) + 17 * int(rep)
                            rng = np.random.default_rng(seed)
                            ttot = float(steps) * float(CONFIG["DT"])

                            if "DAQO" in CONFIG["METHODS_A"]:
                                if n > int(CONFIG["DAQO_MAX_N"]):
                                    w.writerow({"method": "DAQO", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                                "cost_scale": cost_scale, "cd_scale": 0.0, "mix_scale": CONFIG["MIX_SCALE"], "shots": shots_daqo,
                                                "status": "skipped", "error": "DAQO_MAX_N"})
                                else:
                                    t0 = time.perf_counter()
                                    try:
                                        counts, _ = sample_daqo_counts(int(n), shots_daqo, int(steps), float(CONFIG["DT"]),
                                                                       float(cost_scale), float(CONFIG["MIX_SCALE"]), target)
                                        dt = time.perf_counter() - t0
                                        st = counts_stats(counts, int(n))
                                        w.writerow({"method": "DAQO", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                                    "cost_scale": cost_scale, "cd_scale": 0.0, "mix_scale": CONFIG["MIX_SCALE"], "shots": shots_daqo,
                                                    "sample_elapsed_s": dt, "shots_per_s": float(shots_daqo) / max(dt, 1e-12),
                                                    "unique": st["unique"], "bestE_sample": st["bestE"], "q10": st["q10"], "q50": st["q50"], "q90": st["q90"],
                                                    "status": "ok", "error": ""})
                                    except Exception as exc:
                                        w.writerow({"method": "DAQO", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                                    "cost_scale": cost_scale, "cd_scale": 0.0, "mix_scale": CONFIG["MIX_SCALE"], "shots": shots_daqo,
                                                    "status": "error", "error": str(exc)})

                            if "DCQO" in CONFIG["METHODS_A"]:
                                if n > int(CONFIG["DCQO_MAX_N"]):
                                    w.writerow({"method": "DCQO", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                                "cost_scale": cost_scale, "cd_scale": cd_scale, "mix_scale": CONFIG["MIX_SCALE"], "shots": shots_dcqo,
                                                "status": "skipped", "error": "DCQO_MAX_N"})
                                else:
                                    t0 = time.perf_counter()
                                    try:
                                        counts, _ = sample_dcqo_counts(int(n), shots_dcqo, int(steps), float(CONFIG["DT"]),
                                                                       float(cost_scale), float(cd_scale), float(CONFIG["MIX_SCALE"]), target)
                                        dt = time.perf_counter() - t0
                                        st = counts_stats(counts, int(n))
                                        w.writerow({"method": "DCQO", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                                    "cost_scale": cost_scale, "cd_scale": cd_scale, "mix_scale": CONFIG["MIX_SCALE"], "shots": shots_dcqo,
                                                    "sample_elapsed_s": dt, "shots_per_s": float(shots_dcqo) / max(dt, 1e-12),
                                                    "unique": st["unique"], "bestE_sample": st["bestE"], "q10": st["q10"], "q50": st["q50"], "q90": st["q90"],
                                                    "status": "ok", "error": ""})
                                    except Exception as exc:
                                        w.writerow({"method": "DCQO", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                                    "cost_scale": cost_scale, "cd_scale": cd_scale, "mix_scale": CONFIG["MIX_SCALE"], "shots": shots_dcqo,
                                                    "status": "error", "error": str(exc)})

                            if "RANDOM" in CONFIG["METHODS_A"]:
                                t0 = time.perf_counter()
                                st = random_stats(int(n), int(CONFIG["SHOTS_RANDOM"]), rng)
                                dt = time.perf_counter() - t0
                                w.writerow({"method": "RANDOM", "rep": rep, "seed": seed, "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
                                            "cost_scale": cost_scale, "cd_scale": 0.0, "mix_scale": CONFIG["MIX_SCALE"], "shots": CONFIG["SHOTS_RANDOM"],
                                            "sample_elapsed_s": dt, "shots_per_s": float(CONFIG["SHOTS_RANDOM"]) / max(dt, 1e-12),
                                            "unique": st["unique"], "bestE_sample": st["bestE"], "q10": st["q10"], "q50": st["q50"], "q90": st["q90"],
                                            "status": "ok", "error": ""})
                            f.flush()
                            heartbeat(f"rep={rep} N={n} steps={steps} cost={cost_scale} cd={cd_scale}")
    print(f"[phase-a] wrote {out_csv}")


def _to_float(v: Any) -> float:
    try:
        return float(v)
    except Exception:
        return float("nan")


def run_phase_b() -> None:
    out_dir = CONFIG["OUT_DIR"]
    os.makedirs(out_dir, exist_ok=True)
    in_csv = CONFIG["PHASE_A_CSV"]
    out_full_csv = CONFIG["PHASE_B_FULL_CSV"]
    out_csv = CONFIG["PHASE_B_CSV"]
    if not _CUDAQ_AVAILABLE:
        raise SystemExit("CUDA-Q not available.")
    target = choose_target(CONFIG["TARGET"], list(CONFIG["TARGET_FALLBACKS"]))
    print(f"[phase-b] target={target}")

    heartbeat_every = float(CONFIG.get("HEARTBEAT_EVERY_S", 0.0) or 0.0)
    start_time = time.perf_counter()
    last_heartbeat = start_time

    def heartbeat(note: str) -> None:
        nonlocal last_heartbeat
        if not heartbeat_every:
            return
        now = time.perf_counter()
        if (now - last_heartbeat) < heartbeat_every:
            return
        elapsed = now - start_time
        print(f"[phase-b] t={elapsed:.1f}s {note}")
        last_heartbeat = now

    rows: List[Dict[str, Any]] = []
    with open(in_csv, "r", newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            if row.get("status") != "ok":
                continue
            if row.get("method") not in CONFIG["METHODS_B"]:
                continue
            row["sample_elapsed_s"] = _to_float(row.get("sample_elapsed_s"))
            row["shots"] = int(float(row.get("shots", 0)))
            row["N"] = int(float(row.get("N", 0)))
            row["steps"] = int(float(row.get("steps", 0)))
            row["cost_scale"] = _to_float(row.get("cost_scale"))
            row["cd_scale"] = _to_float(row.get("cd_scale"))
            row["seed"] = int(float(row.get("seed", 0)))
            rows.append(row)

    fieldnames = [
        "method", "rep", "seed",
        "N", "steps", "dt", "T",
        "cost_scale", "cd_scale", "mix_scale",
        "shots", "sample_elapsed_s_a", "sample_elapsed_s_b",
        "total_elapsed_s", "tts_s",
        "success", "tts_hit_s", "mts_t_hit_s",
        "target_e",
        "bestE_sample", "q10", "q50", "q90",
        "K", "seed_bestE", "mts_bestE", "mts_elapsed_s", "mts_outer_iters",
        "tau_dcqo_s", "fast_outlier",
    ]
    results: List[Dict[str, Any]] = []
    for row in rows:
        n = int(row["N"])
        steps = int(row["steps"])
        cost_scale = float(row["cost_scale"])
        cd_scale = float(row["cd_scale"])
        seed = int(row["seed"])
        shots = int(row["shots"])
        ttot = float(steps) * float(CONFIG["DT"])
        rng = np.random.default_rng(seed)
        counts: Dict[str, int] = {}
        st: Dict[str, float] = {}
        samp_dt = float("nan")
        if row["method"] == "DAQO":
            t0 = time.perf_counter()
            counts, _ = sample_daqo_counts(n, shots, steps, float(CONFIG["DT"]), cost_scale, float(CONFIG["MIX_SCALE"]), target)
            samp_dt = time.perf_counter() - t0
            st = counts_stats(counts, n)
        elif row["method"] == "DCQO":
            t0 = time.perf_counter()
            counts, _ = sample_dcqo_counts(n, shots, steps, float(CONFIG["DT"]), cost_scale, cd_scale, float(CONFIG["MIX_SCALE"]), target)
            samp_dt = time.perf_counter() - t0
            st = counts_stats(counts, n)
        else:
            continue

        pop = build_population_from_counts(counts, n, int(CONFIG["POP_SIZE"]), rng, int(CONFIG["N_RAND_PAD"]))
        seed_E = int(np.min(labs_energy_batch(pop, use_gpu=bool(CONFIG.get("USE_GPU_ENERGY", False)))))
        target_e = get_target_e(n)
        tabu = TabuConfig(
            max_iters=int(CONFIG["TABU_MAX_ITERS"]),
            tenure=int(CONFIG["TABU_TENURE"]),
            sample_k=int(CONFIG["TABU_SAMPLE_K"]),
            plateau_limit=int(CONFIG["TABU_PLATEAU"]),
            allow_worsen=bool(CONFIG["TABU_ALLOW_WORSEN"]),
        )
        cfg = MTSConfig(
            population_size=int(CONFIG["POP_SIZE"]),
            pcomb=float(CONFIG["PCOMB"]),
            pmutate=float(CONFIG["PMUTATE_FACTOR"]) / float(n),
            target_e=int(target_e),
            max_outer_iters=int(CONFIG["MTS_MAX_OUTER_ITERS"]),
            tabu=tabu,
        )
        t1 = time.perf_counter()
        _, bestE, mts_stats = memetic_tabu_search(pop, cfg, rng=rng, time_budget_s=float(CONFIG["MTS_TIME_BUDGET_S"]))
        mts_dt = time.perf_counter() - t1
        mts_t_hit_s = float(mts_stats.get("t_hit_s", float("nan"))) if mts_stats else float("nan")
        success = 1 if float(mts_stats.get("success", 0.0)) >= 0.5 else 0
        tts_hit_s = float(samp_dt) + float(mts_t_hit_s) if success and np.isfinite(mts_t_hit_s) else float("nan")
        total_elapsed_s = float(samp_dt) + float(mts_dt)
        results.append({
            "method": row["method"], "rep": row.get("rep"), "seed": seed,
            "N": n, "steps": steps, "dt": CONFIG["DT"], "T": ttot,
            "cost_scale": cost_scale, "cd_scale": cd_scale, "mix_scale": CONFIG["MIX_SCALE"],
            "shots": shots,
            "sample_elapsed_s_a": row.get("sample_elapsed_s"),
            "sample_elapsed_s_b": samp_dt,
            "total_elapsed_s": total_elapsed_s,
            "tts_s": total_elapsed_s,
            "success": success,
            "tts_hit_s": tts_hit_s,
            "mts_t_hit_s": mts_t_hit_s,
            "target_e": target_e,
            "bestE_sample": st.get("bestE"), "q10": st.get("q10"), "q50": st.get("q50"), "q90": st.get("q90"),
            "K": CONFIG["POP_SIZE"],
            "seed_bestE": seed_E,
            "mts_bestE": bestE,
            "mts_elapsed_s": mts_dt,
            "mts_outer_iters": int(mts_stats.get("outer_iters", float("nan"))) if mts_stats else "",
            "tau_dcqo_s": "",
            "fast_outlier": 0,
        })
        heartbeat(f"method={row['method']} N={n} steps={steps} cost={cost_scale} cd={cd_scale}")

    # Compute DCQO threshold per (N, steps, cost_scale); apply to all methods.
    dcqo_times: Dict[Tuple[int, int, float], List[float]] = {}
    for row in results:
        if row["method"] != "DCQO":
            continue
        if not row.get("success"):
            continue
        t_hit = row.get("tts_hit_s")
        if t_hit is None or not np.isfinite(float(t_hit)):
            continue
        key = (int(row["N"]), int(row["steps"]), float(row["cost_scale"]))
        dcqo_times.setdefault(key, []).append(float(t_hit))

    tau_by_key: Dict[Tuple[int, int, float], float] = {}
    for key, times in dcqo_times.items():
        if len(times) < int(CONFIG["OUTLIER_MIN_GROUP"]):
            continue
        tau_by_key[key] = float(np.quantile(np.asarray(times, dtype=float), float(CONFIG["OUTLIER_PCTL"])))

    selected: List[Dict[str, Any]] = []
    for row in results:
        key = (int(row["N"]), int(row["steps"]), float(row["cost_scale"]))
        tau = tau_by_key.get(key)
        row["tau_dcqo_s"] = tau if tau is not None else ""
        is_fast = 1 if (tau is not None and row.get("success") and float(row.get("tts_hit_s", float("nan"))) <= tau) else 0
        row["fast_outlier"] = is_fast
        if is_fast:
            selected.append(row)

    # Optional cap per group (set OUTLIER_MAX_PER_GROUP=0 to disable)
    cap = int(CONFIG.get("OUTLIER_MAX_PER_GROUP", 0) or 0)
    if cap > 0:
        capped: List[Dict[str, Any]] = []
        by_key: Dict[Tuple[int, int, float], List[Dict[str, Any]]] = {}
        for row in selected:
            key = (int(row["N"]), int(row["steps"]), float(row["cost_scale"]))
            by_key.setdefault(key, []).append(row)
        for key, items in by_key.items():
            items_sorted = sorted(items, key=lambda r: float(r.get("tts_hit_s", float("inf"))))
            capped.extend(items_sorted[:cap])
        selected = capped

    with open(out_full_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for row in results:
            w.writerow(row)
        f.flush()

    with open(out_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for row in selected:
            w.writerow(row)
        f.flush()
    print(f"[phase-b] wrote {out_full_csv} (all {len(results)} runs)")
    print(f"[phase-b] wrote {out_csv} (selected {len(selected)} runs)")


def main() -> None:
    mode = str(CONFIG["MODE"]).strip().lower()
    if mode == "phase-a":
        run_phase_a()
    elif mode == "phase-b":
        run_phase_b()
    else:
        raise ValueError(f"Unknown MODE={CONFIG['MODE']}")


if __name__ == "__main__":
    main()


[phase-b] target=nvidia
[phase-b] t=20.2s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=40.2s method=DAQO N=16 steps=10 cost=0.05 cd=0.0
[phase-b] t=62.7s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=83.2s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=103.7s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=124.4s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=144.9s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=165.5s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=186.2s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=206.9s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=227.7s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=248.6s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=269.2s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=289.8s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=310.5s method=DCQO N=16 steps=10 cost=0.05 cd=0.1
[phase-b] t=331.3s method=DCQO N=16 steps=10 cost=