<a href="https://colab.research.google.com/github/shanikairoshi/QFL_Experiments/blob/main/Seeded_QFL_CircularVsLinear_v2_withQoS_dataset_hub_nuscenes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# QoS-weighted QFL aggregation (Linear vs Circular) — Notebook

This notebook wraps the fixed script `pasted_qos_updated_fixed.py`.

The previous error (`AttributeError: 'FLConfig' object has no attribute 'use_qos_weights'`)
occurred because QoS fields were outside the `FLConfig` dataclass. This notebook fixes that,
so you can toggle QoS-weighted aggregation ablations.

## Ablations
- Baseline: `use_qos_weights=False`
- Fidelity only: `use_qos_weights=True, qos_alpha=1, qos_gamma=0, qos_delta=0`
- Latency only: `use_qos_weights=True, qos_alpha=0, qos_gamma=1, qos_delta=0`
- Instability only: `use_qos_weights=True, qos_alpha=0, qos_gamma=0, qos_delta=1`
- All signals: `use_qos_weights=True, qos_alpha=1, qos_gamma=1, qos_delta=1`


In [None]:
!pip install "qiskit>=1.1" qiskit-aer qiskit-algorithms "qiskit-machine-learning>=0.7" qiskit-ibm-runtime \
                 scikit-learn pandas matplotlib numpy torch torchvision



In [None]:
# (Optional) quick sanity check that FLConfig has QoS fields
# Run this cell after executing the main script cell below if you want to confirm.


In [None]:
import warnings
warnings.filterwarnings("ignore")  # suppress all warnings


import os, json, time
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

# ------------------------
# Qiskit ML (built-ins)
# ------------------------
from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
from qiskit.quantum_info import SparsePauliOp, Statevector
from qiskit.primitives import StatevectorEstimator

from qiskit_machine_learning.neural_networks import EstimatorQNN
try:
    from qiskit_machine_learning.algorithms.classifiers import NeuralNetworkClassifier
except Exception:
    from qiskit_machine_learning.algorithms import NeuralNetworkClassifier

# optional gradient
try:
    from qiskit_algorithms.gradients import ParamShiftEstimatorGradient
except Exception:
    try:
        from qiskit.algorithms.gradients import ParamShiftEstimatorGradient
    except Exception:
        ParamShiftEstimatorGradient = None

# ------------------------
# Optimizers (robust import)
# ------------------------
try:
    from qiskit_algorithms.optimizers import COBYLA
except Exception:
    try:
        from qiskit.algorithms.optimizers import COBYLA
    except Exception:
        COBYLA = None

try:
    from qiskit_algorithms.optimizers import SPSA
except Exception:
    try:
        from qiskit.algorithms.optimizers import SPSA
    except Exception:
        SPSA = None

# ------------------------
# Sklearn
# ------------------------
from sklearn import datasets
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss


# ============================================================
# Drive helper
# ============================================================
def get_outdir_in_drive(subdir: str) -> Path:
    """
    If running in Colab, mount Drive and save under /content/drive/MyDrive/<subdir>.
    Otherwise, fallback to ~/ (HOME)/<subdir>.
    """
    try:
        from google.colab import drive  # type: ignore
        drive.mount("/content/drive", force_remount=False)
        base = Path("/content/drive/MyDrive/AutopickV2")
    except Exception:
        base = Path.home()
    out = base / subdir
    out.mkdir(parents=True, exist_ok=True)
    return out


# ============================================================
# Data loading
# ============================================================
def load_breast_cancer(
    pca_k: Optional[int] = 4,
    test_size: float = 0.30,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    data = datasets.load_breast_cancer()
    X, y = data.data.astype(float), data.target.astype(int)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )

    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)

    if pca_k is not None:
        p = PCA(n_components=int(pca_k), random_state=seed)
        X_train = p.fit_transform(X_train)
        X_test = p.transform(X_test)

    return X_train, X_test, y_train, y_test


def load_clinical_csv(
    csv_path: str,
    test_size: float = 0.25,
    pca_k: int = 2,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Minimal preprocessing for BrEaST-Lesions-USG-Clinical.csv.
    """
    df = pd.read_csv(csv_path)

    selected_features = ["Age", "Shape", "Echogenicity",
                         "Posterior_features", "Calcifications", "Classification"]
    df = df[selected_features]

    df = df[
        (df["Age"] != "not available") &
        (~df["Shape"].isin(["not applicable"])) &
        (~df["Echogenicity"].isin(["not applicable"])) &
        (~df["Posterior_features"].isin(["not applicable"])) &
        (~df["Calcifications"].isin(["not applicable", "indefinable"])) &
        (df["Classification"].isin(["benign", "malignant"]))
    ].copy()

    df["Age"] = pd.to_numeric(df["Age"])

    for col in ["Shape", "Echogenicity", "Posterior_features", "Calcifications"]:
        df[col] = LabelEncoder().fit_transform(df[col])

    df["Label"] = df["Classification"].map({"benign": 0, "malignant": 1})
    df.drop(columns=["Classification"], inplace=True)

    X = df.drop(columns=["Label"]).values.astype(float)
    y = df["Label"].values.astype(int)

    sc = StandardScaler()
    X = sc.fit_transform(X)

    p = PCA(n_components=int(pca_k), random_state=seed)
    X = p.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    return X_train, X_test, y_train, y_test


# ============================================================
# Validation split helper
# ============================================================
def split_train_val(
    X_train: np.ndarray,
    y_train: np.ndarray,
    val_size: float = 0.15,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Create a validation set from the training data only.
    """
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_train, y_train, test_size=val_size, random_state=seed, stratify=y_train
    )
    return X_tr, X_val, y_tr, y_val


# ============================================================
# Partition helpers
# ============================================================
def shard_iid(X: np.ndarray, y: np.ndarray, num_clients: int, seed: int = 7):
    rng = np.random.default_rng(seed)
    idx = rng.permutation(len(y))
    chunks = np.array_split(idx, num_clients)
    return [(X[c], y[c]) for c in chunks]


