<a href="https://colab.research.google.com/github/shanikairoshi/QFL_Experiments/blob/main/autopick_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install "qiskit>=1.1" qiskit-aer qiskit-algorithms "qiskit-machine-learning>=0.7" qiskit-ibm-runtime \
                 scikit-learn pandas matplotlib numpy torch torchvision

Collecting qiskit>=1.1
  Downloading qiskit-2.2.3-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (12 kB)
Collecting qiskit-aer
  Downloading qiskit_aer-0.17.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.3 kB)
Collecting qiskit-algorithms
  Downloading qiskit_algorithms-0.4.0-py3-none-any.whl.metadata (4.7 kB)
Collecting qiskit-machine-learning>=0.7
  Downloading qiskit_machine_learning-0.8.4-py3-none-any.whl.metadata (13 kB)
Collecting qiskit-ibm-runtime
  Downloading qiskit_ibm_runtime-0.43.1-py3-none-any.whl.metadata (21 kB)
Collecting rustworkx>=0.15.0 (from qiskit>=1.1)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting stevedore>=3.0.0 (from qiskit>=1.1)
  Downloading stevedore-5.6.0-py3-none-any.whl.metadata (2.3 kB)
Collecting qiskit>=1.1
  Downloading qiskit-1.4.5-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting scipy>=1.5 (from qisk

In [6]:
# qfl_qiskit_builtin_autopick_val_stream_regimes.py
# QFL using in-built Qiskit ML (EstimatorQNN + NeuralNetworkClassifier).
#
# Features:
#   1) Validation split from training set.
#   2) Aggregation modes:
#        - "linear_weighted" (alias of fedavg_weighted)
#        - "circular_weighted"
#        - "auto_pick" (choose linear vs circular)
#   3) AutoPick policy switch:
#        - "loss": choose using VALIDATION loss
#        - "geometry": choose using rho_t threshold
#   4) Stream/append CSV saving each round.
#   5) Regime controls to show "linear is not always true":
#        - partition="dirichlet" with dirichlet_alpha sweep
#        - local_init_noise to encourage client divergence
#        - angle_stress for controlled geometry sensitivity demo
#
# Review note:
#   rho_t is an interpretable geometry-guided preference indicator.
#   It does not redefine FAAA-β; instead, it motivates when a blending
#   mechanism is expected to be beneficial.

from __future__ import annotations
import os, json, time
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

# ------------------------
# Qiskit ML (built-ins)
# ------------------------
from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
from qiskit.quantum_info import SparsePauliOp
from qiskit.primitives import StatevectorEstimator

from qiskit_machine_learning.neural_networks import EstimatorQNN
try:
    from qiskit_machine_learning.algorithms.classifiers import NeuralNetworkClassifier
except Exception:
    from qiskit_machine_learning.algorithms import NeuralNetworkClassifier

# optional gradient
try:
    from qiskit_algorithms.gradients import ParamShiftEstimatorGradient
except Exception:
    try:
        from qiskit.algorithms.gradients import ParamShiftEstimatorGradient
    except Exception:
        ParamShiftEstimatorGradient = None

# ------------------------
# Optimizers (robust import)
# ------------------------
try:
    from qiskit_algorithms.optimizers import COBYLA
except Exception:
    try:
        from qiskit.algorithms.optimizers import COBYLA
    except Exception:
        COBYLA = None

try:
    from qiskit_algorithms.optimizers import SPSA
except Exception:
    try:
        from qiskit.algorithms.optimizers import SPSA
    except Exception:
        SPSA = None

# ------------------------
# Sklearn
# ------------------------
from sklearn import datasets
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss


# ============================================================
# Drive helper
# ============================================================
def get_outdir_in_drive(subdir: str) -> Path:
    """
    If running in Colab, mount Drive and save under /content/drive/MyDrive/<subdir>.
    Otherwise, fallback to ~/ (HOME)/<subdir>.
    """
    try:
        from google.colab import drive  # type: ignore
        drive.mount("/content/drive", force_remount=False)
        base = Path("/content/drive/MyDrive/AutopickV2")
    except Exception:
        base = Path.home()
    out = base / subdir
    out.mkdir(parents=True, exist_ok=True)
    return out


# ============================================================
# Data loading
# ============================================================
def load_breast_cancer(
    pca_k: Optional[int] = 4,
    test_size: float = 0.30,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    data = datasets.load_breast_cancer()
    X, y = data.data.astype(float), data.target.astype(int)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )

    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)

    if pca_k is not None:
        p = PCA(n_components=int(pca_k), random_state=seed)
        X_train = p.fit_transform(X_train)
        X_test = p.transform(X_test)

    return X_train, X_test, y_train, y_test


