<a href="https://colab.research.google.com/github/shanikairoshi/QFL_Experiments/blob/main/Autopick_Clinical_QFL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install qiskit qiskit-machine-learning scikit-learn

Collecting qiskit
  Downloading qiskit-2.1.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting qiskit-machine-learning
  Downloading qiskit_machine_learning-0.8.3-py3-none-any.whl.metadata (13 kB)
Collecting rustworkx>=0.15.0 (from qiskit)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting stevedore>=3.0.0 (from qiskit)
  Downloading stevedore-5.5.0-py3-none-any.whl.metadata (2.2 kB)
Collecting qiskit
  Downloading qiskit-1.4.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting symengine<0.14,>=0.11 (from qiskit)
  Downloading symengine-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Downloading qiskit_machine_learning-0.8.3-py3-none-any.whl (231 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m231.9/231.9 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading qiskit-1.4.4-cp39-abi3-manylinux_

In [None]:
!pip install qiskit-algorithms

Collecting qiskit-algorithms
  Downloading qiskit_algorithms-0.4.0-py3-none-any.whl.metadata (4.7 kB)
Downloading qiskit_algorithms-0.4.0-py3-none-any.whl (327 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m327.8/327.8 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: qiskit-algorithms
Successfully installed qiskit-algorithms-0.4.0


In [None]:
# qfl_qiskit_builtin.py
# Quantum Federated Learning with in-built Qiskit ML (EstimatorQNN + NeuralNetworkClassifier).
# Saves per-round train/test accuracy, loss, and time to CSV in Drive (or HOME fallback).

from __future__ import annotations
import os, json, time
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

# ------------------------
# Qiskit ML (built-ins)
# ------------------------
from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
from qiskit.quantum_info import SparsePauliOp, Statevector
from qiskit.primitives import StatevectorEstimator

from qiskit_machine_learning.neural_networks import EstimatorQNN
try:
    # Newer path
    from qiskit_machine_learning.algorithms.classifiers import NeuralNetworkClassifier
except Exception:
    # Older path
    from qiskit_machine_learning.algorithms import NeuralNetworkClassifier

# optional gradient (silence "creating a gradient function" warning if available)
try:
    from qiskit_algorithms.gradients import ParamShiftEstimatorGradient
except Exception:
    try:
        from qiskit.algorithms.gradients import ParamShiftEstimatorGradient
    except Exception:
        ParamShiftEstimatorGradient = None

# ------------------------
# Optimizer (robust import)
# ------------------------
try:
    from qiskit_algorithms.optimizers import COBYLA
except Exception:
    try:
        from qiskit.algorithms.optimizers import COBYLA
    except Exception:
        COBYLA = None
try:
    from qiskit_algorithms.optimizers import SPSA
except Exception:
    try:
        from qiskit.algorithms.optimizers import SPSA
    except Exception:
        SPSA = None

# ------------------------
# Sklearn
# ------------------------
from sklearn import datasets
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss


# ============================================================
# Drive helper
# ============================================================
def get_outdir_in_drive(subdir: str) -> Path:
    """
    If running in Colab, mount Drive and save under /content/drive/MyDrive/<subdir>.
    Otherwise, fallback to ~/ (HOME)/<subdir>.
    """
    try:
        from google.colab import drive  # type: ignore
        drive.mount("/content/drive", force_remount=False)
        base = Path("/content/drive/MyDrive")
    except Exception:
        base = Path.home()
    out = base / subdir
    out.mkdir(parents=True, exist_ok=True)
    return out


# ============================================================
# Data loading
# ============================================================
def load_breast_cancer(pca_k: Optional[int] = 4, test_size: float = 0.30, seed: int = 42
                       ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    data = datasets.load_breast_cancer()
    X, y = data.data.astype(float), data.target.astype(int)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)
    if pca_k is not None:
        p = PCA(n_components=int(pca_k), random_state=seed)
        X_train = p.fit_transform(X_train)
        X_test = p.transform(X_test)
    return X_train, X_test, y_train, y_test


def load_clinical_csv(csv_path: str,
                      test_size: float = 0.45,
                      pca_k: int = 2,
                      seed: int = 42
                      ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Minimal, robust preprocessing for BrEaST-Lesions-USG-Clinical.csv
    → PCA to 2D (2 qubits by default).
    """
    df = pd.read_csv(csv_path)

    selected_features = ["Age", "Shape", "Echogenicity",
                         "Posterior_features", "Calcifications", "Classification"]
    df = df[selected_features]

    # keep clean rows
    df = df[
        (df["Age"] != "not available") &
        (~df["Shape"].isin(["not applicable"])) &
        (~df["Echogenicity"].isin(["not applicable"])) &
        (~df["Posterior_features"].isin(["not applicable"])) &
        (~df["Calcifications"].isin(["not applicable", "indefinable"])) &
        (df["Classification"].isin(["benign", "malignant"]))
    ].copy()

    df["Age"] = pd.to_numeric(df["Age"])

    for col in ["Shape", "Echogenicity", "Posterior_features", "Calcifications"]:
        df[col] = LabelEncoder().fit_transform(df[col])

    df["Label"] = df["Classification"].map({"benign": 0, "malignant": 1})
    df.drop(columns=["Classification"], inplace=True)

    X = df.drop(columns=["Label"]).values.astype(float)
    y = df["Label"].values.astype(int)

    sc = StandardScaler()
    X = sc.fit_transform(X)

    p = PCA(n_components=int(pca_k), random_state=seed)
    X = p.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    return X_train, X_test, y_train, y_test


# ============================================================
# Partition helpers (IID + simple label skew)
# ============================================================
def shard_iid(X: np.ndarray, y: np.ndarray, num_clients: int, seed: int = 7):
    rng = np.random.default_rng(seed)
    idx = rng.permutation(len(y))
    chunks = np.array_split(idx, num_clients)
    return [(X[c], y[c]) for c in chunks]


def shard_label_skew(X: np.ndarray, y: np.ndarray,
                     num_clients: int, min_per_client: int = 60, seed: int = 7):
    """
    Each client gets mostly a single label (binary); draws with replacement as needed.
    """
    rng = np.random.default_rng(seed)
    X0, y0 = X[y == 0], y[y == 0]
    X1, y1 = X[y == 1], y[y == 1]
    shards = []
    for cid in range(num_clients):
        maj = 0 if (cid % 2 == 0) else 1
        k_major = int(min_per_client * 0.8)
        k_minor = max(1, min_per_client - k_major)
        if maj == 0:
            idx0 = rng.integers(0, len(X0), size=k_major)
            idx1 = rng.integers(0, len(X1), size=k_minor)
            Xi = np.vstack([X0[idx0], X1[idx1]])
            yi = np.hstack([y0[idx0], y1[idx1]])
        else:
            idx0 = rng.integers(0, len(X0), size=k_minor)
            idx1 = rng.integers(0, len(X1), size=k_major)
            Xi = np.vstack([X0[idx0], X1[idx1]])
            yi = np.hstack([y0[idx0], y1[idx1]])
        p = rng.permutation(len(yi))
        shards.append((Xi[p], yi[p]))
    return shards


# ============================================================
# QNN builder (in-built classes only)
# ============================================================
def build_estimator_qnn(num_features: int,
                        fm_reps: int = 1,
                        an_reps: int = 2,
                        x_scale: float = np.pi):
    """
    EstimatorQNN with ZZFeatureMap + RealAmplitudes and observable Z on qubit 0.
    x_scale multiplies classical inputs before feeding to the feature map.
    """
    fm = ZZFeatureMap(feature_dimension=num_features, reps=fm_reps)
    an = RealAmplitudes(num_qubits=num_features, reps=an_reps, entanglement="linear")
    obs = SparsePauliOp.from_list([("Z" + "I" * (num_features - 1), 1.0)])
    est = StatevectorEstimator()
    grad = None
    if ParamShiftEstimatorGradient is not None:
        try:
            grad = ParamShiftEstimatorGradient(est)
        except Exception:
            grad = None

    qnn = EstimatorQNN(
        circuit=fm.compose(an),
        observables=obs,
        input_params=sorted(list(fm.parameters), key=lambda p: p.name),
        weight_params=list(an.parameters),
        estimator=est,
        gradient=grad,  # None is OK; Qiskit will auto-create a gradient if needed
    )
    return qnn, x_scale


def make_classifier(num_features: int,
                    initial_point: Optional[np.ndarray],
                    maxiter: int = 60,
                    optimizer: str = "SPSA",
                    opt_kwargs: Optional[dict] = None):
    qnn, _ = build_estimator_qnn(num_features)
    opt_kwargs = {} if opt_kwargs is None else dict(opt_kwargs)

    if optimizer.upper() == "SPSA":
        if SPSA is None:
            raise RuntimeError("SPSA not available; install/update qiskit-algorithms.")
        # Prefer to pass only maxiter (always supported); try extras if available
        try:
            opt = SPSA(maxiter=maxiter, **opt_kwargs)
        except TypeError:
            opt = SPSA(maxiter=maxiter)  # fallback if your version lacks those kwargs
    elif optimizer.upper() == "COBYLA":
        if COBYLA is None:
            raise RuntimeError("COBYLA not available; install/update qiskit-algorithms.")
        opt = COBYLA(maxiter=maxiter, **opt_kwargs) if opt_kwargs else COBYLA(maxiter=maxiter)
    else:
        raise ValueError(f"Unknown optimizer '{optimizer}'. Use 'SPSA' or 'COBYLA'.")

    clf = NeuralNetworkClassifier(
        neural_network=qnn,
        optimizer=opt,
        initial_point=initial_point,
        loss="cross_entropy",
        one_hot=False,  # pass 1-D labels
    )
    return clf


def extract_weights(clf: NeuralNetworkClassifier) -> np.ndarray:
    """
    Robustly grab the trained weights from NeuralNetworkClassifier across versions.
    """
    # Common containers
    for attr in ("fit_result_", "fit_result"):
        if hasattr(clf, attr):
            res = getattr(clf, attr)
            if hasattr(res, "x"):
                return np.asarray(res.x, float).copy()
            if isinstance(res, dict) and "x" in res:
                return np.asarray(res["x"], float).copy()
    # Some versions expose weights_ directly
    if hasattr(clf, "weights_"):
        w = getattr(clf, "weights_")
        return np.asarray(w, float).copy()
    raise RuntimeError("Could not extract weights from classifier; check qiskit-machine-learning version.")


# ============================================================
# Aggregation helpers
# ============================================================
def _wrap_to_pi(theta: np.ndarray) -> np.ndarray:
    return (theta + np.pi) % (2 * np.pi) - np.pi

def agg_fedavg(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    return np.mean(np.stack(local_ws, axis=0), axis=0)

def agg_fedavg_weighted(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    W = np.asarray(shard_sizes, float)
    W = W / W.sum()
    Wstack = np.stack(local_ws, axis=0)  # (K, D)
    return (W[:, None] * Wstack).sum(axis=0)

def agg_circular_weighted(local_ws: List[np.ndarray], shard_sizes: List[int]) -> np.ndarray:
    """
    Angle-aware weighted mean (good for rotation parameters).
    """
    W = np.asarray(shard_sizes, float); W = W / W.sum()
    Theta = np.stack([_wrap_to_pi(w) for w in local_ws], axis=0)  # (K, D)
    S = np.sum((W[:, None]) * np.sin(Theta), axis=0)
    C = np.sum((W[:, None]) * np.cos(Theta), axis=0)
    return np.arctan2(S, C)

AGG_MAP = {
    "fedavg": agg_fedavg,
    "fedavg_weighted": agg_fedavg_weighted,
    "circular_weighted": agg_circular_weighted,
}


# ============================================================
# Evaluation with QNN directly (no classifier needed)
# ============================================================
def qnn_predict_proba(qnn: EstimatorQNN, X: np.ndarray, w: np.ndarray, x_scale: float = np.pi) -> np.ndarray:
    """
    Convert EstimatorQNN expectation <Z0> to class probabilities.
    """
    Xs = np.asarray(X, float) * x_scale
    exp = qnn.forward(Xs, w)  # shape (N,)
    exp = np.asarray(exp, float).reshape(-1)
    p1 = (1.0 - exp) / 2.0
    probs = np.vstack([1.0 - p1, p1]).T
    return probs


# ============================================================
# Federated loop (built-ins only)
# ============================================================
@dataclass
class FLConfig:
    num_clients: int = 10
    rounds: int = 100
    seed: int = 2025
    partition: str = "label_skew"   # "iid" or "label_skew"
    agg_mode: str = "circular_weighted"  # "fedavg" | "fedavg_weighted" | "circular_weighted" | "auto_pick"
    maxiter_local: int = 60          # optimizer maxiter for each client
    x_scale: float = np.pi           # multiply classical inputs before FM
    optimizer: str = "COBYLA"        # "SPSA" | "COBYLA"
    opt_kwargs: dict = None          # optional hyperparameters

# ---------- Angle diagnostics helpers ----------

def _unwrap_to_ref(A: np.ndarray) -> np.ndarray:
    """
    Unwrap angles client-wise to be closest to the first client's angles (reference).
    A: shape (K, D) with values in (-pi, pi].
    Returns an unwrapped copy in R: each row i adjusted by +/- 2π so A[i]-A[0] ∈ (-π, π].
    """
    K, D = A.shape
    out = A.copy()
    ref = out[0]
    for i in range(1, K):
        delta = out[i] - ref
        out[i] = ref + ((delta + np.pi) % (2*np.pi) - np.pi)
    return out

def _min_covering_arc_length(angles_1d: np.ndarray) -> float:
    """
    Minimal arc length (in [0, 2π]) that covers all given angles on the circle.
    = 2π - max_gap, where max_gap is the largest gap between consecutive sorted angles (including wrap gap).
    """
    a = np.sort(angles_1d)
    gaps = np.diff(a, append=a[0] + 2*np.pi)
    max_gap = np.max(gaps)
    return float(2*np.pi - max_gap)

def _geodesic_sse(angles: np.ndarray, center: np.ndarray, W: np.ndarray) -> float:
    """
    Weighted sum of squared geodesic distances on the circle, computed coordinate-wise then summed.
    angles: (K, D) in (-π, π]; center: (D,) in (-π, π]; W: (K,) nonnegative, sum to 1.
    """
    # geodesic difference per client i, coord j is wrap(angles[i,j] - center[j])
    diff = _wrap_to_pi(angles - center)              # (K, D)
    sse_per_client = np.sum(diff**2, axis=1)         # (K,)
    return float(np.sum(W * sse_per_client))

def angle_diagnostics(local_ws: List[np.ndarray], shard_sizes: List[int]) -> dict:
    """
    Compute circular statistics and linear-vs-circular intrinsic error gap for this round.
    Returns a dict of scalar diagnostics.
    """
    A = np.stack(local_ws, axis=0)                   # (K, D), angles in (-π, π] if callers keep them wrapped
    K, D = A.shape

    # weights
    W = np.asarray(shard_sizes, float)
    W = W / W.sum()

    # Circular resultant per coordinate
    C = np.sum(W[:, None] * np.cos(A), axis=0)       # (D,)
    S = np.sum(W[:, None] * np.sin(A), axis=0)       # (D,)
    R = np.sqrt(C**2 + S**2)                         # (D,)
    R_mean = float(np.mean(R))
    R_min  = float(np.min(R))

    # Circular (weighted) mean vector
    mu_circ = np.arctan2(S, C)                       # (D,)

    # "Linear" mean via unwrap relative to reference then re-wrap
    A_unwrap = _unwrap_to_ref(A)                     # (K, D) real line
    mu_lin_unwrapped = np.sum(W[:, None] * A_unwrap, axis=0)  # (D,)
    mu_lin = _wrap_to_pi(mu_lin_unwrapped)

    # Intrinsic (geodesic) SSE for each center
    sse_circ = _geodesic_sse(A, mu_circ, W)
    sse_lin  = _geodesic_sse(A, mu_lin,  W)
    sse_gap  = float(sse_lin - sse_circ)             # >0 ⇒ circular has lower true torus error

    # Semicircle coverage / boundary straddle fraction by coordinate
    cover_lengths = np.array([_min_covering_arc_length(A[:, j]) for j in range(D)])  # (D,)
    semicircle_ok = cover_lengths <= np.pi
    straddle_frac = float(1.0 - np.mean(semicircle_ok))  # fraction NOT contained in any semicircle

    return {
        "R_mean": R_mean,
        "R_min": R_min,
        "straddle_frac": straddle_frac,
        "sse_geo_gap": sse_gap,      # global intrinsic SSE(linear) - SSE(circular)
        "mu_circ": mu_circ,          # vectors, if you want to inspect
        "mu_lin":  mu_lin,
    }

def run_federated_qnn_builtins(X_train, y_train, X_test, y_test, cfg: FLConfig
                               ) -> Dict[str, Any]:
    rng = np.random.default_rng(cfg.seed)

    # accept auto_pick as special mode
    if cfg.agg_mode != "auto_pick" and cfg.agg_mode not in AGG_MAP:
        raise ValueError(
            f"Unknown agg_mode={cfg.agg_mode}. Choose from {list(AGG_MAP.keys())} + ['auto_pick']"
        )
    aggregator = AGG_MAP.get(cfg.agg_mode, None)  # None when 'auto_pick'

    # client shards
    if cfg.partition == "iid":
        shards = shard_iid(X_train, y_train, cfg.num_clients, seed=cfg.seed)
    else:
        shards = shard_label_skew(
            X_train, y_train, cfg.num_clients,
            min_per_client=max(40, len(X_train)//cfg.num_clients),
            seed=cfg.seed
        )

    num_features = X_train.shape[1]

    # reference QNN for evaluation
    qnn_ref, _ = build_estimator_qnn(num_features, fm_reps=1, an_reps=2, x_scale=cfg.x_scale)
    D = len(qnn_ref.weight_params)

    # initialize global weights
    w_global = 0.10 * rng.standard_normal(D)

    # histories
    rounds_csv_rows: List[Dict[str, Any]] = []
    client_csv_rows: List[Dict[str, Any]] = []

    print(f"=== QFL using '{cfg.agg_mode}' | partition='{cfg.partition}' | clients={cfg.num_clients} rounds={cfg.rounds} ===")
    for r in range(1, cfg.rounds + 1):
        t0 = time.perf_counter()

        local_ws, shard_sizes = [], []

        # iterate clients
        for cid, (Xi, yi) in enumerate(shards):
            clf = make_classifier(
                num_features,
                initial_point=w_global,
                maxiter=cfg.maxiter_local,
                optimizer=cfg.optimizer,
                opt_kwargs=cfg.opt_kwargs,
            )

            # IMPORTANT: scale inputs for training consistently with evaluation
            t_c0 = time.perf_counter()
            clf.fit(Xi * cfg.x_scale, yi)   # one_hot=False in make_classifier, so 1-D labels are correct
            dt_c = time.perf_counter() - t_c0

            w_local = extract_weights(clf)
            if len(w_local) != D:
                raise RuntimeError(f"[Round {r}] Client {cid}: weight dim mismatch ({len(w_local)} != {D})")
            local_ws.append(w_local)
            shard_sizes.append(len(yi))

            # --- client eval (optional, for analysis) ---
            probs_tr_loc = qnn_predict_proba(qnn_ref, Xi,      w_local, x_scale=cfg.x_scale)
            probs_te_loc = qnn_predict_proba(qnn_ref, X_test,  w_local, x_scale=cfg.x_scale)
            yhat_tr_loc = np.argmax(probs_tr_loc, axis=1)
            yhat_te_loc = np.argmax(probs_te_loc, axis=1)

            client_csv_rows.append({
                "round": r,
                "client_id": cid,
                "shard_size": int(len(yi)),
                "train_acc_local": float(accuracy_score(yi, yhat_tr_loc)),
                "test_acc_local":  float(accuracy_score(y_test, yhat_te_loc)),
                "train_loss_local": float(log_loss(yi,     probs_tr_loc, labels=[0,1])),
                "test_loss_local":  float(log_loss(y_test, probs_te_loc, labels=[0,1])),
                "time_sec_local":   float(dt_c),
            })


                # ==== DIAGNOSTICS (compute before aggregation) ====
        diag = angle_diagnostics(local_ws, shard_sizes)
        # Pretty console print
        print(f"    [diag] R̄={diag['R_mean']:.3f} | Rmin={diag['R_min']:.3f} | "
              f"straddle_frac={diag['straddle_frac']:.2f} | ΔSSE_geo={diag['sse_geo_gap']:.4f}")


        # -------- aggregate to new global --------
        test_loss_lin = None
        test_loss_circ = None
        picked = None

        if cfg.agg_mode == "auto_pick":
            # try both
            w_lin  = agg_fedavg_weighted(local_ws, shard_sizes)
            w_circ = agg_circular_weighted(local_ws, shard_sizes)

            ## evaluate both on X_test (or use a separate val set)
            probs_lin  = qnn_predict_proba(qnn_ref, X_test, w_lin,  x_scale=cfg.x_scale)
            probs_circ = qnn_predict_proba(qnn_ref, X_test, w_circ, x_scale=cfg.x_scale)
            test_loss_lin  = log_loss(y_test, probs_lin,  labels=[0, 1])
            test_loss_circ = log_loss(y_test, probs_circ, labels=[0, 1])

            # pick lower-loss aggregate
            use_circ = test_loss_circ <= test_loss_lin
            w_global = w_circ if use_circ else w_lin
            picked = "circular" if use_circ else "linear"
            # print the decision
            print(f"    ↳ auto_pick chose: {picked} "
              f"(loss_lin={test_loss_lin:.4f}, loss_circ={test_loss_circ:.4f})")
        else:
            w_global = aggregator(local_ws, shard_sizes)
            # print which fixed aggregator we used
            print(f"    ↳ agg used: {cfg.agg_mode}")

        # -------- global evaluation --------
        probs_tr = qnn_predict_proba(qnn_ref, X_train, w_global, x_scale=cfg.x_scale)
        probs_te = qnn_predict_proba(qnn_ref, X_test,  w_global, x_scale=cfg.x_scale)
        ytr_pred = np.argmax(probs_tr, axis=1)
        yte_pred = np.argmax(probs_te, axis=1)

        train_acc = accuracy_score(y_train, ytr_pred)
        test_acc  = accuracy_score(y_test,  yte_pred)
        train_loss = log_loss(y_train, probs_tr, labels=[0,1])
        test_loss  = log_loss(y_test,  probs_te, labels=[0,1])

        dt = time.perf_counter() - t0
        # just before the big print
        if cfg.agg_mode == "auto_pick":
          suffix = f" | picked={picked}"
        else:
          suffix = f" | agg={cfg.agg_mode}"

        print(f"[Round {r:02d}] TrainAcc={train_acc:.3f} | TestAcc={test_acc:.3f} | "
            f"TrainLoss={train_loss:.4f} | TestLoss={test_loss:.4f} | Time={dt:.2f}s{suffix}")


        row = {
            "round": r,
            "train_acc": float(train_acc),
            "test_acc":  float(test_acc),
            "train_loss": float(train_loss),
            "test_loss":  float(test_loss),
            "time_sec": float(dt),

            # diagnostics
            "resultant_mean": float(diag["R_mean"]),
            "resultant_min":  float(diag["R_min"]),
            "straddle_frac":  float(diag["straddle_frac"]),
            "delta_SSE_geo":  float(diag["sse_geo_gap"]),
        }
        if cfg.agg_mode == "auto_pick":
            row.update({
                "test_loss_lin":  None if test_loss_lin  is None else float(test_loss_lin),
                "test_loss_circ": None if test_loss_circ is None else float(test_loss_circ),
                "picked": picked,
            })
            print(f"    [pick] chose={picked}  (loss_lin={test_loss_lin:.4f}, loss_circ={test_loss_circ:.4f})")
        else:
            row["picked"] = cfg.agg_mode
        rounds_csv_rows.append(row)

    return {
        "w_global": w_global,
        "rows": rounds_csv_rows,
        "client_rows": client_csv_rows,
        "config": cfg.__dict__,
    }


# ============================================================
# Save CSV + metadata
# ============================================================
def save_rounds_csv(rows: List[Dict[str, Any]], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame(rows)  # accept any extra columns
    base = ["round", "train_acc", "test_acc", "train_loss", "test_loss", "time_sec"]
    extras = [c for c in df.columns if c not in base]
    df = df[base + extras]
    df.to_csv(path, index=False)

def save_meta(meta: Dict[str, Any], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(meta, f, indent=2)

def save_client_rounds_csv(rows: List[Dict[str, Any]], path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    cols = [
        "round","client_id","shard_size",
        "train_acc_local","test_acc_local",
        "train_loss_local","test_loss_local",
        "time_sec_local"
    ]
    df = pd.DataFrame(rows, columns=cols)
    df.to_csv(path, index=False)


# ============================================================
# Main (run both datasets)
# ============================================================
if __name__ == "__main__":
    np.random.seed(42)

    # ---------------- CONFIG ----------------
    # Choose aggregator: "fedavg", "fedavg_weighted", "circular_weighted", or "auto_pick"
    AGGREGATOR = "auto_pick"   # compare circular vs weighted-fedavg each round and pick best
    NUM_CLIENTS = 10
    ROUNDS = 50
    MAXITER_LOCAL = 60
    X_SCALE = np.pi

    # Clinical CSV path (edit if needed)
    CLINICAL_CSV = "/content/drive/MyDrive/data/BrEaST-Lesions-USG-Clinical.csv"

    # Drive output folders
    OUTDIR_BC = get_outdir_in_drive(f"QFL_folder/ZZ_breast_cancer_{AGGREGATOR}")
    OUTDIR_CL = get_outdir_in_drive(f"QFL_folder/ZZ_clinical_{AGGREGATOR}")

    # ------------- Breast Cancer (sklearn) -------------
    Xtr, Xte, ytr, yte = load_breast_cancer(pca_k=5, test_size=0.30, seed=42)
    cfg_bc = FLConfig(num_clients=NUM_CLIENTS, rounds=ROUNDS, seed=2025,
                      partition="label_skew", agg_mode=AGGREGATOR,
                      maxiter_local=MAXITER_LOCAL, x_scale=X_SCALE,
                      optimizer="SPSA", opt_kwargs={"learning_rate": 0.05, "perturbation": 0.1})
    print("\n=== Breast Cancer (sklearn) ===")
    out_bc = run_federated_qnn_builtins(Xtr, ytr, Xte, yte, cfg_bc)
    csv_bc = OUTDIR_BC / f"round_metrics_bc_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"
    save_rounds_csv(out_bc["rows"], csv_bc)

    # (optional) per-client CSV
    csv_bc_clients = OUTDIR_BC / f"round_client_metrics_bc_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"
    save_client_rounds_csv(out_bc["client_rows"], csv_bc_clients)

    save_meta({"dataset": "sklearn_breast_cancer", **out_bc["config"]}, OUTDIR_BC / "run_metadata.json")
    np.save(OUTDIR_BC / "w_global.npy", out_bc["w_global"])
    print(f"Saved breast-cancer CSV to: {csv_bc}")

    # ------------- Clinical CSV -------------
    try:
        Xtr, Xte, ytr, yte = load_clinical_csv(CLINICAL_CSV, test_size=0.45, pca_k=5, seed=42)
        cfg_cl = FLConfig(num_clients=NUM_CLIENTS, rounds=ROUNDS, seed=2025,
                          partition="label_skew", agg_mode=AGGREGATOR,
                          maxiter_local=MAXITER_LOCAL, x_scale=X_SCALE,
                          optimizer="COBYLA")  # COBYLA is often stable for small circuits
        print("\n=== Clinical CSV (PCA→5D) ===")
        out_cl = run_federated_qnn_builtins(Xtr, ytr, Xte, yte, cfg_cl)
        csv_cl = OUTDIR_CL / f"round_metrics_clinical_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"
        save_rounds_csv(out_cl["rows"], csv_cl)

        csv_cl_clients = OUTDIR_CL / f"round_client_metrics_clinical_clients{NUM_CLIENTS}_rounds{ROUNDS}.csv"
        save_client_rounds_csv(out_cl["client_rows"], csv_cl_clients)

        save_meta({"dataset": "clinical_ultrasound_csv", "csv_path": CLINICAL_CSV, **out_cl["config"]},
                  OUTDIR_CL / "run_metadata.json")
        np.save(OUTDIR_CL / "w_global.npy", out_cl["w_global"])
        print(f"Saved clinical CSV to: {csv_cl}")
    except FileNotFoundError:
        print(f"[WARN] Clinical CSV not found at: {CLINICAL_CSV}. Skipping that run.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

=== Breast Cancer (sklearn) ===
=== QFL using 'auto_pick' | partition='label_skew' | clients=10 rounds=50 ===
    ↳ auto_pick chose: circular (loss_lin=0.7308, loss_circ=0.7301)
[Round 01] TrainAcc=0.497 | TestAcc=0.456 | TrainLoss=0.7145 | TestLoss=0.7278 | Time=252.19s | picked=circular
    ↳ auto_pick chose: linear (loss_lin=0.7167, loss_circ=0.7254)
[Round 02] TrainAcc=0.487 | TestAcc=0.462 | TrainLoss=0.6998 | TestLoss=0.7192 | Time=247.59s | picked=linear
    ↳ auto_pick chose: linear (loss_lin=0.6994, loss_circ=0.7053)
[Round 03] TrainAcc=0.500 | TestAcc=0.579 | TrainLoss=0.7147 | TestLoss=0.6988 | Time=244.61s | picked=linear
    ↳ auto_pick chose: linear (loss_lin=0.7059, loss_circ=0.7089)
[Round 04] TrainAcc=0.528 | TestA

In [None]:
# === QFL Round Metrics Plotter (columns: und, train_acc, test_acc, train_loss, test_loss, time_sec) ===
# Instructions:
#  1) Mount Drive in Colab.
#  2) Set BC_PATH and CL_PATH below.
#  3) Run to generate and save figures.

import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# ------------------------------------------------------------
# 1) Configure paths
# ------------------------------------------------------------
BC_PATH = "/content/drive/MyDrive/QFL_folder/breast_cancer_circular_weighted/round_metrics_bc_clients10_rounds20.csv"
CL_PATH = "/content/drive/MyDrive/QFL_folder/clinical_circular_weighted/round_metrics_clinical_clients10_rounds20.csv"

# ------------------------------------------------------------
# 2) Loader for your schema (robust to TSV/CSV and stray spaces)
# ------------------------------------------------------------
def load_metrics(csv_path: str) -> pd.DataFrame:
    """
    Expect columns: und, train_acc, test_acc, train_loss, test_loss, time_sec
    Returns standardized: Round, TrainAcc, TestAcc, TrainLoss, TestLoss, (optional) TimeSec
    """
    p = Path(csv_path)
    if not p.exists():
        raise FileNotFoundError(f"CSV not found: {csv_path}")

    # Let pandas sniff delimiter (works for CSV or TSV); keep default engine='c' if you prefer.
    df = pd.read_csv(p, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]

    # Minimal schema check
    required = {"round", "train_acc", "test_acc", "train_loss", "test_loss"}
    if not required.issubset(set(df.columns)):
        missing = sorted(required - set(df.columns))
        raise ValueError(f"Missing columns in {csv_path}: {missing}")

    # Construct rename map safely (no None target)
    rename_map = {
        "round": "Round",
        "train_acc": "TrainAcc",
        "test_acc": "TestAcc",
        "train_loss": "TrainLoss",
        "test_loss": "TestLoss",
    }
    if "time_sec" in df.columns:
        rename_map["time_sec"] = "TimeSec"

    df = df.rename(columns=rename_map)

    # Coerce numerics
    for k in ["Round", "TrainAcc", "TestAcc", "TrainLoss", "TestLoss"]:
        df[k] = pd.to_numeric(df[k], errors="coerce")

    # Drop incomplete rows and sort by round (just in case)
    df = (
        df.dropna(subset=["Round", "TrainAcc", "TestAcc", "TrainLoss", "TestLoss"])
          .sort_values("Round")
          .reset_index(drop=True)
    )

    keep = ["Round", "TrainAcc", "TestAcc", "TrainLoss", "TestLoss"]
    if "TimeSec" in df.columns:
        keep.append("TimeSec")
    return df[keep]

# ------------------------------------------------------------
# 3) Plotting utilities
# ------------------------------------------------------------
def plot_acc_loss(df: pd.DataFrame, title_prefix: str, ax_acc, ax_loss):
    ax_acc.plot(df["Round"], df["TrainAcc"], label="TrainAcc")
    ax_acc.plot(df["Round"], df["TestAcc"],  label="TestAcc")
    ax_acc.set_title(f"{title_prefix} – Accuracy")
    ax_acc.set_xlabel("Federated Round")
    ax_acc.set_ylabel("Accuracy")
    ax_acc.grid(True, alpha=0.3)
    ax_acc.legend()

    ax_loss.plot(df["Round"], df["TrainLoss"], label="TrainLoss")
    ax_loss.plot(df["Round"], df["TestLoss"],  label="TestLoss")
    ax_loss.set_title(f"{title_prefix} – Loss")
    ax_loss.set_xlabel("Federated Round")
    ax_loss.set_ylabel("Loss")
    ax_loss.grid(True, alpha=0.3)
    ax_loss.legend()

# ------------------------------------------------------------
# 4) Load, plot 2×2 panel, save
# ------------------------------------------------------------
bc_df = load_metrics(BC_PATH)
cl_df = load_metrics(CL_PATH)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
plot_acc_loss(bc_df, "Breast Cancer (sklearn)", axes[0, 0], axes[0, 1])
plot_acc_loss(cl_df, "Clinical (PCA→5D)",     axes[1, 0], axes[1, 1])
plt.tight_layout()

OUT_PANEL = "/content/drive/MyDrive/QFL_folder/qfl_round_metrics_panel.png"
plt.savefig(OUT_PANEL, dpi=200, bbox_inches="tight")
plt.show()
print(f"Saved 2×2 panel to: {OUT_PANEL}")

# ------------------------------------------------------------
# 5) Optional overlay comparisons (toggle as needed)
# ------------------------------------------------------------
def moving_average(s, k=3):
    return s.rolling(window=k, center=True, min_periods=1).mean()

SMOOTH = False
K = 3
def maybe(x): return moving_average(x, K) if SMOOTH else x

fig2, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Accuracy overlay
ax1.plot(bc_df["Round"], maybe(bc_df["TrainAcc"]), label="BC TrainAcc")
ax1.plot(bc_df["Round"], maybe(bc_df["TestAcc"]),  label="BC TestAcc")
ax1.plot(cl_df["Round"], maybe(cl_df["TrainAcc"]), label="Clinical TrainAcc")
ax1.plot(cl_df["Round"], maybe(cl_df["TestAcc"]),  label="Clinical TestAcc")
ax1.set_title("Accuracy Overlay")
ax1.set_xlabel("Federated Round"); ax1.set_ylabel("Accuracy"); ax1.grid(True, alpha=0.3); ax1.legend()

# Loss overlay
ax2.plot(bc_df["Round"], maybe(bc_df["TrainLoss"]), label="BC TrainLoss")
ax2.plot(bc_df["Round"], maybe(bc_df["TestLoss"]),  label="BC TestLoss")
ax2.plot(cl_df["Round"], maybe(cl_df["TrainLoss"]), label="Clinical TrainLoss")
ax2.plot(cl_df["Round"], maybe(cl_df["TestLoss"]),  label="Clinical TestLoss")
ax2.set_title("Loss Overlay")
ax2.set_xlabel("Federated Round"); ax2.set_ylabel("Loss"); ax2.grid(True, alpha=0.3); ax2.legend()

plt.tight_layout()
OUT_OVERLAY = "/content/drive/MyDrive/QFL_folder/qfl_round_metrics_overlay.png"
plt.savefig(OUT_OVERLAY, dpi=200, bbox_inches="tight")
plt.show()
print(f"Saved overlay figure to: {OUT_OVERLAY}")


In [None]:
# === Compare COBYLA (solid) vs SPSA (dotted) on BC & Clinical ===

import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# -----------------------------
# 1) Paths (edit if necessary)
# -----------------------------
BC_COBYLA_PATH = "/content/drive/MyDrive/QFL_folder/breast_cancer_circular_weighted/round_metrics_bc_clients10_rounds20.csv"
CL_COBYLA_PATH = "/content/drive/MyDrive/QFL_folder/clinical_circular_weighted/round_metrics_clinical_clients10_rounds20.csv"

BC_SPSA_PATH   = "/content/drive/MyDrive/QFL_folder/ZZ_SPSA_breast_cancer_circular_weighted/round_metrics_bc_clients10_rounds20.csv"
CL_SPSA_PATH   = "/content/drive/MyDrive/QFL_folder/ZZ_SPSA_clinical_circular_weighted/round_metrics_clinical_clients10_rounds20.csv"

# -----------------------------
# 2) Minimal robust CSV loader
# -----------------------------
def load_metrics(path: str) -> pd.DataFrame:
    df = pd.read_csv(path, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]

    # Normalize round column
    for cand in ["round", "Round", "Federated Round", "und"]:
        if cand in df.columns:
            df = df.rename(columns={cand: "Round"})
            break
    if "Round" not in df.columns:
        df["Round"] = range(1, len(df) + 1)

    # Normalize metric columns (case-insensitive)
    def ren_ci(src, dst):
        for c in list(df.columns):
            if c.lower() == src.lower():
                df.rename(columns={c: dst}, inplace=True)
                return
    for src, dst in [
        ("train_acc", "TrainAcc"),
        ("test_acc",  "TestAcc"),
        ("train_loss","TrainLoss"),
        ("test_loss", "TestLoss"),
    ]:
        ren_ci(src, dst)

    keep = ["Round", "TrainAcc", "TestAcc", "TrainLoss", "TestLoss"]
    for k in keep:
        df[k] = pd.to_numeric(df[k], errors="coerce")
    return df[keep].dropna().sort_values("Round").reset_index(drop=True)

# -----------------------------
# 3) Load all four datasets
# -----------------------------
bc_cob = load_metrics(BC_COBYLA_PATH)
cl_cob = load_metrics(CL_COBYLA_PATH)
bc_sps = load_metrics(BC_SPSA_PATH)
cl_sps = load_metrics(CL_SPSA_PATH)

# -----------------------------
# 4) Plot: solid=COBYLA, dotted=SPSA
# -----------------------------
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# (A) Breast Cancer – Accuracy
ax = axes[0, 0]
ax.plot(bc_cob["Round"], bc_cob["TrainAcc"], linestyle="-", label="COBYLA TrainAcc")
ax.plot(bc_cob["Round"], bc_cob["TestAcc"],  linestyle="-", label="COBYLA TestAcc")
ax.plot(bc_sps["Round"], bc_sps["TrainAcc"], linestyle=":", label="SPSA TrainAcc")
ax.plot(bc_sps["Round"], bc_sps["TestAcc"],  linestyle=":", label="SPSA TestAcc")
ax.set_title("Breast Cancer — Accuracy"); ax.set_xlabel("Round"); ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3); ax.legend()

# (B) Breast Cancer – Loss
ax = axes[0, 1]
ax.plot(bc_cob["Round"], bc_cob["TrainLoss"], linestyle="-", label="COBYLA TrainLoss")
ax.plot(bc_cob["Round"], bc_cob["TestLoss"],  linestyle="-", label="COBYLA TestLoss")
ax.plot(bc_sps["Round"], bc_sps["TrainLoss"], linestyle=":", label="SPSA TrainLoss")
ax.plot(bc_sps["Round"], bc_sps["TestLoss"],  linestyle=":", label="SPSA TestLoss")
ax.set_title("Breast Cancer — Loss"); ax.set_xlabel("Round"); ax.set_ylabel("Loss")
ax.grid(True, alpha=0.3); ax.legend()

# (C) Clinical — Accuracy
ax = axes[1, 0]
ax.plot(cl_cob["Round"], cl_cob["TrainAcc"], linestyle="-", label="COBYLA TrainAcc")
ax.plot(cl_cob["Round"], cl_cob["TestAcc"],  linestyle="-", label="COBYLA TestAcc")
ax.plot(cl_sps["Round"], cl_sps["TrainAcc"], linestyle=":", label="SPSA TrainAcc")
ax.plot(cl_sps["Round"], cl_sps["TestAcc"],  linestyle=":", label="SPSA TestAcc")
ax.set_title("Clinical — Accuracy"); ax.set_xlabel("Round"); ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3); ax.legend()

# (D) Clinical — Loss
ax = axes[1, 1]
ax.plot(cl_cob["Round"], cl_cob["TrainLoss"], linestyle="-", label="COBYLA TrainLoss")
ax.plot(cl_cob["Round"], cl_cob["TestLoss"],  linestyle="-", label="COBYLA TestLoss")
ax.plot(cl_sps["Round"], cl_sps["TrainLoss"], linestyle=":", label="SPSA TrainLoss")
ax.plot(cl_sps["Round"], cl_sps["TestLoss"],  linestyle=":", label="SPSA TestLoss")
ax.set_title("Clinical — Loss"); ax.set_xlabel("Round"); ax.set_ylabel("Loss")
ax.grid(True, alpha=0.3); ax.legend()

plt.tight_layout()
OUT = "/content/drive/MyDrive/QFL_folder/qfl_compare_cobyla_vs_spsa_bc_clinical.png"
plt.savefig(OUT, dpi=200, bbox_inches="tight")
plt.show()
print(f"Saved figure to: {OUT}")


In [None]:
# === Compare Aggregators: circular_weighted (solid) vs fedavg_weighted (dashed) ===
# - Uses round-level metrics where available.
# - For BC fedavg_weighted, if only client-level CSV is found, it plots client-mean approximations.

import os, glob
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# -----------------------------
# 1) Paths you provided
# -----------------------------
# circular_weighted (round-level)
BC_CIRC_PATH = "/content/drive/MyDrive/QFL_folder/breast_cancer_circular_weighted/round_metrics_bc_clients10_rounds20.csv"
CL_CIRC_PATH = "/content/drive/MyDrive/QFL_folder/clinical_circular_weighted/round_metrics_clinical_clients10_rounds20.csv"

# fedavg_weighted
# - Breast Cancer: you shared client-level file
BC_FEDAVG_CLIENT_PATH = "/content/drive/MyDrive/QFL_folder/ZZ_SPSA_breast_cancer_fedavg_weighted/round_client_metrics_bc_clients10_rounds20.csv"
# We'll try to auto-locate a round-level file in the same directory:
BC_FEDAVG_DIR = str(Path(BC_FEDAVG_CLIENT_PATH).parent)

# - Clinical: you shared round-level file
CL_FEDAVG_ROUND_PATH = "/content/drive/MyDrive/QFL_folder/ZZ_SPSA_clinical_fedavg_weighted/round_metrics_clinical_clients10_rounds20.csv"

# -----------------------------
# 2) Robust CSV loaders
# -----------------------------
def load_round_metrics(path: str) -> pd.DataFrame:
    """Load round-level CSV -> columns: Round, TrainAcc, TestAcc, TrainLoss, TestLoss."""
    df = pd.read_csv(path, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]

    # Normalize round column
    for cand in ["round", "Round", "Federated Round", "und"]:
        if cand in df.columns:
            df = df.rename(columns={cand: "Round"})
            break
    if "Round" not in df.columns:
        df["Round"] = range(1, len(df) + 1)

    # Normalize metrics (case-insensitive)
    def ren_ci(src, dst):
        for c in list(df.columns):
            if c.lower() == src.lower():
                df.rename(columns={c: dst}, inplace=True)
                return
    for src, dst in [("train_acc","TrainAcc"), ("test_acc","TestAcc"),
                     ("train_loss","TrainLoss"), ("test_loss","TestLoss")]:
        ren_ci(src, dst)

    keep = ["Round","TrainAcc","TestAcc","TrainLoss","TestLoss"]
    for k in keep:
        df[k] = pd.to_numeric(df[k], errors="coerce")
    return df[keep].dropna().sort_values("Round").reset_index(drop=True)

def load_client_metrics(path: str) -> pd.DataFrame:
    """
    Load client-level CSV and produce per-round MEANS:
    returns columns: Round, TrainAcc, TestAcc, TrainLoss, TestLoss (means across clients).
    """
    df = pd.read_csv(path, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]

    # Normalize columns
    col_map = {
        "round": "Round",
        "client_id": "client_id",
        "train_acc_local": "TrainAcc",
        "test_acc_local": "TestAcc",
        "train_loss_local": "TrainLoss",
        "test_loss_local": "TestLoss",
    }
    # tolerate case differences
    for k in list(col_map.keys()):
        if k not in df.columns:
            for c in df.columns:
                if c.lower() == k.lower():
                    col_map[c] = col_map.pop(k)
                    break

    df = df.rename(columns=col_map)
    required = {"Round","TrainAcc","TestAcc","TrainLoss","TestLoss"}
    if not required.issubset(set(df.columns)):
        missing = sorted(required - set(df.columns))
        raise ValueError(f"Client CSV missing required columns: {missing}")

    # numeric + group by round
    for k in ["Round","TrainAcc","TestAcc","TrainLoss","TestLoss"]:
        df[k] = pd.to_numeric(df[k], errors="coerce")
    grp = (df.dropna(subset=list(required))
             .groupby("Round", as_index=False)[["TrainAcc","TestAcc","TrainLoss","TestLoss"]]
             .mean())
    return grp.sort_values("Round").reset_index(drop=True)

def find_round_metrics_in_dir(dir_path: str, pattern="round_metrics_bc_clients*_rounds*.csv"):
    paths = glob.glob(os.path.join(dir_path, pattern))
    if not paths:
        return None
    # most recent
    paths.sort(key=os.path.getmtime, reverse=True)
    return paths[0]

# -----------------------------
# 3) Resolve fedavg(BC) round-level or fallback to client-mean
# -----------------------------
BC_FEDAVG_ROUND_PATH = find_round_metrics_in_dir(BC_FEDAVG_DIR)
bc_fedavg_from_client_mean = False

if BC_FEDAVG_ROUND_PATH and Path(BC_FEDAVG_ROUND_PATH).exists():
    bc_fedavg_df = load_round_metrics(BC_FEDAVG_ROUND_PATH)
else:
    # Fallback: compute client-mean approximation per round
    bc_fedavg_df = load_client_metrics(BC_FEDAVG_CLIENT_PATH)
    bc_fedavg_from_client_mean = True
    print("[INFO] BC fedavg_weighted round-level metrics not found; using client-mean approximation.")

# -----------------------------
# 4) Load the other three datasets
# -----------------------------
bc_circ_df = load_round_metrics(BC_CIRC_PATH)
cl_circ_df = load_round_metrics(CL_CIRC_PATH)
cl_fedavg_df = load_round_metrics(CL_FEDAVG_ROUND_PATH)

# -----------------------------
# 5) Plot: circular(solid) vs fedavg(dashed)
# -----------------------------
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# (A) Breast Cancer — Accuracy
ax = axes[0,0]
ax.plot(bc_circ_df["Round"], bc_circ_df["TrainAcc"], linestyle="-", label="circular TrainAcc")
ax.plot(bc_circ_df["Round"], bc_circ_df["TestAcc"],  linestyle="-", label="circular TestAcc")
label_suffix = " (client-mean approx)" if bc_fedavg_from_client_mean else ""
ax.plot(bc_fedavg_df["Round"], bc_fedavg_df["TrainAcc"], linestyle="--", label=f"fedavg TrainAcc{label_suffix}")
ax.plot(bc_fedavg_df["Round"], bc_fedavg_df["TestAcc"],  linestyle="--", label=f"fedavg TestAcc{label_suffix}")
ax.set_title("Breast Cancer — Accuracy"); ax.set_xlabel("Round"); ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3); ax.legend()

# (B) Breast Cancer — Loss
ax = axes[0,1]
ax.plot(bc_circ_df["Round"], bc_circ_df["TrainLoss"], linestyle="-", label="circular TrainLoss")
ax.plot(bc_circ_df["Round"], bc_circ_df["TestLoss"],  linestyle="-", label="circular TestLoss")
ax.plot(bc_fedavg_df["Round"], bc_fedavg_df["TrainLoss"], linestyle="--", label=f"fedavg TrainLoss{label_suffix}")
ax.plot(bc_fedavg_df["Round"], bc_fedavg_df["TestLoss"],  linestyle="--", label=f"fedavg TestLoss{label_suffix}")
ax.set_title("Breast Cancer — Loss"); ax.set_xlabel("Round"); ax.set_ylabel("Loss")
ax.grid(True, alpha=0.3); ax.legend()

# (C) Clinical — Accuracy
ax = axes[1,0]
ax.plot(cl_circ_df["Round"], cl_circ_df["TrainAcc"], linestyle="-", label="circular TrainAcc")
ax.plot(cl_circ_df["Round"], cl_circ_df["TestAcc"],  linestyle="-", label="circular TestAcc")
ax.plot(cl_fedavg_df["Round"], cl_fedavg_df["TrainAcc"], linestyle="--", label="fedavg TrainAcc")
ax.plot(cl_fedavg_df["Round"], cl_fedavg_df["TestAcc"],  linestyle="--", label="fedavg TestAcc")
ax.set_title("Clinical — Accuracy"); ax.set_xlabel("Round"); ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3); ax.legend()

# (D) Clinical — Loss
ax = axes[1,1]
ax.plot(cl_circ_df["Round"], cl_circ_df["TrainLoss"], linestyle="-", label="circular TrainLoss")
ax.plot(cl_circ_df["Round"], cl_circ_df["TestLoss"],  linestyle="-", label="circular TestLoss")
ax.plot(cl_fedavg_df["Round"], cl_fedavg_df["TrainLoss"], linestyle="--", label="fedavg TrainLoss")
ax.plot(cl_fedavg_df["Round"], cl_fedavg_df["TestLoss"],  linestyle="--", label="fedavg TestLoss")
ax.set_title("Clinical — Loss"); ax.set_xlabel("Round"); ax.set_ylabel("Loss")
ax.grid(True, alpha=0.3); ax.legend()

plt.tight_layout()
OUT = "/content/drive/MyDrive/QFL_folder/qfl_compare_circular_vs_fedavg_bc_clinical.png"
plt.savefig(OUT, dpi=200, bbox_inches="tight")
plt.show()
print(f"Saved comparison figure to: {OUT}")