# ============================================================
# OpenML loaders (CPU-friendly) for biomarker / tabular datasets
# ============================================================
def load_openml_binary(
    data_id: int,
    *,
    pos_label: Optional[str] = None,
    pca_k: Optional[int] = 4,
    test_size: float = 0.30,
    seed: int = 42,
    max_samples: Optional[int] = 600,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Fetch a dataset from OpenML and return (X_train, X_test, y_train, y_test) for binary classification.
    - If the dataset is binary, we map the *second* class (sorted) to 1 by default.
    - If multi-class, we keep only the top-2 most frequent classes (binary reduction).
    - max_samples caps the total samples (before train/test split) to keep CPU runtime practical.
    """
    X, y = fetch_openml(data_id=data_id, as_frame=True, return_X_y=True)
    X_df = X.copy() if hasattr(X, "copy") else pd.DataFrame(X)
    y_s = pd.Series(y).astype(str)

    # Numeric matrix + median imputation
    X_df = X_df.apply(pd.to_numeric, errors="coerce")
    X_df = X_df.fillna(X_df.median(numeric_only=True))

    # Reduce labels to binary (if needed)
    classes = list(pd.unique(y_s))
    if pos_label is not None:
        y_bin = (y_s == str(pos_label)).astype(int)
    else:
        if len(classes) == 2:
            cls_sorted = sorted(map(str, classes))
            y_bin = (y_s == cls_sorted[1]).astype(int)
        else:
            top2 = y_s.value_counts().index[:2].tolist()
            mask = y_s.isin(top2)
            X_df = X_df.loc[mask].reset_index(drop=True)
            y_s = y_s.loc[mask].reset_index(drop=True)
            y_bin = (y_s == top2[1]).astype(int)

    # Optional subsample BEFORE split (stratified)
    if max_samples is not None and len(X_df) > max_samples:
        X_df, _, y_bin, _ = train_test_split(
            X_df, y_bin, train_size=max_samples, random_state=seed, stratify=y_bin
        )

    X_train, X_test, y_train, y_test = train_test_split(
        X_df.to_numpy(dtype=float),
        np.asarray(y_bin, dtype=int),
        test_size=test_size,
        random_state=seed,
        stratify=y_bin
    )

    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)

    if pca_k is not None:
        pca = PCA(n_components=pca_k, random_state=seed, svd_solver="randomized")
        X_train = pca.fit_transform(X_train)
        X_test = pca.transform(X_test)

    return X_train, X_test, y_train, y_test

def load_biomarker_leukemia(
    *,
    pca_k: Optional[int] = 4,
    test_size: float = 0.30,
    seed: int = 42,
    max_samples: Optional[int] = 400,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """OpenML id=1104 (leukemia; ALL vs AML style binary)."""
    return load_openml_binary(
        1104, pca_k=pca_k, test_size=test_size, seed=seed, max_samples=max_samples
    )

def load_biomarker_arcene(
    *,
    pca_k: Optional[int] = 4,
    test_size: float = 0.30,
    seed: int = 42,
    max_samples: Optional[int] = 600,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """OpenML id=1458 (ARCENE; cancer vs normal; high-dimensional)."""
    return load_openml_binary(
        1458, pca_k=pca_k, test_size=test_size, seed=seed, max_samples=max_samples
    )


def build_nuscenes_turning_dataset(
    dataroot: str,
    version: str = "v1.0-mini",
    category_prefix: str = "vehicle",
    history_sec: float = 2.0,
    horizon_sec: float = 2.0,
    yaw_thresh_deg: float = 15.0,
    max_samples: Optional[int] = 400,
    seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray]:
    """Build a small CPU-friendly motion dataset from nuScenes annotations.

    We construct fixed-length windows from annotated object tracks (default: any 'vehicle*').
    Binary label: whether the agent will 'turn' within the horizon (yaw change > threshold).

    Requires:
      pip install nuscenes-devkit
      and nuScenes data downloaded under `dataroot` with `v1.0-mini/` available.
    """
    try:
        from nuscenes.nuscenes import NuScenes
        from pyquaternion import Quaternion
    except Exception as e:
        raise ImportError(
            "nuScenes support requires `nuscenes-devkit` (and its deps). "
            "Install with: pip install nuscenes-devkit"
        ) from e

    dt = 0.5  # nuScenes keyframes are at 2Hz (every 0.5s) in the standard release.
    H = max(1, int(round(history_sec / dt)))
    F = max(1, int(round(horizon_sec / dt)))
    yaw_thresh = np.deg2rad(float(yaw_thresh_deg))

    def wrap_pi(a: np.ndarray) -> np.ndarray:
        return (a + np.pi) % (2 * np.pi) - np.pi

    nusc = NuScenes(version=version, dataroot=dataroot, verbose=False)
    rng = np.random.default_rng(seed)

    X_list: List[List[float]] = []
    y_list: List[int] = []

    for inst in nusc.instance:
        # Build the annotation chain for this instance.
        ann_token = inst.get("first_annotation_token", "")
        if not ann_token:
            continue

        track = []
        while ann_token:
            ann = nusc.get("sample_annotation", ann_token)
            if str(ann.get("category_name", "")).startswith(category_prefix):
                track.append(ann)
            ann_token = ann.get("next", "") or ""

        if len(track) < (H + F + 2):
            continue

        xs = np.asarray([a["translation"][0] for a in track], dtype=float)
        ys = np.asarray([a["translation"][1] for a in track], dtype=float)

        # yaw from quaternion [w, x, y, z]
        yaws = np.asarray([Quaternion(a["rotation"]).yaw_pitch_roll[0] for a in track], dtype=float)

        # speed estimate (finite difference on positions)
        dx = np.diff(xs)
        dy = np.diff(ys)
        speeds = np.sqrt(dx**2 + dy**2) / dt
        speeds = np.concatenate([[speeds[0]], speeds])  # align to track length

        for t in range(H, len(track) - F - 1):
            dxH = xs[t] - xs[t - H]
            dyH = ys[t] - ys[t - H]

            v_hist = speeds[t - H:t]
            v_mean = float(np.mean(v_hist))
            v_last = float(speeds[t])
            v_prev = float(speeds[t - 1]) if t - 1 >= 0 else v_last
            a_last = (v_last - v_prev) / dt

            yaw_now = float(yaws[t])
            yaw_past = float(yaws[t - H])
            yaw_rate = float(wrap_pi(np.array([yaw_now - yaw_past]))[0] / (H * dt))

            yaw_future = float(yaws[t + F])
            turn = int(abs(wrap_pi(np.array([yaw_future - yaw_now]))[0]) > yaw_thresh)

            # Feature vector (compact + stable):
            # displacement, speed stats, accel, yaw_rate, heading embedding
            X_list.append([
                float(dxH), float(dyH),
                v_mean, v_last,
                float(np.std(v_hist, ddof=0)),
                float(a_last),
                float(yaw_rate),
                float(np.sin(yaw_now)), float(np.cos(yaw_now)),
            ])
            y_list.append(turn)

            if max_samples is not None and len(y_list) >= int(max_samples):
                break

        if max_samples is not None and len(y_list) >= int(max_samples):
            break

    if len(y_list) < 20:
        raise RuntimeError(
            "nuScenes motion extraction produced too few windows. "
            "Check that `dataroot` points to a valid nuScenes root containing the requested version "
            f"(e.g., {dataroot}/{version})."
        )

    X = np.asarray(X_list, dtype=float)
    y = np.asarray(y_list, dtype=int)

    idx = rng.permutation(len(y))
    return X[idx], y[idx]


def load_motion_nuscenes_mini(
    *,
    dataroot: str,
    version: str = "v1.0-mini",
    category_prefix: str = "vehicle",
    pca_k: Optional[int] = 4,
    test_size: float = 0.30,
    seed: int = 42,
    max_samples: Optional[int] = 400,
    history_sec: float = 2.0,
    horizon_sec: float = 2.0,
    yaw_thresh_deg: float = 15.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """nuScenes v1.0-mini → CPU-friendly motion windows (binary turning label)."""
    X, y = build_nuscenes_turning_dataset(
        dataroot=dataroot,
        version=version,
        category_prefix=category_prefix,
        history_sec=history_sec,
        horizon_sec=horizon_sec,
        yaw_thresh_deg=yaw_thresh_deg,
        max_samples=max_samples,
        seed=seed,
    )

    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=test_size,
        random_state=seed,
        stratify=y
    )

    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)

    if pca_k is not None:
        pca = PCA(n_components=pca_k, random_state=seed, svd_solver="randomized")
        X_train = pca.fit_transform(X_train)
        X_test = pca.transform(X_test)

    return X_train, X_test, y_train, y_test


def shard_label_skew(
    X: np.ndarray, y: np.ndarray,
    num_clients: int, min_per_client: int = 60, seed: int = 7
):
    """
    Each client gets mostly a single label (binary); draws with replacement as needed.
    """
    rng = np.random.default_rng(seed)
    X0, y0 = X[y == 0], y[y == 0]
    X1, y1 = X[y == 1], y[y == 1]
    shards = []
    for cid in range(num_clients):
        maj = 0 if (cid % 2 == 0) else 1
        k_major = int(min_per_client * 0.8)
        k_minor = max(1, min_per_client - k_major)

        if maj == 0:
            idx0 = rng.integers(0, len(X0), size=k_major)
            idx1 = rng.integers(0, len(X1), size=k_minor)
            Xi = np.vstack([X0[idx0], X1[idx1]])
            yi = np.hstack([y0[idx0], y1[idx1]])
        else:
            idx0 = rng.integers(0, len(X0), size=k_minor)
            idx1 = rng.integers(0, len(X1), size=k_major)
            Xi = np.vstack([X0[idx0], X1[idx1]])
            yi = np.hstack([y0[idx0], y1[idx1]])

        p = rng.permutation(len(yi))
        shards.append((Xi[p], yi[p]))

    return shards


def shard_dirichlet(
    X: np.ndarray, y: np.ndarray,
    num_clients: int, alpha: float = 0.3, seed: int = 7,
    min_samples_per_class_per_client: int = 1
):
    """
    Dirichlet label distribution partition.
    Smaller alpha => stronger non-IID.
    Ensures each client has at least min_samples_per_class_per_client for each class.
    """
    rng = np.random.default_rng(seed)
    classes = np.unique(y)
    if len(classes) != 2:
        raise ValueError("shard_dirichlet expects binary classification labels (0 and 1).")

    client_shards_data = [[] for _ in range(num_clients)]
    client_shards_labels = [[] for _ in range(num_clients)]

    for c in classes:
        idx_c = np.where(y == c)[0]
        rng.shuffle(idx_c)

        if len(idx_c) < num_clients * min_samples_per_class_per_client:
            raise ValueError(f"Not enough samples for class {c} to satisfy "
                             f"min_samples_per_class_per_client={min_samples_per_class_per_client} "
                             f"for {num_clients} clients.")

        # 1. Distribute minimum required samples for current class 'c'
        remaining_idx_c = list(idx_c)
        for i in range(num_clients):
            client_shards_data[i].extend(X[remaining_idx_c[:min_samples_per_class_per_client]].tolist())
            client_shards_labels[i].extend(y[remaining_idx_c[:min_samples_per_class_per_client]].tolist())
            remaining_idx_c = remaining_idx_c[min_samples_per_class_per_client:]

        # 2. Distribute the rest using Dirichlet
        props = rng.dirichlet([alpha] * num_clients)
        counts = (props * len(remaining_idx_c)).astype(int)

        # Fix rounding drift
        while counts.sum() < len(remaining_idx_c):
            counts[rng.integers(0, num_clients)] += 1
        while counts.sum() > len(remaining_idx_c):
            j = rng.integers(0, num_clients)
            if counts[j] > 0:
                counts[j] -= 1

        start = 0
        for i in range(num_clients):
            take = counts[i]
            if take > 0:
                client_shards_data[i].extend(X[remaining_idx_c[start:start+take]].tolist())
                client_shards_labels[i].extend(y[remaining_idx_c[start:start+take]].tolist())
            start += take

    shards = []
    for i in range(num_clients):
        Xi = np.array(client_shards_data[i])
        yi = np.array(client_shards_labels[i])
        p = rng.permutation(len(yi))
        shards.append((Xi[p], yi[p]))

    return shards


# ============================================================
# QNN builder (built-ins only)
# ============================================================
def build_estimator_qnn(
    num_features: int,
    fm_reps: int = 1,
    an_reps: int = 2
):
    fm = ZZFeatureMap(feature_dimension=num_features, reps=fm_reps)
    an = RealAmplitudes(num_qubits=num_features, reps=an_reps, entanglement="linear")

    obs = SparsePauliOp.from_list([("Z" + "I" * (num_features - 1), 1.0)])
    est = StatevectorEstimator()

    grad = None
    if ParamShiftEstimatorGradient is not None:
        try:
            grad = ParamShiftEstimatorGradient(est)
        except Exception:
            grad = None

    qnn = EstimatorQNN(
        circuit=fm.compose(an),
        observables=obs,
        input_params=sorted(list(fm.parameters), key=lambda p: p.name),
        weight_params=list(an.parameters),
        estimator=est,
        gradient=grad,
    )
    return qnn


def make_classifier(
    num_features: int,
    initial_point: Optional[np.ndarray],
    maxiter: int = 10,
    optimizer: str = "SPSA",
    opt_kwargs: Optional[dict] = None
):
    qnn = build_estimator_qnn(num_features)
    opt_kwargs = {} if opt_kwargs is None else dict(opt_kwargs)

    if optimizer.upper() == "SPSA":
        if SPSA is None:
            raise RuntimeError("SPSA not available; install/update qiskit-algorithms.")
        try:
            opt = SPSA(maxiter=maxiter, **opt_kwargs)
        except TypeError:
            opt = SPSA(maxiter=maxiter)
    elif optimizer.upper() == "COBYLA":
        if COBYLA is None:
            raise RuntimeError("COBYLA not available; install/update qiskit-algorithms.")
        opt = COBYLA(maxiter=maxiter, **opt_kwargs) if opt_kwargs else COBYLA(maxiter=maxiter)
    else:
        raise ValueError(f"Unknown optimizer '{optimizer}'.")

    clf = NeuralNetworkClassifier(
        neural_network=qnn,
        optimizer=opt,
        initial_point=initial_point,
        loss="cross_entropy",
        one_hot=False,
    )
    return clf


def extract_weights(clf: NeuralNetworkClassifier) -> np.ndarray:
    for attr in ("fit_result_", "fit_result"):
        if hasattr(clf, attr):
            res = getattr(clf, attr)
            if hasattr(res, "x"):
                return np.asarray(res.x, float).copy()
            if isinstance(res, dict) and "x" in res:
                return np.asarray(res["x"], float).copy()

    if hasattr(clf, "weights_"):
        return np.asarray(getattr(clf, "weights_"), float).copy()

    raise RuntimeError("Could not extract weights; check qiskit-machine-learning version.")


# ============================================================
# Circular helpers
# ============================================================
def wrap_to_pi(theta: np.ndarray) -> np.ndarray:
    return (theta + np.pi) % (2 * np.pi) - np.pi




# ============================================================
# QoS signals for QFL aggregation (fidelity, latency, instability)
# ============================================================
def angular_instability_sigma2(w_local: np.ndarray, w_global: np.ndarray) -> float:
    """A simple angular instability proxy: mean squared wrapped update magnitude."""
    diff = wrap_to_pi(np.asarray(w_local, float) - np.asarray(w_global, float))
    return float(np.mean(diff * diff))

def _bind_qnn_params(qnn_ref: EstimatorQNN, x_scaled: np.ndarray, w: np.ndarray) -> Dict:
    """Build a parameter binding dict for qnn_ref.circuit."""
    bind = {}
    for p, val in zip(qnn_ref.input_params, np.asarray(x_scaled, float).reshape(-1)):
        bind[p] = float(val)
    for p, val in zip(qnn_ref.weight_params, np.asarray(w, float).reshape(-1)):
        bind[p] = float(val)
    return bind

def avg_state_fidelity(
    qnn_ref: EstimatorQNN,
    X_ref: np.ndarray,
    w_a: np.ndarray,
    w_b: np.ndarray,
    x_scale: float = np.pi,
) -> float:
    """Average pure-state fidelity between |psi(x; w_a)> and |psi(x; w_b)> over a small reference set.

    Notes:
      - This uses Statevector simulation; on hardware you would typically estimate fidelity via tomography/shadow
        methods or proxy channel metrics.
      - In this codebase, it is primarily a *relative* QoS indicator for ablations.
    """
    if X_ref is None or len(X_ref) == 0:
        return 1.0

    circ = qnn_ref.circuit
    vals = []
    for x in np.asarray(X_ref, float):
        xa = x * x_scale
        psi_a = Statevector.from_instruction(circ.assign_parameters(_bind_qnn_params(qnn_ref, xa, w_a), inplace=False))
        psi_b = Statevector.from_instruction(circ.assign_parameters(_bind_qnn_params(qnn_ref, xa, w_b), inplace=False))
        ov = np.vdot(psi_a.data, psi_b.data)
        vals.append(float(np.abs(ov) ** 2))
    return float(np.mean(vals))

def compute_qos_scores(
    local_ws: List[np.ndarray],
    w_global: np.ndarray,
    local_times: List[float],
    qnn_ref: EstimatorQNN,
    X_ref: Optional[np.ndarray],
    cfg: "FLConfig",
) -> Dict[str, np.ndarray]:
    """Compute per-client QoS scores q_i,t from fidelity F, latency tau, and instability sigma^2.

    q_i,t = F_i,t^alpha / ((tau_i,t + eps)^gamma * (sigma_i,t^2 + eps)^delta)
    """
    K = len(local_ws)
    tau = np.asarray(local_times, float).reshape(K)
    sig2 = np.asarray([angular_instability_sigma2(wi, w_global) for wi in local_ws], float)
    F = np.asarray([avg_state_fidelity(qnn_ref, X_ref, w_global, wi, x_scale=cfg.x_scale) for wi in local_ws], float)

    q = (F ** float(cfg.qos_alpha)) / (
        (tau + float(cfg.qos_eps)) ** float(cfg.qos_gamma) *
        (sig2 + float(cfg.qos_eps)) ** float(cfg.qos_delta)
    )
    q = np.clip(q, float(cfg.qos_clip_min), float(cfg.qos_clip_max))
    return {"F": F, "tau": tau, "sig2": sig2, "q": q}

# ============================================================
# Aggregation helpers
# ============================================================
def agg_fedavg(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    return np.mean(np.stack(local_ws, axis=0), axis=0)


def agg_fedavg_weighted(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    W = np.asarray(shard_sizes, float)
    W = W / (W.sum() + 1e-12)
    A = np.stack(local_ws, axis=0)
    return (W[:, None] * A).sum(axis=0)


def agg_circular_weighted(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    """
    Review simplification:
      Treat all ansatz weights as angular parameters.
      For RealAmplitudes, weights correspond to rotation angles.
    """
    W = np.asarray(shard_sizes, float)
    W = W / (W.sum() + 1e-12)
    Theta = np.stack([wrap_to_pi(w) for w in local_ws], axis=0)
    S = np.sum(W[:, None] * np.sin(Theta), axis=0)
    C = np.sum(W[:, None] * np.cos(Theta), axis=0)
    return np.arctan2(S, C)


AGG_MAP = {
    "fedavg": agg_fedavg,
    "fedavg_weighted": agg_fedavg_weighted,
    "linear_weighted": agg_fedavg_weighted,      # alias for clarity in plots/papers
    "circular_weighted": agg_circular_weighted,
}


# ============================================================
# QNN evaluation
# ============================================================
def qnn_predict_proba(qnn: EstimatorQNN, X: np.ndarray, w: np.ndarray, x_scale: float = np.pi) -> np.ndarray:
    Xs = np.asarray(X, float) * x_scale
    exp = qnn.forward(Xs, w)
    exp = np.asarray(exp, float).reshape(-1)
    p1 = (1.0 - exp) / 2.0
    return np.vstack([1.0 - p1, p1]).T


def eval_candidate(qnn_ref: EstimatorQNN, X_eval, y_eval, w, x_scale):
    probs = qnn_predict_proba(qnn_ref, X_eval, w, x_scale=x_scale)
    yhat = np.argmax(probs, axis=1)
    acc = float(accuracy_score(y_eval, yhat))
    loss = float(log_loss(y_eval, probs, labels=[0, 1]))
    return acc, loss


# ============================================================
# Geometry diagnostics + risk + rho
# ============================================================
def _unwrap_to_ref(A: np.ndarray) -> np.ndarray:
    K, D = A.shape
    out = A.copy()
    ref = out[0]
    for i in range(1, K):
        delta = out[i] - ref
        out[i] = ref + ((delta + np.pi) % (2*np.pi) - np.pi)
    return out


def _min_covering_arc_length(angles_1d: np.ndarray) -> float:
    a = np.sort(angles_1d)
    gaps = np.diff(a, append=a[0] + 2*np.pi)
    max_gap = np.max(gaps)
    return float(2*np.pi - max_gap)


def _geodesic_sse(angles: np.ndarray, center: np.ndarray, W: np.ndarray) -> float:
    diff = wrap_to_pi(angles - center)
    sse_per_client = np.sum(diff**2, axis=1)
    return float(np.sum(W * sse_per_client))


def angle_diagnostics(local_ws: List[np.ndarray], shard_sizes: List[int]) -> dict:
    A_raw = np.stack(local_ws, axis=0)
    A = wrap_to_pi(A_raw)

    W = np.asarray(shard_sizes, float)
    W = W / (W.sum() + 1e-12)

    C = np.sum(W[:, None] * np.cos(A), axis=0)
    S = np.sum(W[:, None] * np.sin(A), axis=0)
    R = np.sqrt(C**2 + S**2)

    R_mean = float(np.mean(R))
    R_min  = float(np.min(R))

    mu_circ = np.arctan2(S, C)

    A_unwrap = _unwrap_to_ref(A)
    mu_lin_unwrapped = np.sum(W[:, None] * A_unwrap, axis=0)
    mu_lin = wrap_to_pi(mu_lin_unwrapped)

    sse_circ = _geodesic_sse(A, mu_circ, W)
    sse_lin  = _geodesic_sse(A, mu_lin,  W)
    sse_gap  = float(sse_lin - sse_circ)

    cover_lengths = np.array([_min_covering_arc_length(A[:, j]) for j in range(A.shape[1])])
    straddle_frac = float(np.mean(cover_lengths > np.pi))

    gap = np.abs(wrap_to_pi(mu_lin - mu_circ))
    disagreement_rad = float(np.mean(gap))
    disagreement_deg = float(np.rad2deg(disagreement_rad))

    return {
        "R_mean": R_mean,
        "R_min": R_min,
        "straddle_frac": straddle_frac,
        "sse_geo_gap": sse_gap,
        "disagreement_rad": disagreement_rad,
        "disagreement_deg": disagreement_deg,
        "mu_circ": mu_circ,
        "mu_lin":  mu_lin,
    }


def geometry_risk_from_diag(diag: dict, a: float = 0.50, b: float = 0.30, c: float = 0.20) -> float:
    R_mean = float(diag.get("R_mean", 0.0))
    straddle = float(diag.get("straddle_frac", 0.0))
    gap = float(diag.get("sse_geo_gap", 0.0))

    gap_pos = max(0.0, gap)
    gap_term = gap_pos / (gap_pos + 1.0)

    risk = a * (1.0 - R_mean) + b * straddle + c * gap_term
    return float(np.clip(risk, 0.0, 1.0))


def rho_from_risk(risk: float, lam: float = 8.0, center: float = 0.5) -> float:
    return float(1.0 / (1.0 + np.exp(-lam * (risk - center))))


# ============================================================
# Streaming CSV helpers
# ============================================================
def append_row_csv(row: Dict[str, Any], csv_path: Path):
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame([row])
    header = not csv_path.exists()
    df.to_csv(csv_path, mode="a", header=header, index=False)


def append_rows_csv(rows: List[Dict[str, Any]], csv_path: Path):
    if not rows:
        return
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame(rows)
    header = not csv_path.exists()
    df.to_csv(csv_path, mode="a", header=header, index=False)


# ============================================================
# Federated loop config
# ============================================================
@dataclass
class FLConfig:
    num_clients: int = 10
    rounds: int = 50
    seed: int = 2025

    # partition: "iid", "label_skew", "dirichlet"
    partition: str = "label_skew"
    dirichlet_alpha: float = 0.3

    # agg: "linear_weighted", "circular_weighted", "auto_pick"
    agg_mode: str = "circular_weighted"

    maxiter_local: int = 10
    x_scale: float = np.pi
    optimizer: str = "SPSA"
    opt_kwargs: Optional[dict] = None

    # Validation split inside training
    val_size: float = 0.15
    val_seed: int = 42

    # AutoPick decision policy
    autopick_policy: str = "loss"  # "loss" or "geometry"
    rho_threshold: float = 0.5

    # Regime amplifiers
    local_init_noise: float = 0.0  # try 0.05–0.2

    angle_stress: bool = False
    angle_stress_shift: float = 2.8  # radians (~160°)

    # New parameter to ensure clients have both classes in Dirichlet partitioning
    min_samples_per_class_per_client: int = 1

    # Streaming save options
    stream_save: bool = True
    save_clients_each_round: bool = False



    # QoS-aware aggregation (for ablations): weights from fidelity, latency, and instability
    use_qos_weights: bool = True
    qos_alpha: float = 1.0     # fidelity exponent
    qos_gamma: float = 1.0     # latency exponent
    qos_delta: float = 1.0     # instability exponent
    qos_eps: float = 1e-6
    qos_ref_samples: int = 5   # number of reference inputs for fidelity (kept small for speed)
    qos_combine_with_data: bool = True  # multiply QoS score by shard_size weights
    qos_clip_min: float = 1e-6
    qos_clip_max: float = 1e6

# ============================================================
# Federated training loop
# ============================================================
def run_federated_qnn_builtins(
    X_train, y_train, X_test, y_test,
    cfg: FLConfig,
    rounds_csv_path: Optional[Path] = None,
    clients_csv_path: Optional[Path] = None,
) -> Dict[str, Any]:

    rng = np.random.default_rng(cfg.seed)

    if cfg.agg_mode != "auto_pick" and cfg.agg_mode not in AGG_MAP:
        raise ValueError(
            f"Unknown agg_mode={cfg.agg_mode}. Choose from {list(AGG_MAP.keys())} + ['auto_pick']"
        )
    aggregator = AGG_MAP.get(cfg.agg_mode, None)

    # ---- Create validation set from training ----
    X_tr, X_val, y_tr, y_val = split_train_val(
        X_train, y_train, val_size=cfg.val_size, seed=cfg.val_seed
    )


    # ---- Reference inputs for fidelity (fixed across rounds for comparability) ----
    X_ref = None
    if cfg.use_qos_weights:
        m_ref = int(min(max(1, cfg.qos_ref_samples), len(X_val)))
        ref_idx = rng.choice(len(X_val), size=m_ref, replace=False)
        X_ref = X_val[ref_idx]

    # ---- Client shards are built from TRAIN-ONLY subset ----
    if cfg.partition == "iid":
        shards = shard_iid(X_tr, y_tr, cfg.num_clients, seed=cfg.seed)
    elif cfg.partition == "dirichlet":
        shards = shard_dirichlet(
            X_tr, y_tr, cfg.num_clients,
            alpha=cfg.dirichlet_alpha, seed=cfg.seed,
            min_samples_per_class_per_client=cfg.min_samples_per_class_per_client
        )
    else:
        shards = shard_label_skew(
            X_tr, y_tr, cfg.num_clients,
            min_per_client=max(40, len(X_tr)//cfg.num_clients),
            seed=cfg.seed
        )

    num_features = X_train.shape[1]

    qnn_ref = build_estimator_qnn(num_features, fm_reps=1, an_reps=2)
    D = len(qnn_ref.weight_params)

    # Initialize global weights
    w_global = 0.10 * rng.standard_normal(D)

    rounds_rows: List[Dict[str, Any]] = []
    client_rows_all: List[Dict[str, Any]] = []

    print(
        f"=== QFL | agg='{cfg.agg_mode}' | policy='{cfg.autopick_policy}' | "
        f"partition='{cfg.partition}' (α={cfg.dirichlet_alpha}) | "
        f"clients={cfg.num_clients} rounds={cfg.rounds} | "
        f"init_noise={cfg.local_init_noise} | stress={cfg.angle_stress} ==="
    )

    for r in range(1, cfg.rounds + 1):
        t0 = time.perf_counter()

        local_ws: List[np.ndarray] = []
        shard_sizes: List[int] = []
        local_times: List[float] = []

        client_rows_this_round: List[Dict[str, Any]] = []

        # ---------------------------
        # Local training per client
        # ---------------------------
        for cid, (Xi, yi) in enumerate(shards):

            # client-specific init noise to encourage divergence
            w0 = w_global.copy()
            if cfg.local_init_noise > 0:
                w0 = w0 + cfg.local_init_noise * rng.standard_normal(len(w0))

            clf = make_classifier(
                num_features=num_features,
                initial_point=w0,
                maxiter=cfg.maxiter_local,
                optimizer=cfg.optimizer,
                opt_kwargs=cfg.opt_kwargs,
            )

            t_c0 = time.perf_counter()
            clf.fit(Xi * cfg.x_scale, yi)
            dt_c = time.perf_counter() - t_c0
            local_times.append(float(dt_c))

            w_local = extract_weights(clf)
            if len(w_local) != D:
                raise RuntimeError(f"[Round {r}] Client {cid}: weight dim mismatch ({len(w_local)} != {D})")

            # controlled geometry stress-test (mechanism demo)
            if cfg.angle_stress:
                direction = +1.0 if (cid % 2 == 0) else -1.0
                w_local = wrap_to_pi(w_local + direction * cfg.angle_stress_shift)

            local_ws.append(w_local)
            shard_sizes.append(len(yi))

            # --- optional client eval ---
            probs_tr_loc = qnn_predict_proba(qnn_ref, Xi,     w_local, x_scale=cfg.x_scale)
            probs_te_loc = qnn_predict_proba(qnn_ref, X_test, w_local, x_scale=cfg.x_scale)
            yhat_tr_loc = np.argmax(probs_tr_loc, axis=1)
            yhat_te_loc = np.argmax(probs_te_loc, axis=1)

            c_row = {
                "round": r,
                "client_id": cid,
                "shard_size": int(len(yi)),
                "train_acc_local": float(accuracy_score(yi, yhat_tr_loc)),
                "test_acc_local":  float(accuracy_score(y_test, yhat_te_loc)),
                "train_loss_local": float(log_loss(yi,     probs_tr_loc, labels=[0, 1])),
                "test_loss_local":  float(log_loss(y_test, probs_te_loc, labels=[0, 1])),
                "time_sec_local":   float(dt_c),
            }
            client_rows_this_round.append(c_row)


        # ---------------------------
        # QoS weighting (optional)
        # ---------------------------
        weights_for_agg = np.asarray(shard_sizes, float)

        qos = None
        if cfg.use_qos_weights:
            qos = compute_qos_scores(
                local_ws=local_ws,
                w_global=w_global,
                local_times=local_times,
                qnn_ref=qnn_ref,
                X_ref=X_ref,
                cfg=cfg,
            )
            if cfg.qos_combine_with_data:
                weights_for_agg = weights_for_agg * qos["q"]
            else:
                weights_for_agg = qos["q"].copy()

            # normalize client weights for logging
            wnorm = weights_for_agg / (weights_for_agg.sum() + 1e-12)
            for i, crow in enumerate(client_rows_this_round):
                crow.update({
                    "fidelity": float(qos["F"][i]),
                    "instability_sigma2": float(qos["sig2"][i]),
                    "qos_score": float(qos["q"][i]),
                    "agg_weight": float(wnorm[i]),
                })

        # ---------------------------
        # Diagnostics (before aggregation)
        # ---------------------------
        diag = angle_diagnostics(local_ws, weights_for_agg.tolist())
        r_t = geometry_risk_from_diag(diag)
        rho_t = rho_from_risk(r_t)

        print(
            f"    [diag] R̄={diag['R_mean']:.3f} | Rmin={diag['R_min']:.3f} | "
            f"straddle_frac={diag['straddle_frac']:.2f} | ΔSSE_geo={diag['sse_geo_gap']:.4f} | "
            f"geom_err={diag['disagreement_deg']:.2f}° | r_t={r_t:.3f} | ρ_t={rho_t:.3f}"
        )

        # ---------------------------
        # Aggregate to new global
        # ---------------------------
        picked = None

        # Candidate stats (for logging)
        val_acc_lin = val_loss_lin = None
        val_acc_circ = val_loss_circ = None
        test_acc_lin = test_loss_lin = None
        test_acc_circ = test_loss_circ = None

        if cfg.agg_mode == "auto_pick":
            # candidates
            w_lin  = agg_fedavg_weighted(local_ws, weights_for_agg.tolist())
            w_circ = agg_circular_weighted(local_ws, weights_for_agg.tolist())

            # evaluate on validation for selection logging
            val_acc_lin,  val_loss_lin  = eval_candidate(qnn_ref, X_val, y_val, w_lin,  cfg.x_scale)
            val_acc_circ, val_loss_circ = eval_candidate(qnn_ref, X_val, y_val, w_circ, cfg.x_scale)

            # also evaluate on test for reporting
            test_acc_lin,  test_loss_lin  = eval_candidate(qnn_ref, X_test, y_test, w_lin,  cfg.x_scale)
            test_acc_circ, test_loss_circ = eval_candidate(qnn_ref, X_test, y_test, w_circ, cfg.x_scale)

            policy = str(cfg.autopick_policy).lower().strip()

            if policy == "loss":
                use_circ = (val_loss_circ <= val_loss_lin)
                rule_used = "loss(val)"
            elif policy == "geometry":
                use_circ = (rho_t >= cfg.rho_threshold)
                rule_used = "geometry(ρ)"
            else:
                raise ValueError("autopick_policy must be 'loss' or 'geometry'")

            w_global = w_circ if use_circ else w_lin
            picked = "circular" if use_circ else "linear"

            print(
                f"    ↳ AutoPick-{rule_used} chose: {picked} | "
                f"val_loss_lin={val_loss_lin:.4f}, val_loss_circ={val_loss_circ:.4f} | "
                f"ρ_t={rho_t:.3f}"
            )
        else:
            w_global = aggregator(local_ws, weights_for_agg.tolist())  # type: ignore
            picked = cfg.agg_mode
            print(f"    ↳ agg used: {cfg.agg_mode}")

        # ---------------------------
        # Global evaluation
        # ---------------------------
        probs_tr = qnn_predict_proba(qnn_ref, X_tr,   w_global, x_scale=cfg.x_scale)
        probs_va = qnn_predict_proba(qnn_ref, X_val,  w_global, x_scale=cfg.x_scale)
        probs_te = qnn_predict_proba(qnn_ref, X_test, w_global, x_scale=cfg.x_scale)

        ytr_pred = np.argmax(probs_tr, axis=1)
        yva_pred = np.argmax(probs_va, axis=1)
        yte_pred = np.argmax(probs_te, axis=1)

        train_acc  = float(accuracy_score(y_tr,   ytr_pred))
        val_acc    = float(accuracy_score(y_val,  yva_pred))
        test_acc   = float(accuracy_score(y_test, yte_pred))

        train_loss = float(log_loss(y_tr,   probs_tr, labels=[0, 1]))
        val_loss   = float(log_loss(y_val,  probs_va, labels=[0, 1]))
        test_loss  = float(log_loss(y_test, probs_te, labels=[0, 1]))

        dt = time.perf_counter() - t0

        print(
            f"[Round {r:02d}] TrainAcc={train_acc:.3f} | ValAcc={val_acc:.3f} | TestAcc={test_acc:.3f} | "
            f"TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | TestLoss={test_loss:.4f} | "
            f"Time={dt:.2f}s | picked={picked}"
        )

        # ---------------------------
        # Per-round row
        # ---------------------------
        row = {
            "round": r,
            "train_acc": train_acc,
            "val_acc": val_acc,
            "test_acc": test_acc,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "test_loss": test_loss,
            "time_sec": float(dt),

            # QoS summaries (when enabled)
            "qos_enabled": bool(cfg.use_qos_weights),
            "qos_alpha": float(cfg.qos_alpha),
            "qos_gamma": float(cfg.qos_gamma),
            "qos_delta": float(cfg.qos_delta),
            "mean_fidelity": None,
            "mean_latency": None,
            "mean_instability": None,

            "agg_mode": cfg.agg_mode,
            "autopick_policy": cfg.autopick_policy if cfg.agg_mode == "auto_pick" else "fixed",
            "picked": picked,

            # partition + regime knobs for traceability
            "partition": cfg.partition,
            "dirichlet_alpha": float(cfg.dirichlet_alpha),
            "local_init_noise": float(cfg.local_init_noise),
            "angle_stress": bool(cfg.angle_stress),
            "angle_stress_shift": float(cfg.angle_stress_shift),

            # geometry diagnostics
            "resultant_mean": float(diag["R_mean"]),
            "resultant_min":  float(diag["R_min"]),
            "straddle_frac":  float(diag["straddle_frac"]),
            "delta_SSE_geo":  float(diag["sse_geo_gap"]),
            "geom_err_deg":   float(diag["disagreement_deg"]),
            "geom_err_rad":   float(diag["disagreement_rad"]),
            "risk_t": float(r_t),
            "rho_t":  float(rho_t),

            # selection threshold (for traceability)
            "rho_threshold": float(cfg.rho_threshold),
            "val_size": float(cfg.val_size),
            "min_samples_per_class_per_client": float(cfg.min_samples_per_class_per_client)
        }

        # candidate logs (useful for plotting why choices differ)
        if cfg.agg_mode == "auto_pick":
            row.update({
                "val_loss_lin":  None if val_loss_lin  is None else float(val_loss_lin),
                "val_loss_circ": None if val_loss_circ is None else float(val_loss_circ),
                "val_acc_lin":   None if val_acc_lin   is None else float(val_acc_lin),
                "val_acc_circ":  None if val_acc_circ  is None else float(val_acc_circ),

                "test_loss_lin":  None if test_loss_lin  is None else float(test_loss_lin),
                "test_loss_circ": None if test_loss_circ is None else float(test_loss_circ),
                "test_acc_lin":   None if test_acc_lin   is None else float(test_acc_lin),
                "test_acc_circ":  None if test_acc_circ  is None else float(test_acc_circ),
            })


        # Fill QoS summary statistics
        if cfg.use_qos_weights and qos is not None:
            row["mean_fidelity"] = float(np.mean(qos["F"])) # Removed one ')'
            row["mean_latency"] = float(np.mean(qos["tau"])) # Removed one ')'
            row["mean_instability"] = float(np.mean(qos["sig2"])) # Removed one ')'

        rounds_rows.append(row)

        # ---------------------------
        # Streaming save
        # ---------------------------
        if cfg.stream_save and rounds_csv_path is not None:
            append_row_csv(row, rounds_csv_path)

        if cfg.save_clients_each_round and clients_csv_path is not None:
            append_rows_csv(client_rows_this_round, clients_csv_path)

        client_rows_all.extend(client_rows_this_round)

    return {
        "w_global": w_global,
        "rows": rounds_rows,
        "client_rows": client_rows_all,
        "config": cfg.__dict__,
    }


# ============================================================
# Save metadata (once)
# ============================================================
def save_meta(meta: Dict[str, Any], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(meta, f, indent=2)


# ============================================================
# Main (configuration-driven)
# ============================================================
if __name__ == "__main__":
    import random
    np.random.seed(42)

    # ---------------- USER CHOICES ----------------
    FAST_CPU_MODE = True

    # Multi-seed protocol (5 is a standard minimum; use 10 for stronger evidence)
    SEEDS = [0, 1, 2, 3, 4]


    # Choose dataset:
    DATASET_CHOICE = "clinical"
    # options: "breast_cancer", "clinical", "both", "biomarker_leukemia", "biomarker_arcene", "motion_nuscenes_mini", or "openml:<id>"

    # Feature dimension (= number of qubits). 4 is a good CPU default.
    PCA_K = 4
    # Cap samples to keep Qiskit simulation practical (set None to disable).
    MAX_SAMPLES = 400

    # nuScenes (requires downloaded data + nuscenes-devkit). Used when DATASET_CHOICE == "motion_nuscenes_mini"
    NUSCENES_ROOT = os.environ.get("NUSCENES_ROOT", "./nuscenes")
    NUSCENES_VERSION = "v1.0-mini"
    NUSCENES_CATEGORY_PREFIX = "vehicle"


    # Aggregators you want to compare
    AGG_MODES_TO_RUN = [
        #"linear_weighted",
        "circular_weighted",
        #"auto_pick",
    ]

    # AutoPick policies to run (only used when agg_mode == "auto_pick")
    AUTOPICK_POLICIES = ["loss", "geometry"]

    # Partition regimes
    PARTITIONS_TO_RUN = [
        #("label_skew", None),
         #Dirichlet sweeps: include multiple alphas to show regime shifts
        #("dirichlet", 5.0),
        #("dirichlet", 0.5),
        ("dirichlet", 0.1),
    ]

    # Regime amplifiers
    LOCAL_INIT_NOISE_LIST = [0.1]  # 0.0 baseline, 0.1 encourages divergence
    ANGLE_STRESS_LIST = [False]  # True = controlled mechanism demo

    NUM_CLIENTS = 10
    ROUNDS = 200
    MAXITER_LOCAL = 20
    X_SCALE = np.pi

    if FAST_CPU_MODE:
        # Practical CPU defaults (safe to expand later)
        NUM_CLIENTS = 10
        ROUNDS = 200
        MAXITER_LOCAL = 20
        X_SCALE = np.pi  # keep angles in [-pi, pi]

    # Validation split inside training
    VAL_SIZE = 0.15
    VAL_SEED = 42

    # Rho threshold for geometry policy
    RHO_THRESHOLD = 0.5

    # Clinical CSV path
    CLINICAL_CSV = "/content/drive/MyDrive/data/BrEaST-Lesions-USG-Clinical.csv"

    # Optimizer defaults
    OPT_BREAST = ("SPSA", {"learning_rate": 0.05, "perturbation": 0.1})
    OPT_CLIN   = ("COBYLA", None)

    # Stream-save options
    STREAM_SAVE = True
    SAVE_CLIENTS_EACH_ROUND = False
    CLEAN_OLD_CSVS = True

    def set_all_seeds(seed: int):
        random.seed(seed)
        np.random.seed(seed)
        try:
            import torch
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        except Exception:
            pass

    # Keep dataset split fixed for fair method comparison (recommended).
    # If want split to vary across seeds, set DATA_SPLIT_SEED = seed inside the seed loop.
    DATA_SPLIT_SEED = 42


    # ---------------- INTERNAL HELPERS ----------------
    def outdir_for(dataset_key: str, agg_mode: str, policy: str,
               partition: str, alpha: Optional[float],
               init_noise: float, stress: bool, seed: int) -> Path:

      alpha_tag = "" if alpha is None else f"_a{alpha}"
      stress_tag = "_stress" if stress else ""
      noise_tag = f"_n{init_noise}"
      seed_tag = f"_seed{seed}"

      name = (
        f"QFL_folder/newQosMotivation/withQOS{dataset_key}_"
        f"{agg_mode}_"
        f"{policy}_"
        f"{partition}{alpha_tag}{noise_tag}{stress_tag}{seed_tag}"
      )
      return get_outdir_in_drive(name)


    def rounds_csv_path(outdir: Path, dataset_key: str) -> Path:
        return outdir / f"round_metrics_{dataset_key}_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"

    def clients_csv_path(outdir: Path, dataset_key: str) -> Path:
        return outdir / f"client_metrics_{dataset_key}_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"

    def maybe_clean(*paths: Path):
        if not CLEAN_OLD_CSVS:
            return
        for p in paths:
            if p.exists():
                p.unlink()

    def run_one(dataset_key: str,
            X_train, X_test, y_train, y_test,
            agg_mode: str,
            policy: str,
            partition: str,
            alpha: Optional[float],
            init_noise: float,
            stress: bool,
            optimizer_name: str,
            opt_kwargs: Optional[dict],
            seed: int):

        outdir = outdir_for(dataset_key, agg_mode, policy, partition, alpha, init_noise, stress, seed)
        r_csv = rounds_csv_path(outdir, dataset_key)
        c_csv = clients_csv_path(outdir, dataset_key)

        maybe_clean(r_csv, c_csv)


        cfg = FLConfig(
            num_clients=NUM_CLIENTS,
            rounds=ROUNDS,
            seed=seed,          # <<<< changed
            val_seed=42,        # keep fixed for comparable validation split (optional)


            partition=partition,
            dirichlet_alpha=float(alpha) if alpha is not None else 0.3,

            agg_mode=agg_mode,
            maxiter_local=MAXITER_LOCAL,
            x_scale=X_SCALE,
            optimizer=optimizer_name,
            opt_kwargs=opt_kwargs,

            val_size=VAL_SIZE,


            autopick_policy=policy,
            rho_threshold=RHO_THRESHOLD,

            local_init_noise=init_noise,
            angle_stress=stress,
            angle_stress_shift=2.8,

            min_samples_per_class_per_client=1, # Ensure this is set

            stream_save=STREAM_SAVE,
            save_clients_each_round=SAVE_CLIENTS_EACH_ROUND,
        )

        print(f"\n=== RUN | {dataset_key} | agg={agg_mode} | policy={policy} | "
              f"partition={partition} | alpha={alpha} | init_noise={init_noise} | stress={stress} ===")

        out = run_federated_qnn_builtins(
            X_train, y_train, X_test, y_test,
            cfg,
            rounds_csv_path=r_csv,
            clients_csv_path=c_csv if SAVE_CLIENTS_EACH_ROUND else None
        )

        save_meta({"dataset": dataset_key, **out["config"]}, outdir / "run_metadata.json")
        np.save(outdir / "w_global.npy", out["w_global"])

        print(f"Saved rounds CSV to: {r_csv}")
        if SAVE_CLIENTS_EACH_ROUND:
            print(f"Saved client CSV to: {c_csv}")

        return out

    # ---------------- DATASET RUNNERS ----------------
    def run_breast():
        X_train, X_test, y_train, y_test = load_breast_cancer(pca_k=PCA_K, test_size=0.30, seed=42)
        opt_name, opt_kwargs = OPT_BREAST

        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one("breast_cancer", X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs, seed=42) # Added seed=42
                        else:
                            # policy label for fixed modes
                            run_one("breast_cancer", X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs, seed=42) # Added seed=42

    def run_clinical():
        X_train, X_test, y_train, y_test = load_clinical_csv(
            CLINICAL_CSV, test_size=0.45, pca_k=PCA_K, seed=DATA_SPLIT_SEED
        )
        opt_name, opt_kwargs = OPT_CLIN

        for seed in SEEDS:
            set_all_seeds(seed)

            for partition, alpha in PARTITIONS_TO_RUN:
                for init_noise in LOCAL_INIT_NOISE_LIST:
                    for stress in ANGLE_STRESS_LIST:
                        for agg_mode in AGG_MODES_TO_RUN:
                            if agg_mode == "auto_pick":
                                for pol in AUTOPICK_POLICIES:
                                    run_one("clinical", X_train, X_test, y_train, y_test,
                                            agg_mode, pol, partition, alpha, init_noise, stress,
                                            opt_name, opt_kwargs, seed)
                            else:
                                run_one("clinical", X_train, X_test, y_train, y_test,
                                        agg_mode, "fixed", partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs, seed)


    # ---------------- EXECUTE CHOSEN ----------------

    def run_biomarker_leukemia():
        X_train, X_test, y_train, y_test = load_biomarker_leukemia(
            pca_k=PCA_K, test_size=0.30, seed=42, max_samples=MAX_SAMPLES
        )
        opt_name, opt_kwargs = OPT_BREAST  # SPSA is typically robust for small QNNs
        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one("biomarker_leukemia", X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs, seed=42) # Added seed=42
                        else:
                            run_one("biomarker_leukemia", X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs, seed=42) # Added seed=42

    def run_biomarker_arcene():
        X_train, X_test, y_train, y_test = load_biomarker_arcene(
            pca_k=PCA_K, test_size=0.30, seed=42, max_samples=MAX_SAMPLES
        )
        opt_name, opt_kwargs = OPT_BREAST
        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one("biomarker_arcene", X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs, seed=42) # Added seed=42
                        else:
                            run_one("biomarker_arcene", X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs, seed=42) # Added seed=42

    def run_motion_nuscenes_mini():
        X_train, X_test, y_train, y_test = load_motion_nuscenes_mini(
            dataroot=NUSCENES_ROOT,
            version=NUSCENES_VERSION,
            category_prefix=NUSCENES_CATEGORY_PREFIX,
            pca_k=PCA_K,
            test_size=0.30,
            seed=42,
            max_samples=MAX_SAMPLES,
        )
        opt_name, opt_kwargs = OPT_BREAST
        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one("motion_nuscenes_mini", X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs, seed=42) # Added seed=42
                        else:
                            run_one("motion_nuscenes_mini", X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs, seed=42) # Added seed=42

    def run_openml_id(openml_id: int):
        X_train, X_test, y_train, y_test = load_openml_binary(
            openml_id, pca_k=PCA_K, test_size=0.30, seed=42, max_samples=MAX_SAMPLES
        )
        opt_name, opt_kwargs = OPT_BREAST
        tag = f"openml_{openml_id}"
        for partition, alpha in PARTITIONS_TO_RUN:
            for init_noise in LOCAL_INIT_NOISE_LIST:
                for stress in ANGLE_STRESS_LIST:
                    for agg_mode in AGG_MODES_TO_RUN:
                        if agg_mode == "auto_pick":
                            for pol in AUTOPICK_POLICIES:
                                run_one(tag, X_train, X_test, y_train, y_test,
                                        agg_mode, pol, partition, alpha, init_noise, stress,
                                        opt_name, opt_kwargs, seed=42) # Added seed=42
                        else:
                            run_one(tag, X_train, X_test, y_train, y_test,
                                    agg_mode, "fixed", partition, alpha, init_noise, stress,
                                    opt_name, opt_kwargs, seed=42) # Added seed=42
    if DATASET_CHOICE == "breast_cancer":
        run_breast()
    elif DATASET_CHOICE == "clinical":
        try:
            run_clinical()
        except FileNotFoundError:
            print(f"[WARN] Clinical CSV not found at: {CLINICAL_CSV}.")
    elif DATASET_CHOICE == "both":
        run_breast()
        try:
            run_clinical()
        except FileNotFoundError:
            print(f"[WARN] Clinical CSV not found at: {CLINICAL_CSV}.")
    elif DATASET_CHOICE == "biomarker_leukemia":
        run_biomarker_leukemia()
    elif DATASET_CHOICE == "biomarker_arcene":
        run_biomarker_arcene()
    elif DATASET_CHOICE == "motion_nuscenes_mini":
        run_motion_nuscenes_mini()
    elif isinstance(DATASET_CHOICE, str) and DATASET_CHOICE.startswith("openml:"):
        openml_id = int(DATASET_CHOICE.split(":", 1)[1])
        run_openml_id(openml_id)
    else:
        raise ValueError(
            "DATASET_CHOICE must be: 'breast_cancer', 'clinical', 'both', "
            "'biomarker_leukemia', 'biomarker_arcene', 'motion_nuscenes_mini', or 'openml:<id>'."
        )

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

=== RUN | clinical | agg=circular_weighted | policy=fixed | partition=dirichlet | alpha=0.1 | init_noise=0.1 | stress=False ===
=== QFL | agg='circular_weighted' | policy='fixed' | partition='dirichlet' (α=0.1) | clients=10 rounds=200 | init_noise=0.1 | stress=False ===
    [diag] R̄=0.857 | Rmin=0.777 | straddle_frac=0.00 | ΔSSE_geo=-0.0121 | geom_err=1.36° | r_t=0.072 | ρ_t=0.031
    ↳ agg used: circular_weighted
[Round 01] TrainAcc=0.427 | ValAcc=0.389 | TestAcc=0.526 | TrainLoss=0.7354 | ValLoss=0.6930 | TestLoss=0.6956 | Time=20.90s | picked=circular_weighted
    [diag] R̄=0.895 | Rmin=0.769 | straddle_frac=0.00 | ΔSSE_geo=-0.0070 | geom_err=0.97° | r_t=0.052 | ρ_t=0.027
    ↳ agg used: circular_weighted
[Round 02] TrainAcc=0.458 | ValAcc=0.389 | TestAcc=0.537 | TrainLoss=0.7224 | ValLoss=0.6938 | TestLoss=0.6911 | Time=12.41s | picked=circular_weighted

In [8]:
import json
import numpy as np
import pandas as pd
from pathlib import Path

# Adjust if your Drive base differs
BASE = Path("/content/drive/MyDrive/AutopickV2/QFL_folder/newQosMotivation/Seeds")

TAIL_K = 20
SAFE_Q = 0.25

ACC = "val_acc"
TEST = "test_acc"
RISK = "risk_t"
GEOM = "geom_err_deg"

def auc_norm(df, col):
    x = df["round"].to_numpy(float)
    y = df[col].to_numpy(float)
    m = np.isfinite(x) & np.isfinite(y)
    x, y = x[m], y[m]
    if len(x) < 2:
        return np.nan
    idx = np.argsort(x)
    x, y = x[idx], y[idx]
    auc = np.trapz(y, x)
    denom = x.max() - x.min()
    return float(auc / denom) if denom > 0 else np.nan

def method_label(meta):
    # Preferred: use meta flags so it works even if filenames change
    if bool(meta.get("use_qos_weights", False)):
        return "QoS"
    mode = meta.get("agg_mode", "")
    if mode == "linear_weighted":
        return "Linear"
    if mode == "circular_weighted":
        return "Circular"
    if mode == "auto_pick":
        return f"AutoPick:{meta.get('autopick_policy','?')}"
    return str(mode)

# ---- load all runs ----
runs = []
for meta_path in BASE.rglob("run_metadata.json"):
    outdir = meta_path.parent
    meta = json.loads(meta_path.read_text())

    # find the round-metrics csv saved by rounds_csv_path(...)
    csvs = sorted(outdir.glob("round_metrics_*_clients*_rounds*.csv"))
    if not csvs:
        continue
    df = pd.read_csv(csvs[0]).sort_values("round")
    df = df[df["round"] >= 1].reset_index(drop=True)

    runs.append((meta, df))

if not runs:
    raise RuntimeError(f"No runs found under: {BASE}")

# ---- compute pooled safe threshold per (dataset, partition, alpha, noise, stress) ----
pool_rows = []
for meta, df in runs:
    if RISK in df.columns:
        pool_rows.append({
            "dataset": meta.get("dataset"),
            "partition": meta.get("partition"),
            "alpha": meta.get("dirichlet_alpha", None),
            "noise": meta.get("local_init_noise", None),
            "stress": meta.get("angle_stress", None),
            "risk": df[RISK].to_numpy(float)
        })

pool_expanded = []
for r in pool_rows:
    for v in r["risk"]:
        if np.isfinite(v):
            pool_expanded.append({k: r[k] for k in ["dataset","partition","alpha","noise","stress"]} | {"risk": v})
pool_expanded = pd.DataFrame(pool_expanded)

thr_map = (
    pool_expanded
    .groupby(["dataset","partition","alpha","noise","stress"])["risk"]
    .quantile(SAFE_Q)
    .to_dict()
)

# ---- per-seed summaries ----
rows = []
for meta, df in runs:
    dataset  = meta.get("dataset")
    part     = meta.get("partition")
    alpha    = meta.get("dirichlet_alpha", None)
    noise    = meta.get("local_init_noise", None)
    stress   = meta.get("angle_stress", None)
    seed     = meta.get("seed")
    method   = method_label(meta)

    key = (dataset, part, alpha, noise, stress)
    thr = float(thr_map.get(key, np.nan))

    row = {
        "dataset": dataset,
        "partition": part,
        "alpha": alpha,
        "noise": noise,
        "stress": stress,
        "seed": seed,
        "method": method,
        "rounds": len(df),
        "risk_thr_q": thr,
    }

    if ACC in df.columns:
        row["val_best"] = float(df[ACC].max())
        row["val_tail"] = float(df[ACC].tail(TAIL_K).mean())
        row["val_auc_norm"] = auc_norm(df, ACC)

    if TEST in df.columns:
        row["test_best"] = float(df[TEST].max())
        row["test_tail"] = float(df[TEST].tail(TAIL_K).mean())
        row["test_auc_norm"] = auc_norm(df, TEST)

    if RISK in df.columns:
        row["risk_mean"] = float(df[RISK].mean())
        row["risk_max"]  = float(df[RISK].max())
        row["risk_auc_norm"] = auc_norm(df, RISK)
        row[f"safe_frac_q{SAFE_Q}"] = float((df[RISK] < thr).mean()) if np.isfinite(thr) else np.nan

    if GEOM in df.columns:
        row["geom_mean"] = float(df[GEOM].mean())
        row["geom_max"]  = float(df[GEOM].max())
        row["geom_auc_norm"] = auc_norm(df, GEOM)

    rows.append(row)

per_seed = pd.DataFrame(rows)

# ---- aggregate across seeds per scenario + method ----
def mean_std_worst(g, col):
    return pd.Series({
        f"{col}_mean": g[col].mean(),
        f"{col}_std":  g[col].std(ddof=1),
        f"{col}_worst": g[col].min(),   # worst-case robustness (min across seeds)
    })

agg_cols = ["val_best","val_tail","test_best","test_tail","risk_mean","geom_mean",f"safe_frac_q{SAFE_Q}"]

out = []
group_keys = ["dataset","partition","alpha","noise","stress","method"]
for keys, g in per_seed.groupby(group_keys):
    row = dict(zip(group_keys, keys))
    row["n_seeds"] = g["seed"].nunique()
    for c in agg_cols:
        if c in g.columns:
            row.update(mean_std_worst(g, c).to_dict())
    out.append(row)

summary = pd.DataFrame(out).sort_values(["dataset","partition","alpha","noise","stress","method"])

display(per_seed)
display(summary)

per_seed.to_csv("multiseed_per_seed.csv", index=False)
summary.to_csv("multiseed_summary.csv", index=False)

print("Saved: multiseed_per_seed.csv, multiseed_summary.csv")
print("\nLaTeX:\n")
print(summary.round(4).to_latex(index=False))


Unnamed: 0,dataset,partition,alpha,noise,stress,seed,method,rounds,risk_thr_q,val_best,...,test_best,test_tail,test_auc_norm,risk_mean,risk_max,risk_auc_norm,safe_frac_q0.25,geom_mean,geom_max,geom_auc_norm
0,clinical,dirichlet,0.1,0.1,False,0,QoS,200,0.027677,0.722222,...,0.610526,0.490526,0.488892,0.042073,0.122033,0.042036,0.19,0.760379,2.224961,0.759615
1,clinical,dirichlet,0.1,0.1,False,1,QoS,200,0.027677,0.722222,...,0.652632,0.478947,0.477017,0.037908,0.081755,0.037858,0.245,0.722543,1.780151,0.722091
2,clinical,dirichlet,0.1,0.1,False,2,QoS,200,0.027677,0.777778,...,0.621053,0.541053,0.499497,0.041188,0.232786,0.041075,0.2,0.822161,12.376022,0.820083
3,clinical,dirichlet,0.1,0.1,False,3,QoS,200,0.027677,0.888889,...,0.6,0.488421,0.49188,0.030447,0.083407,0.030326,0.415,0.571136,2.039784,0.568629
4,clinical,dirichlet,0.1,0.1,False,4,QoS,200,0.027677,0.777778,...,0.578947,0.476316,0.46247,0.040341,0.110587,0.040297,0.2,0.765645,3.076758,0.764786


Unnamed: 0,dataset,partition,alpha,noise,stress,method,n_seeds,val_best_mean,val_best_std,val_best_worst,...,test_tail_worst,risk_mean_mean,risk_mean_std,risk_mean_worst,geom_mean_mean,geom_mean_std,geom_mean_worst,safe_frac_q0.25_mean,safe_frac_q0.25_std,safe_frac_q0.25_worst
0,clinical,dirichlet,0.1,0.1,False,QoS,5,0.777778,0.068041,0.722222,...,0.476316,0.038391,0.004704,0.030447,0.728373,0.094825,0.571136,0.25,0.09467,0.19


Saved: multiseed_per_seed.csv, multiseed_summary.csv

LaTeX:

\begin{tabular}{llrrrlrrrrrrrrrrrrrrrrrrrrrr}
\toprule
dataset & partition & alpha & noise & stress & method & n_seeds & val_best_mean & val_best_std & val_best_worst & val_tail_mean & val_tail_std & val_tail_worst & test_best_mean & test_best_std & test_best_worst & test_tail_mean & test_tail_std & test_tail_worst & risk_mean_mean & risk_mean_std & risk_mean_worst & geom_mean_mean & geom_mean_std & geom_mean_worst & safe_frac_q0.25_mean & safe_frac_q0.25_std & safe_frac_q0.25_worst \\
\midrule
clinical & dirichlet & 0.100000 & 0.100000 & False & QoS & 5 & 0.777800 & 0.068000 & 0.722200 & 0.518300 & 0.027300 & 0.475000 & 0.612600 & 0.027200 & 0.578900 & 0.495100 & 0.026400 & 0.476300 & 0.038400 & 0.004700 & 0.030400 & 0.728400 & 0.094800 & 0.571100 & 0.250000 & 0.094700 & 0.190000 \\
\bottomrule
\end{tabular}



In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# After running the main cell:
# - set DATASET_CHOICE etc. in the script configuration section (near the bottom),
# - then rerun the last "EXECUTE CHOSEN" block, or simply rerun the main cell.
#
# Sanity check:
try:
    _ = FLConfig().use_qos_weights
    print("OK: FLConfig has QoS fields.")
except Exception as e:
    print("QoS fields missing:", e)