def load_clinical_csv(
    csv_path: str,
    test_size: float = 0.45,
    pca_k: int = 2,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Minimal preprocessing for BrEaST-Lesions-USG-Clinical.csv.
    """
    df = pd.read_csv(csv_path)

    selected_features = ["Age", "Shape", "Echogenicity",
                         "Posterior_features", "Calcifications", "Classification"]
    df = df[selected_features]

    df = df[
        (df["Age"] != "not available") &
        (~df["Shape"].isin(["not applicable"])) &
        (~df["Echogenicity"].isin(["not applicable"])) &
        (~df["Posterior_features"].isin(["not applicable"])) &
        (~df["Calcifications"].isin(["not applicable", "indefinable"])) &
        (df["Classification"].isin(["benign", "malignant"]))
    ].copy()

    df["Age"] = pd.to_numeric(df["Age"])

    for col in ["Shape", "Echogenicity", "Posterior_features", "Calcifications"]:
        df[col] = LabelEncoder().fit_transform(df[col])

    df["Label"] = df["Classification"].map({"benign": 0, "malignant": 1})
    df.drop(columns=["Classification"], inplace=True)

    X = df.drop(columns=["Label"]).values.astype(float)
    y = df["Label"].values.astype(int)

    sc = StandardScaler()
    X = sc.fit_transform(X)

    p = PCA(n_components=int(pca_k), random_state=seed)
    X = p.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    return X_train, X_test, y_train, y_test


# ============================================================
# Validation split helper
# ============================================================
def split_train_val(
    X_train: np.ndarray,
    y_train: np.ndarray,
    val_size: float = 0.15,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Create a validation set from the training data only.
    """
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_train, y_train, test_size=val_size, random_state=seed, stratify=y_train
    )
    return X_tr, X_val, y_tr, y_val


# ============================================================
# Partition helpers
# ============================================================
def shard_iid(X: np.ndarray, y: np.ndarray, num_clients: int, seed: int = 7):
    rng = np.random.default_rng(seed)
    idx = rng.permutation(len(y))
    chunks = np.array_split(idx, num_clients)
    return [(X[c], y[c]) for c in chunks]


def shard_label_skew(
    X: np.ndarray, y: np.ndarray,
    num_clients: int, min_per_client: int = 60, seed: int = 7
):
    """
    Each client gets mostly a single label (binary); draws with replacement as needed.
    """
    rng = np.random.default_rng(seed)
    X0, y0 = X[y == 0], y[y == 0]
    X1, y1 = X[y == 1], y[y == 1]
    shards = []
    for cid in range(num_clients):
        maj = 0 if (cid % 2 == 0) else 1
        k_major = int(min_per_client * 0.8)
        k_minor = max(1, min_per_client - k_major)

        if maj == 0:
            idx0 = rng.integers(0, len(X0), size=k_major)
            idx1 = rng.integers(0, len(X1), size=k_minor)
            Xi = np.vstack([X0[idx0], X1[idx1]])
            yi = np.hstack([y0[idx0], y1[idx1]])
        else:
            idx0 = rng.integers(0, len(X0), size=k_minor)
            idx1 = rng.integers(0, len(X1), size=k_major)
            Xi = np.vstack([X0[idx0], X1[idx1]])
            yi = np.hstack([y0[idx0], y1[idx1]])

        p = rng.permutation(len(yi))
        shards.append((Xi[p], yi[p]))

    return shards


def shard_dirichlet(
    X: np.ndarray, y: np.ndarray,
    num_clients: int, alpha: float = 0.3, seed: int = 7
):
    """
    Dirichlet label distribution partition.
    Smaller alpha => stronger non-IID.
    """
    rng = np.random.default_rng(seed)
    classes = np.unique(y)
    idx_by_class = {c: np.where(y == c)[0] for c in classes}

    client_indices = [[] for _ in range(num_clients)]

    for c in classes:
        idx_c = idx_by_class[c].copy()
        rng.shuffle(idx_c)

        props = rng.dirichlet([alpha] * num_clients)
        counts = (props * len(idx_c)).astype(int)

        # fix rounding drift
        while counts.sum() < len(idx_c):
            counts[rng.integers(0, num_clients)] += 1
        while counts.sum() > len(idx_c):
            j = rng.integers(0, num_clients)
            if counts[j] > 0:
                counts[j] -= 1

        start = 0
        for i in range(num_clients):
            take = counts[i]
            if take > 0:
                client_indices[i].extend(idx_c[start:start+take].tolist())
            start += take

    shards = []
    for i in range(num_clients):
        idx = np.array(client_indices[i], dtype=int)
        rng.shuffle(idx)
        shards.append((X[idx], y[idx]))

    return shards


# ============================================================
# QNN builder (built-ins only)
# ============================================================
def build_estimator_qnn(
    num_features: int,
    fm_reps: int = 1,
    an_reps: int = 2
):
    fm = ZZFeatureMap(feature_dimension=num_features, reps=fm_reps)
    an = RealAmplitudes(num_qubits=num_features, reps=an_reps, entanglement="linear")

    obs = SparsePauliOp.from_list([("Z" + "I" * (num_features - 1), 1.0)])
    est = StatevectorEstimator()

    grad = None
    if ParamShiftEstimatorGradient is not None:
        try:
            grad = ParamShiftEstimatorGradient(est)
        except Exception:
            grad = None

    qnn = EstimatorQNN(
        circuit=fm.compose(an),
        observables=obs,
        input_params=sorted(list(fm.parameters), key=lambda p: p.name),
        weight_params=list(an.parameters),
        estimator=est,
        gradient=grad,
    )
    return qnn


def make_classifier(
    num_features: int,
    initial_point: Optional[np.ndarray],
    maxiter: int = 60,
    optimizer: str = "SPSA",
    opt_kwargs: Optional[dict] = None
):
    qnn = build_estimator_qnn(num_features)
    opt_kwargs = {} if opt_kwargs is None else dict(opt_kwargs)

    if optimizer.upper() == "SPSA":
        if SPSA is None:
            raise RuntimeError("SPSA not available; install/update qiskit-algorithms.")
        try:
            opt = SPSA(maxiter=maxiter, **opt_kwargs)
        except TypeError:
            opt = SPSA(maxiter=maxiter)
    elif optimizer.upper() == "COBYLA":
        if COBYLA is None:
            raise RuntimeError("COBYLA not available; install/update qiskit-algorithms.")
        opt = COBYLA(maxiter=maxiter, **opt_kwargs) if opt_kwargs else COBYLA(maxiter=maxiter)
    else:
        raise ValueError(f"Unknown optimizer '{optimizer}'.")

    clf = NeuralNetworkClassifier(
        neural_network=qnn,
        optimizer=opt,
        initial_point=initial_point,
        loss="cross_entropy",
        one_hot=False,
    )
    return clf


def extract_weights(clf: NeuralNetworkClassifier) -> np.ndarray:
    for attr in ("fit_result_", "fit_result"):
        if hasattr(clf, attr):
            res = getattr(clf, attr)
            if hasattr(res, "x"):
                return np.asarray(res.x, float).copy()
            if isinstance(res, dict) and "x" in res:
                return np.asarray(res["x"], float).copy()

    if hasattr(clf, "weights_"):
        return np.asarray(getattr(clf, "weights_"), float).copy()

    raise RuntimeError("Could not extract weights; check qiskit-machine-learning version.")


# ============================================================
# Circular helpers
# ============================================================
def wrap_to_pi(theta: np.ndarray) -> np.ndarray:
    return (theta + np.pi) % (2 * np.pi) - np.pi


# ============================================================
# Aggregation helpers
# ============================================================
def agg_fedavg(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    return np.mean(np.stack(local_ws, axis=0), axis=0)


def agg_fedavg_weighted(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    W = np.asarray(shard_sizes, float)
    W = W / (W.sum() + 1e-12)
    A = np.stack(local_ws, axis=0)
    return (W[:, None] * A).sum(axis=0)


def agg_circular_weighted(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    """
    Review simplification:
      Treat all ansatz weights as angular parameters.
      For RealAmplitudes, weights correspond to rotation angles.
    """
    W = np.asarray(shard_sizes, float)
    W = W / (W.sum() + 1e-12)
    Theta = np.stack([wrap_to_pi(w) for w in local_ws], axis=0)
    S = np.sum(W[:, None] * np.sin(Theta), axis=0)
    C = np.sum(W[:, None] * np.cos(Theta), axis=0)
    return np.arctan2(S, C)


AGG_MAP = {
    "fedavg": agg_fedavg,
    "fedavg_weighted": agg_fedavg_weighted,
    "linear_weighted": agg_fedavg_weighted,      # alias for clarity in plots/papers
    "circular_weighted": agg_circular_weighted,
}


# ============================================================
# QNN evaluation
# ============================================================
def qnn_predict_proba(qnn: EstimatorQNN, X: np.ndarray, w: np.ndarray, x_scale: float = np.pi) -> np.ndarray:
    Xs = np.asarray(X, float) * x_scale
    exp = qnn.forward(Xs, w)
    exp = np.asarray(exp, float).reshape(-1)
    p1 = (1.0 - exp) / 2.0
    return np.vstack([1.0 - p1, p1]).T


def eval_candidate(qnn_ref: EstimatorQNN, X_eval, y_eval, w, x_scale):
    probs = qnn_predict_proba(qnn_ref, X_eval, w, x_scale=x_scale)
    yhat = np.argmax(probs, axis=1)
    acc = float(accuracy_score(y_eval, yhat))
    loss = float(log_loss(y_eval, probs, labels=[0, 1]))
    return acc, loss


# ============================================================
# Geometry diagnostics + risk + rho
# ============================================================
def _unwrap_to_ref(A: np.ndarray) -> np.ndarray:
    K, D = A.shape
    out = A.copy()
    ref = out[0]
    for i in range(1, K):
        delta = out[i] - ref
        out[i] = ref + ((delta + np.pi) % (2*np.pi) - np.pi)
    return out


def _min_covering_arc_length(angles_1d: np.ndarray) -> float:
    a = np.sort(angles_1d)
    gaps = np.diff(a, append=a[0] + 2*np.pi)
    max_gap = np.max(gaps)
    return float(2*np.pi - max_gap)


def _geodesic_sse(angles: np.ndarray, center: np.ndarray, W: np.ndarray) -> float:
    diff = wrap_to_pi(angles - center)
    sse_per_client = np.sum(diff**2, axis=1)
    return float(np.sum(W * sse_per_client))


def angle_diagnostics(local_ws: List[np.ndarray], shard_sizes: List[int]) -> dict:
    A_raw = np.stack(local_ws, axis=0)
    A = wrap_to_pi(A_raw)

    W = np.asarray(shard_sizes, float)
    W = W / (W.sum() + 1e-12)

    C = np.sum(W[:, None] * np.cos(A), axis=0)
    S = np.sum(W[:, None] * np.sin(A), axis=0)
    R = np.sqrt(C**2 + S**2)

    R_mean = float(np.mean(R))
    R_min  = float(np.min(R))

    mu_circ = np.arctan2(S, C)

    A_unwrap = _unwrap_to_ref(A)
    mu_lin_unwrapped = np.sum(W[:, None] * A_unwrap, axis=0)
    mu_lin = wrap_to_pi(mu_lin_unwrapped)

    sse_circ = _geodesic_sse(A, mu_circ, W)
    sse_lin  = _geodesic_sse(A, mu_lin,  W)
    sse_gap  = float(sse_lin - sse_circ)

    cover_lengths = np.array([_min_covering_arc_length(A[:, j]) for j in range(A.shape[1])])
    straddle_frac = float(np.mean(cover_lengths > np.pi))

    gap = np.abs(wrap_to_pi(mu_lin - mu_circ))
    disagreement_rad = float(np.mean(gap))
    disagreement_deg = float(np.rad2deg(disagreement_rad))

    return {
        "R_mean": R_mean,
        "R_min": R_min,
        "straddle_frac": straddle_frac,
        "sse_geo_gap": sse_gap,
        "disagreement_rad": disagreement_rad,
        "disagreement_deg": disagreement_deg,
        "mu_circ": mu_circ,
        "mu_lin":  mu_lin,
    }


def geometry_risk_from_diag(diag: dict, a: float = 0.50, b: float = 0.30, c: float = 0.20) -> float:
    R_mean = float(diag.get("R_mean", 0.0))
    straddle = float(diag.get("straddle_frac", 0.0))
    gap = float(diag.get("sse_geo_gap", 0.0))

    gap_pos = max(0.0, gap)
    gap_term = gap_pos / (gap_pos + 1.0)

    risk = a * (1.0 - R_mean) + b * straddle + c * gap_term
    return float(np.clip(risk, 0.0, 1.0))


def rho_from_risk(risk: float, lam: float = 8.0, center: float = 0.5) -> float:
    return float(1.0 / (1.0 + np.exp(-lam * (risk - center))))


# ============================================================
# Streaming CSV helpers
# ============================================================
def append_row_csv(row: Dict[str, Any], csv_path: Path):
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame([row])
    header = not csv_path.exists()
    df.to_csv(csv_path, mode="a", header=header, index=False)


def append_rows_csv(rows: List[Dict[str, Any]], csv_path: Path):
    if not rows:
        return
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame(rows)
    header = not csv_path.exists()
    df.to_csv(csv_path, mode="a", header=header, index=False)


# ============================================================
# Federated loop config
# ============================================================
@dataclass
class FLConfig:
    num_clients: int = 10
    rounds: int = 50
    seed: int = 2025

    # partition: "iid", "label_skew", "dirichlet"
    partition: str = "label_skew"
    dirichlet_alpha: float = 0.3

    # agg: "linear_weighted", "circular_weighted", "auto_pick"
    agg_mode: str = "circular_weighted"

    maxiter_local: int = 60
    x_scale: float = np.pi
    optimizer: str = "SPSA"
    opt_kwargs: Optional[dict] = None

    # Validation split inside training
    val_size: float = 0.15
    val_seed: int = 42

    # AutoPick decision policy
    autopick_policy: str = "loss"  # "loss" or "geometry"
    rho_threshold: float = 0.5

    # Regime amplifiers
    local_init_noise: float = 0.0  # try 0.05–0.2

    angle_stress: bool = False
    angle_stress_shift: float = 2.8  # radians (~160°)

    # Streaming save options
    stream_save: bool = True
    save_clients_each_round: bool = False


# ============================================================
# Federated training loop
# ============================================================
def run_federated_qnn_builtins(
    X_train, y_train, X_test, y_test,
    cfg: FLConfig,
    rounds_csv_path: Optional[Path] = None,
    clients_csv_path: Optional[Path] = None,
) -> Dict[str, Any]:

    rng = np.random.default_rng(cfg.seed)

    if cfg.agg_mode != "auto_pick" and cfg.agg_mode not in AGG_MAP:
        raise ValueError(
            f"Unknown agg_mode={cfg.agg_mode}. Choose from {list(AGG_MAP.keys())} + ['auto_pick']"
        )
    aggregator = AGG_MAP.get(cfg.agg_mode, None)

    # ---- Create validation set from training ----
    X_tr, X_val, y_tr, y_val = split_train_val(
        X_train, y_train, val_size=cfg.val_size, seed=cfg.val_seed
    )

    # ---- Client shards are built from TRAIN-ONLY subset ----
    if cfg.partition == "iid":
        shards = shard_iid(X_tr, y_tr, cfg.num_clients, seed=cfg.seed)
    elif cfg.partition == "dirichlet":
        shards = shard_dirichlet(
            X_tr, y_tr, cfg.num_clients,
            alpha=cfg.dirichlet_alpha, seed=cfg.seed
        )
    else:
        shards = shard_label_skew(
            X_tr, y_tr, cfg.num_clients,
            min_per_client=max(40, len(X_tr)//cfg.num_clients),
            seed=cfg.seed
        )

    num_features = X_train.shape[1]

    qnn_ref = build_estimator_qnn(num_features, fm_reps=1, an_reps=2)
    D = len(qnn_ref.weight_params)

    # Initialize global weights
    w_global = 0.10 * rng.standard_normal(D)

    rounds_rows: List[Dict[str, Any]] = []
    client_rows_all: List[Dict[str, Any]] = []

    print(
        f"=== QFL | agg='{cfg.agg_mode}' | policy='{cfg.autopick_policy}' | "
        f"partition='{cfg.partition}' (α={cfg.dirichlet_alpha}) | "
        f"clients={cfg.num_clients} rounds={cfg.rounds} | "
        f"init_noise={cfg.local_init_noise} | stress={cfg.angle_stress} ==="
    )

    for r in range(1, cfg.rounds + 1):
        t0 = time.perf_counter()

        local_ws: List[np.ndarray] = []
        shard_sizes: List[int] = []

        client_rows_this_round: List[Dict[str, Any]] = []

        # ---------------------------
        # Local training per client
        # ---------------------------
        for cid, (Xi, yi) in enumerate(shards):

            # client-specific init noise to encourage divergence
            w0 = w_global.copy()
            if cfg.local_init_noise > 0:
                w0 = w0 + cfg.local_init_noise * rng.standard_normal(len(w0))

            clf = make_classifier(
                num_features=num_features,
                initial_point=w0,
                maxiter=cfg.maxiter_local,
                optimizer=cfg.optimizer,
                opt_kwargs=cfg.opt_kwargs,
            )

            t_c0 = time.perf_counter()
            clf.fit(Xi * cfg.x_scale, yi)
            dt_c = time.perf_counter() - t_c0

            w_local = extract_weights(clf)
            if len(w_local) != D:
                raise RuntimeError(f"[Round {r}] Client {cid}: weight dim mismatch ({len(w_local)} != {D})")

            # controlled geometry stress-test (mechanism demo)
            if cfg.angle_stress:
                direction = +1.0 if (cid % 2 == 0) else -1.0
                w_local = wrap_to_pi(w_local + direction * cfg.angle_stress_shift)

            local_ws.append(w_local)
            shard_sizes.append(len(yi))

            # --- optional client eval ---
            probs_tr_loc = qnn_predict_proba(qnn_ref, Xi,     w_local, x_scale=cfg.x_scale)
            probs_te_loc = qnn_predict_proba(qnn_ref, X_test, w_local, x_scale=cfg.x_scale)
            yhat_tr_loc = np.argmax(probs_tr_loc, axis=1)
            yhat_te_loc = np.argmax(probs_te_loc, axis=1)

            c_row = {
                "round": r,
                "client_id": cid,
                "shard_size": int(len(yi)),
                "train_acc_local": float(accuracy_score(yi, yhat_tr_loc)),
                "test_acc_local":  float(accuracy_score(y_test, yhat_te_loc)),
                "train_loss_local": float(log_loss(yi,     probs_tr_loc, labels=[0, 1])),
                "test_loss_local":  float(log_loss(y_test, probs_te_loc, labels=[0, 1])),
                "time_sec_local":   float(dt_c),
            }
            client_rows_this_round.append(c_row)

        # ---------------------------
        # Diagnostics (before aggregation)
        # ---------------------------
        diag = angle_diagnostics(local_ws, shard_sizes)
        r_t = geometry_risk_from_diag(diag)
        rho_t = rho_from_risk(r_t)

        print(
            f"    [diag] R̄={diag['R_mean']:.3f} | Rmin={diag['R_min']:.3f} | "
            f"straddle_frac={diag['straddle_frac']:.2f} | ΔSSE_geo={diag['sse_geo_gap']:.4f} | "
            f"geom_err={diag['disagreement_deg']:.2f}° | r_t={r_t:.3f} | ρ_t={rho_t:.3f}"
        )

        # ---------------------------
        # Aggregate to new global
        # ---------------------------
        picked = None

        # Candidate stats (for logging)
        val_acc_lin = val_loss_lin = None
        val_acc_circ = val_loss_circ = None
        test_acc_lin = test_loss_lin = None
        test_acc_circ = test_loss_circ = None

        if cfg.agg_mode == "auto_pick":
            # candidates
            w_lin  = agg_fedavg_weighted(local_ws, shard_sizes)
            w_circ = agg_circular_weighted(local_ws, shard_sizes)

            # evaluate on validation for selection logging
            val_acc_lin,  val_loss_lin  = eval_candidate(qnn_ref, X_val, y_val, w_lin,  cfg.x_scale)
            val_acc_circ, val_loss_circ = eval_candidate(qnn_ref, X_val, y_val, w_circ, cfg.x_scale)

            # also evaluate on test for reporting
            test_acc_lin,  test_loss_lin  = eval_candidate(qnn_ref, X_test, y_test, w_lin,  cfg.x_scale)
            test_acc_circ, test_loss_circ = eval_candidate(qnn_ref, X_test, y_test, w_circ, cfg.x_scale)

            policy = str(cfg.autopick_policy).lower().strip()

            if policy == "loss":
                use_circ = (val_loss_circ <= val_loss_lin)
                rule_used = "loss(val)"
            elif policy == "geometry":
                use_circ = (rho_t >= cfg.rho_threshold)
                rule_used = "geometry(ρ)"
            else:
                raise ValueError("autopick_policy must be 'loss' or 'geometry'")

            w_global = w_circ if use_circ else w_lin
            picked = "circular" if use_circ else "linear"

            print(
                f"    ↳ AutoPick-{rule_used} chose: {picked} | "
                f"val_loss_lin={val_loss_lin:.4f}, val_loss_circ={val_loss_circ:.4f} | "
                f"ρ_t={rho_t:.3f}"
            )
        else:
            w_global = aggregator(local_ws, shard_sizes)  # type: ignore
            picked = cfg.agg_mode
            print(f"    ↳ agg used: {cfg.agg_mode}")

        # ---------------------------
        # Global evaluation
        # ---------------------------
        probs_tr = qnn_predict_proba(qnn_ref, X_tr,   w_global, x_scale=cfg.x_scale)
        probs_va = qnn_predict_proba(qnn_ref, X_val,  w_global, x_scale=cfg.x_scale)
        probs_te = qnn_predict_proba(qnn_ref, X_test, w_global, x_scale=cfg.x_scale)

        ytr_pred = np.argmax(probs_tr, axis=1)
        yva_pred = np.argmax(probs_va, axis=1)
        yte_pred = np.argmax(probs_te, axis=1)

        train_acc  = float(accuracy_score(y_tr,   ytr_pred))
        val_acc    = float(accuracy_score(y_val,  yva_pred))
        test_acc   = float(accuracy_score(y_test, yte_pred))

        train_loss = float(log_loss(y_tr,   probs_tr, labels=[0, 1]))
        val_loss   = float(log_loss(y_val,  probs_va, labels=[0, 1]))
        test_loss  = float(log_loss(y_test, probs_te, labels=[0, 1]))

        dt = time.perf_counter() - t0

        print(
            f"[Round {r:02d}] TrainAcc={train_acc:.3f} | ValAcc={val_acc:.3f} | TestAcc={test_acc:.3f} | "
            f"TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | TestLoss={test_loss:.4f} | "
            f"Time={dt:.2f}s | picked={picked}"
        )

        # ---------------------------
        # Per-round row
        # ---------------------------
        row = {
            "round": r,
            "train_acc": train_acc,
            "val_acc": val_acc,
            "test_acc": test_acc,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "test_loss": test_loss,
            "time_sec": float(dt),

            "agg_mode": cfg.agg_mode,
            "autopick_policy": cfg.autopick_policy if cfg.agg_mode == "auto_pick" else "fixed",
            "picked": picked,

            # partition + regime knobs for traceability
            "partition": cfg.partition,
            "dirichlet_alpha": float(cfg.dirichlet_alpha),
            "local_init_noise": float(cfg.local_init_noise),
            "angle_stress": bool(cfg.angle_stress),
            "angle_stress_shift": float(cfg.angle_stress_shift),

            # geometry diagnostics
            "resultant_mean": float(diag["R_mean"]),
            "resultant_min":  float(diag["R_min"]),
            "straddle_frac":  float(diag["straddle_frac"]),
            "delta_SSE_geo":  float(diag["sse_geo_gap"]),
            "geom_err_deg":   float(diag["disagreement_deg"]),
            "geom_err_rad":   float(diag["disagreement_rad"]),
            "risk_t": float(r_t),
            "rho_t":  float(rho_t),

            # selection threshold (for traceability)
            "rho_threshold": float(cfg.rho_threshold),
            "val_size": float(cfg.val_size),
        }

        # candidate logs (useful for plotting why choices differ)
        if cfg.agg_mode == "auto_pick":
            row.update({
                "val_loss_lin":  None if val_loss_lin  is None else float(val_loss_lin),
                "val_loss_circ": None if val_loss_circ is None else float(val_loss_circ),
                "val_acc_lin":   None if val_acc_lin   is None else float(val_acc_lin),
                "val_acc_circ":  None if val_acc_circ  is None else float(val_acc_circ),

                "test_loss_lin":  None if test_loss_lin  is None else float(test_loss_lin),
                "test_loss_circ": None if test_loss_circ is None else float(test_loss_circ),
                "test_acc_lin":   None if test_acc_lin   is None else float(test_acc_lin),
                "test_acc_circ":  None if test_acc_circ  is None else float(test_acc_circ),
            })

        rounds_rows.append(row)

        # ---------------------------
        # Streaming save
        # ---------------------------
        if cfg.stream_save and rounds_csv_path is not None:
            append_row_csv(row, rounds_csv_path)

        if cfg.save_clients_each_round and clients_csv_path is not None:
            append_rows_csv(client_rows_this_round, clients_csv_path)

        client_rows_all.extend(client_rows_this_round)

    return {
        "w_global": w_global,
        "rows": rounds_rows,
        "client_rows": client_rows_all,
        "config": cfg.__dict__,
    }


# ============================================================
# Save metadata (once)
# ============================================================
def save_meta(meta: Dict[str, Any], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(meta, f, indent=2)


# ============================================================
# Main (configuration-driven)
# ============================================================
if __name__ == "__main__":
    np.random.seed(42)

    # ---------------- USER CHOICES ----------------
    DATASET_CHOICE = "clinical"
    # options: "breast_cancer", "clinical", "both"

    # Aggregators you want to compare
    AGG_MODES_TO_RUN = [
        "linear_weighted",
        "circular_weighted",
        "auto_pick",
    ]

    # AutoPick policies to run (only used when agg_mode == "auto_pick")
    AUTOPICK_POLICIES = ["loss", "geometry"]

    # Partition regimes
    PARTITIONS_TO_RUN = [
        #("label_skew", None),
        # Dirichlet sweeps: include multiple alphas to show regime shifts
        ("dirichlet", 5.0),
        ("dirichlet", 0.5),
        ("dirichlet", 0.1),
    ]

    # Regime amplifiers
    LOCAL_INIT_NOISE_LIST = [0.0, 0.1]  # 0.0 baseline, 0.1 encourages divergence
    ANGLE_STRESS_LIST = [False, True]  # True = controlled mechanism demo

    NUM_CLIENTS = 10
    ROUNDS = 20
    MAXITER_LOCAL = 60
    X_SCALE = np.pi

    # Validation split inside training
    VAL_SIZE = 0.15
    VAL_SEED = 42

    # Rho threshold for geometry policy
    RHO_THRESHOLD = 0.5

    # Clinical CSV path
    CLINICAL_CSV = "/content/drive/MyDrive/data/BrEaST-Lesions-USG-Clinical.csv"

    # Optimizer defaults
    OPT_BREAST = ("SPSA", {"learning_rate": 0.05, "perturbation": 0.1})
    OPT_CLIN   = ("COBYLA", None)

    # Stream-save options
    STREAM_SAVE = True
    SAVE_CLIENTS_EACH_ROUND = False
    CLEAN_OLD_CSVS = True

    # ---------------- INTERNAL HELPERS ----------------
    def outdir_for(dataset_key: str, agg_mode: str, policy: str,
                   partition: str, alpha: Optional[float],
                   init_noise: float, stress: bool) -> Path:

        alpha_tag = "" if alpha is None else f"_a{alpha}"
        stress_tag = "_stress" if stress else ""
        noise_tag = f"_n{init_noise}"

        name = (
            f"QFL_folder/ZZ_{dataset_key}_"
            f"{agg_mode}_"
            f"{policy}_"
            f"{partition}{alpha_tag}{noise_tag}{stress_tag}"
        )
        return get_outdir_in_drive(name)

    def rounds_csv_path(outdir: Path, dataset_key: str) -> Path:
        return outdir / f"round_metrics_{dataset_key}_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"

    def clients_csv_path(outdir: Path, dataset_key: str) -> Path:
        return outdir / f"client_metrics_{dataset_key}_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"

    def maybe_clean(*paths: Path):
        if not CLEAN_OLD_CSVS:
            return
        for p in paths:
            if p.exists():
                p.unlink()

    def run_one(dataset_key: str,
                X_train, X_test, y_train, y_test,
                agg_mode: str,
                policy: str,
                partition: str,
                alpha: Optional[float],
                init_noise: float,
                stress: bool,
                optimizer_name: str,
                opt_kwargs: Optional[dict]):

        outdir = outdir_for(dataset_key, agg_mode, policy, partition, alpha, init_noise, stress)
        r_csv = rounds_csv_path(outdir, dataset_key)
        c_csv = clients_csv_path(outdir, dataset_key)

        maybe_clean(r_csv, c_csv)

        cfg = FLConfig(
            num_clients=NUM_CLIENTS,
            rounds=ROUNDS,
            seed=2025,

            partition=partition,
            dirichlet_alpha=float(alpha) if alpha is not None else 0.3,

            agg_mode=agg_mode,
            maxiter_local=MAXITER_LOCAL,
            x_scale=X_SCALE,
            optimizer=optimizer_name,
            opt_kwargs=opt_kwargs,

            val_size=VAL_SIZE,
            val_seed=VAL_SEED,

            autopick_policy=policy,
            rho_threshold=RHO_THRESHOLD,

            local_init_noise=init_noise,
            angle_stress=stress,
            angle_stress_shift=2.8,

            stream_save=STREAM_SAVE,
            save_clients_each_round=SAVE_CLIENTS_EACH_ROUND,
        )

        print(f"\n=== RUN | {dataset_key} | agg={agg_mode} | policy={policy} | "
              f"partition={partition} | alpha={alpha} | init_noise={init_noise} | stress={stress} ===")

        out = run_federated_qnn_builtins(
            X_train, y_train, X_test, y_test,
            cfg,
            rounds_csv_path=r_csv,
            clients_csv_path=c_csv if SAVE_CLIENTS_EACH_ROUND else None
        )

        save_meta({"dataset": dataset_key, **out["config"]}, outdir / "run_metadata.json")
        np.save(outdir / "w_global.npy", out["w_global"])

        print(f"Saved rounds CSV to: {r_csv}")
        if SAVE_CLIENTS_EACH_ROUND:
            print(f"Saved client CSV to: {c_csv}")

        return out

    # ---------------- DATASET RUNNERS ----------------
    def run_breast():
        X_train, X_test, y_train, y_test = load_breast_cancer(pca_k=5, test_size=0.30, seed=42)
        opt_name, opt_kwargs = OPT_BREAST

        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one("breast_cancer", X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs)
                        else:
                            # policy label for fixed modes
                            run_one("breast_cancer", X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs)

    def run_clinical():
        X_train, X_test, y_train, y_test = load_clinical_csv(
            CLINICAL_CSV, test_size=0.45, pca_k=5, seed=42
        )
        opt_name, opt_kwargs = OPT_CLIN

        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one("clinical", X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs)
                        else:
                            run_one("clinical", X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs)

    # ---------------- EXECUTE CHOSEN ----------------
    if DATASET_CHOICE == "breast_cancer":
        run_breast()
    elif DATASET_CHOICE == "clinical":
        try:
            run_clinical()
        except FileNotFoundError:
            print(f"[WARN] Clinical CSV not found at: {CLINICAL_CSV}.")
    elif DATASET_CHOICE == "both":
        run_breast()
        try:
            run_clinical()
        except FileNotFoundError:
            print(f"[WARN] Clinical CSV not found at: {CLINICAL_CSV}.")
    else:
        raise ValueError("DATASET_CHOICE must be: 'breast_cancer', 'clinical', or 'both'.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

=== RUN | clinical | agg=linear_weighted | policy=fixed | partition=label_skew | alpha=None | init_noise=0.0 | stress=False ===
=== QFL | agg='linear_weighted' | policy='fixed' | partition='label_skew' (α=0.3) | clients=10 rounds=20 | init_noise=0.0 | stress=False ===
    [diag] R̄=0.821 | Rmin=0.661 | straddle_frac=0.00 | ΔSSE_geo=-0.0080 | geom_err=1.07° | r_t=0.089 | ρ_t=0.036
    ↳ agg used: linear_weighted
[Round 01] TrainAcc=0.438 | ValAcc=0.444 | TestAcc=0.600 | TrainLoss=0.7059 | ValLoss=0.7016 | TestLoss=0.6909 | Time=142.82s | picked=linear_weighted
    [diag] R̄=0.908 | Rmin=0.832 | straddle_frac=0.00 | ΔSSE_geo=-0.0032 | geom_err=0.67° | r_t=0.046 | ρ_t=0.026
    ↳ agg used: linear_weighted
[Round 02] TrainAcc=0.427 | ValAcc=0.389 | TestAcc=0.568 | TrainLoss=0.7161 | ValLoss=0.7221 | TestLoss=0.6856 | Time=145.66s | picked=linear_weighted
    [di

QiskitMachineLearningError: 'The target values appear to be multi-classified. The neural network output shape is only suitable for binary classification.'