In [None]:
!pip install -U "flwr[simulation]" torch==2.8.0 opacus matplotlib



In [None]:
# =========================
# Federated Healthcare (Colab)
# Flower (simulation) + PyTorch
# =========================

import json, math, random, warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd


from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score,
    classification_report, confusion_matrix
)
from sklearn.impute import SimpleImputer

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import flwr as fl
from flwr.common import (
    ndarrays_to_parameters,
    parameters_to_ndarrays,
    NDArrays,
    Scalar,
)

warnings.filterwarnings("ignore")

# -----------------------------
# Config & Utilities
# -----------------------------

@dataclass
class FLConfig:
    num_clients: int = 8
    num_rounds: int = 5
    local_epochs: int = 1
    batch_size: int = 32
    lr: float = 1e-3
    seed: int = 42
    dirichlet_alpha: float = 0.5  # client heterogeneity
    dp_on: bool = False           # optional DP
    dp_noise_multiplier: float = 1.0
    dp_max_grad_norm: float = 1.0


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def is_binary_labels(y: np.ndarray) -> bool:
    return len(np.unique(y)) == 2


# -----------------------------
# Load & preprocess
# -----------------------------

def engineer_features(df: pd.DataFrame) -> pd.DataFrame:
    """Create useful numeric features from dates and amounts."""
    df = df.copy()

    # Length of stay (days) if dates exist
    if "Date of Admission" in df.columns and "Discharge Date" in df.columns:
        adm = pd.to_datetime(df["Date of Admission"], errors="coerce")
        dis = pd.to_datetime(df["Discharge Date"], errors="coerce")
        df["length_of_stay_days"] = (dis - adm).dt.days

    # Clean billing
    if "Billing Amount" in df.columns:
        df["Billing Amount"] = pd.to_numeric(df["Billing Amount"], errors="coerce")

    # Normalize casing for certain categoricals
    for col in ["Gender", "Blood Type", "Medical Condition", "Admission Type",
                "Medication", "Insurance Provider"]:
        if col in df.columns and df[col].dtype == object:
            df[col] = df[col].astype(str).str.strip().str.title()

    return df


def load_healthcare(csv_path: str, target_col: str = "stroke"):
    """
    Loads stroke dataset, drops obvious ID/PII, preprocesses features,
    and returns train/test splits suitable for FL.
    """
    df = pd.read_csv(csv_path)

    # If the stroke column is not present, raise an error
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found. Available: {list(df.columns)}")

    # Drop obvious non-predictive identifiers if they exist
    drop_cols = [c for c in ["id", "Name", "Doctor", "Hospital"] if c in df.columns]
    df = df.drop(columns=drop_cols, errors="ignore")

    # Feature engineering (won't do much here but safe to keep)
    df = engineer_features(df)

    # Target: stroke (0/1)
    y_raw = df[target_col]
    # Force to numeric 0/1
    y = pd.to_numeric(y_raw, errors="coerce").astype(float)
    # Drop rows where target is NaN after conversion
    mask = ~np.isnan(y)
    df = df.loc[mask].reset_index(drop=True)
    y = y[mask].astype(int).values

    # Features: all columns except target
    X = df.drop(columns=[target_col])

    numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
    categorical_cols = [c for c in X.columns if c not in numeric_cols]

    preprocessor = ColumnTransformer(
        transformers=[
            ("num", Pipeline([
                ("imputer", SimpleImputer(strategy="median")),
                ("scaler", StandardScaler())
            ]), numeric_cols),
            ("cat", Pipeline([
                ("imputer", SimpleImputer(strategy="most_frequent")),
                ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False))
            ]), categorical_cols),
        ]
    )

    strat = y if is_binary_labels(y) else None
    X_train_raw, X_test_raw, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=strat
    )

    X_train = preprocessor.fit_transform(X_train_raw)
    X_test = preprocessor.transform(X_test_raw)

    X_train = np.array(X_train)
    X_test = np.array(X_test)
    y_train = np.array(y_train)
    y_test = np.array(y_test)

    return X_train, X_test, y_train, y_test, preprocessor



# -----------------------------
# Federated partition (non-IID)
# -----------------------------

def dirichlet_partition(X, y, num_clients: int, alpha: float = 0.5, seed: int = 42):
    """Non-IID Dirichlet partition of data into num_clients splits."""
    rng = np.random.default_rng(seed)
    classes = np.unique(y)
    idx_by_class = {c: np.where(y == c)[0] for c in classes}
    client_indices = [[] for _ in range(num_clients)]

    for c in classes:
        idxs = idx_by_class[c]
        rng.shuffle(idxs)
        props = rng.dirichlet(alpha=[alpha] * num_clients)
        counts = np.floor(props * len(idxs)).astype(int)
        while counts.sum() < len(idxs):
            counts[rng.integers(0, num_clients)] += 1
        start = 0
        for i in range(num_clients):
            end = start + counts[i]
            client_indices[i].extend(idxs[start:end].tolist())
            start = end

    splits = []
    X = np.array(X)
    y = np.array(y)
    for ci in client_indices:
        ci = np.array(ci, dtype=int)
        # Only include splits with data
        if len(ci) > 0:
            splits.append((X[ci], y[ci]))
    return splits


# -----------------------------
# Flatten helpers for parameters
# -----------------------------

def flatten_ndarrays(nds: List[np.ndarray]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
    """Flatten a list of ndarrays into a single 1D vector + remember shapes."""
    shapes = [a.shape for a in nds]
    flats = [a.ravel() for a in nds]
    flat = np.concatenate(flats).astype(np.float64)
    return flat, shapes

def unflatten_ndarrays(flat: np.ndarray, shapes: List[Tuple[int, ...]]) -> List[np.ndarray]:
    """Rebuild list of ndarrays from flat vector + shapes."""
    out = []
    i = 0
    for s in shapes:
        n = int(np.prod(s))
        part = flat[i:i+n].reshape(s)
        out.append(part)
        i += n
    return out


# -----------------------------
# Model & Training helpers
# -----------------------------

class MLP(nn.Module):
    def __init__(self, in_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(1)


def make_pos_weight(y: np.ndarray):
    """For BCEWithLogitsLoss: pos_weight = N_neg / N_pos."""
    unique, counts = np.unique(y, return_counts=True)
    cdict = dict(zip(unique, counts))
    if 0 in cdict and 1 in cdict and cdict[1] > 0:
        return torch.tensor(cdict[0] / cdict[1], dtype=torch.float32)
    return torch.tensor(1.0, dtype=torch.float32)


def to_tensor_dataset(X: np.ndarray, y: np.ndarray) -> TensorDataset:
    return TensorDataset(torch.tensor(X, dtype=torch.float32),
                         torch.tensor(y, dtype=torch.float32))


def bce_metrics(logits: np.ndarray, y_true: np.ndarray) -> Dict[str, float]:
    probs = 1 / (1 + np.exp(-logits))
    y_pred = (probs >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    try:
        auc = roc_auc_score(y_true, probs)
    except Exception:
        auc = float("nan")
    return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "roc_auc": auc}


# -----------------------------
# Flower Client
# -----------------------------

class TabularClient(fl.client.NumPyClient):
    def __init__(self, cid: int, X: np.ndarray, y: np.ndarray, input_dim: int, cfg: FLConfig):
        self.cid = cid
        self.X = X
        self.y = y
        self.cfg = cfg

        self.model = MLP(input_dim)

        if is_binary_labels(self.y):
            self.criterion = nn.BCEWithLogitsLoss(pos_weight=make_pos_weight(self.y))
        else:
            self.criterion = nn.BCEWithLogitsLoss()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)

    def get_parameters(self, config={}):
        return [v.cpu().numpy() for _, v in self.model.state_dict().items()]

    def set_parameters(self, params):
        state_dict = self.model.state_dict()
        for (k, _), v in zip(state_dict.items(), params):
            state_dict[k] = torch.tensor(v)
        self.model.load_state_dict(state_dict)

    def fit(self, params, config={}):
        self.set_parameters(params)
        # Ensure X is not empty before creating DataLoader
        if len(self.X) == 0:
            print(f"Client {self.cid} has no data, skipping fit.")
            return self.get_parameters(), 0, {}

        loader = DataLoader(
            to_tensor_dataset(self.X, self.y),
            batch_size=self.cfg.batch_size,
            shuffle=True,
        )
        self.model.train()
        for _ in range(self.cfg.local_epochs):
            for xb, yb in loader:
                logits = self.model(xb)
                loss = self.criterion(logits, yb)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
        return self.get_parameters(), len(self.X), {}

    def evaluate(self, params, config={}):
        self.set_parameters(params)
        if len(self.X) == 0:
            print(f"Client {self.cid} has no data, skipping evaluation.")
            return 0.0, 0, {"accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0, "roc_auc": float("nan")}

        self.model.eval()
        with torch.no_grad():
            logits = self.model(torch.tensor(self.X, dtype=torch.float32)).cpu().numpy()
        m = bce_metrics(logits, self.y)
        loss = float(
            nn.BCEWithLogitsLoss()(
                torch.tensor(logits, dtype=torch.float32),
                torch.tensor(self.y, dtype=torch.float32)
            ).item()
        )
        return loss, len(self.X), m


# -----------------------------
# Server-side (test set) evaluation
# -----------------------------

def gen_evaluate_fn(X_test: np.ndarray, y_test: np.ndarray, input_dim: int, cfg: FLConfig):
    def evaluate(server_round: int, parameters: fl.common.NDArrays, config):
        model = MLP(input_dim)
        state_dict = model.state_dict()
        for (k, _), v in zip(state_dict.items(), parameters):
            state_dict[k] = torch.tensor(v)
        model.load_state_dict(state_dict)

        model.eval()
        with torch.no_grad():
            logits = model(torch.tensor(X_test, dtype=torch.float32)).cpu().numpy()
        metrics = bce_metrics(logits, y_test)

        print(f"[Round {server_round}] test: " +
              json.dumps({k: round(v, 4) if v == v else None for k, v in metrics.items()}))

        loss = float(
            nn.BCEWithLogitsLoss()(
                torch.tensor(logits, dtype=torch.float32),
                torch.tensor(y_test, dtype=torch.float32)
            ).item()
        )
        return loss, metrics
    return evaluate


# -----------------------------
# Custom FedAvg that logs updates for HE notebook
# -----------------------------

class LoggingFedAvg(fl.server.strategy.FedAvg):
    """
    Same as FedAvg, but on a chosen round it logs client deltas Î”w_i
    (local_i - global_after_round) to 'round1_updates.npy'
    for later use in a separate HE notebook.
    """
    def __init__(self, log_round: int = 1, log_path: str = "round1_updates.npy", **kwargs):
        super().__init__(**kwargs)
        self.log_round = log_round
        self.log_path = log_path
        self._shapes_cache = None
        self._logged = False

    def aggregate_fit(
        self,
        rnd: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.server.client_proxy.FitRes]],
        failures: List[BaseException],
    ) -> Tuple[NDArrays, Dict[str, Scalar]]:

        # 1) Let FedAvg do the usual aggregation
        aggregated_params, metrics = super().aggregate_fit(rnd, results, failures)

        # 2) Log once, for the chosen round
        if (not self._logged) and (rnd == self.log_round) and results:
            print(f"ðŸ“¥ Logging client updates from round {rnd} to '{self.log_path}'")

            # Global AFTER aggregation for this round
            global_nd = parameters_to_ndarrays(aggregated_params)
            flat_global, shapes = flatten_ndarrays(global_nd)
            self._shapes_cache = shapes

            updates = []
            for _, fitres in results:
                local_nd = parameters_to_ndarrays(fitres.parameters)  # client's local model
                flat_local, _ = flatten_ndarrays(local_nd)
                delta = flat_local - flat_global  # Î”w_i = local_i - global_after
                updates.append(delta)

            updates = np.stack(updates, axis=0)  # shape: [num_clients, D]

            np.save(self.log_path, updates)
            np.save(self.log_path.replace(".npy", "_shapes.npy"),
                    np.array(self._shapes_cache, dtype=object))
            print(f"âœ… Saved shape {updates.shape} to '{self.log_path}'")

            self._logged = True

        return aggregated_params, metrics



# -----------------------------
# Orchestration
# -----------------------------

def run_federated(csv_path: str, target_col: str = "Test Results", cfg: FLConfig = FLConfig()):
    set_seed(cfg.seed)

    # Load & preprocess
    X_train, X_test, y_train, y_test, preproc = load_healthcare(csv_path, target_col)
    input_dim = X_train.shape[1]

    # Non-IID client splits
    client_splits = dirichlet_partition(
        X_train, y_train, cfg.num_clients,
        alpha=cfg.dirichlet_alpha, seed=cfg.seed
    )

    def client_fn(cid: str):
        i = int(cid)
        # Ensure client_splits has enough elements
        if i < len(client_splits):
            Xc, yc = client_splits[i]
            return TabularClient(i, Xc, yc, input_dim, cfg)
        else:
            # Handle cases where client_splits might have fewer clients than cfg.num_clients
            # This can happen if some partitions ended up empty and were filtered out.
            # For now, we return a dummy client that doesn't train/evaluate.
            print(f"Client {cid} requested, but no data available. Returning a dummy client.")
            return TabularClient(i, np.array([]).reshape(0, input_dim), np.array([]), input_dim, cfg)

    strategy = LoggingFedAvg(
        log_round=1,                     # log round-1 updates for HE notebook
        log_path="round1_updates.npy",
        evaluate_fn=gen_evaluate_fn(X_test, y_test, input_dim, cfg),
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=cfg.num_clients, # Changed this to allow simulation to proceed with fewer actual clients
        min_evaluate_clients=cfg.num_clients,
        min_available_clients=cfg.num_clients,
    )

    hist = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=cfg.num_clients,
        config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
        strategy=strategy,
    )

    print("\nâœ… Training completed successfully.")
    return hist, X_train, X_test, y_test


# -----------------------------
# Run
# -----------------------------

hist, X_train, X_test, y_test = run_federated(
    "healthcare-dataset-stroke-data.csv",
    target_col="stroke",
    cfg=FLConfig(num_clients=8, num_rounds=5, local_epochs=1, batch_size=32, lr=1e-3),
)

print("âœ… Training complete.")

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
  return datetime.utcnow().replace(tzinfo=utc)
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
  return datetime.utcnow().replace(tzinfo=utc)
2025-12-09 01:31:05,377	INFO

[Round 0] test: {"accuracy": 0.0489, "precision": 0.0489, "recall": 1.0, "f1": 0.0933, "roc_auc": 0.6514}


[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=17263)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=17263)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=17263)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=17263)[0m             entirely in future v

ðŸ“¥ Logging client updates from round 1 to 'round1_updates.npy'
âœ… Saved shape (8, 11137) to 'round1_updates.npy'
[Round 1] test: {"accuracy": 0.729, "precision": 0.1178, "recall": 0.7, "f1": 0.2017, "roc_auc": 0.804}


[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[92mINFO [0m:      aggregate_evaluate: received 8 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 8 clients (out of 8)
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17264)[0m             This is a deprecated feature. It will be removed[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=17264)[0m             entirely in future versions of Flower.[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=17263)[0m 
[36m(Cli

[Round 2] test: {"accuracy": 0.7495, "precision": 0.1241, "recall": 0.68, "f1": 0.2099, "roc_auc": 0.8135}


[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[92mINFO [0m:      aggregate_evaluate: received 8 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 8 clients (out of 8)
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pi

[Round 3] test: {"accuracy": 0.7838, "precision": 0.1423, "recall": 0.68, "f1": 0.2353, "roc_auc": 0.8216}


[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[92mINFO [0m:      aggregate_evaluate: received 8 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 8 clients (out of 8)
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pi

[Round 4] test: {"accuracy": 0.726, "precision": 0.1266, "recall": 0.78, "f1": 0.2179, "roc_auc": 0.8228}


[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[92mINFO [0m:      aggregate_evaluate: received 8 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 8 clients (out of 8)
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pi

[Round 5] test: {"accuracy": 0.7573, "precision": 0.136, "recall": 0.74, "f1": 0.2298, "roc_auc": 0.8218}


[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17264)[0m 
[36m(ClientAppActor pid=17264)[0m         
[36m(ClientAppActor pid=17263)[0m 
[36m(ClientAppActor pid=17263)[0m         
[92mINFO [0m:      aggregate_evaluate: received 8 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 10.89s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.6249032109946654
[92mINFO [0m:      		round 2: 0.530844393162578
[92mINFO [0m:      		round 3: 0.43990577413990306
[92mINFO [0m:      		round 4: 0.47526


âœ… Training completed successfully.
âœ… Training complete.


  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
display(hist.metrics_centralized)

{'accuracy': [(0, 0.04892367906066536),
  (1, 0.7289628180039139),
  (2, 0.7495107632093934),
  (3, 0.7837573385518591),
  (4, 0.726027397260274),
  (5, 0.7573385518590998)],
 'precision': [(0, 0.04892367906066536),
  (1, 0.11784511784511785),
  (2, 0.12408759124087591),
  (3, 0.14225941422594143),
  (4, 0.1266233766233766),
  (5, 0.13602941176470587)],
 'recall': [(0, 1.0), (1, 0.7), (2, 0.68), (3, 0.68), (4, 0.78), (5, 0.74)],
 'f1': [(0, 0.09328358208955224),
  (1, 0.2017291066282421),
  (2, 0.20987654320987653),
  (3, 0.23529411764705882),
  (4, 0.21787709497206703),
  (5, 0.22981366459627328)],
 'roc_auc': [(0, np.float64(0.6514197530864198)),
  (1, np.float64(0.8039711934156378)),
  (2, np.float64(0.8135185185185185)),
  (3, np.float64(0.8216460905349794)),
  (4, np.float64(0.822798353909465)),
  (5, np.float64(0.8218312757201646))]}

In [6]:
# 1. Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# 2. Load updates saved by FL
import numpy as np
updates = np.load("/content/round1_updates.npy")
print("Loaded updates from FL:", updates.shape)

# 3. Blockchain import
import sys
sys.path.append('/content/drive/MyDrive')
from mock_ledger import MockBlockchain

ledger = MockBlockchain()

# Store each client update
for i in range(updates.shape[0]):
    payload = updates[i].tobytes()
    ledger.submit_update(1, f"client_{i}", payload)

# 4. Save ledger locally
import pickle
with open("ledger.pkl", "wb") as f:
    pickle.dump(ledger, f)

print("Saved ledger to local file 'ledger.pkl'")

#Copy all files into Drive
!cp round1_updates.npy /content/drive/MyDrive/
!cp round1_updates_shapes.npy /content/drive/MyDrive/
!cp ledger.pkl /content/drive/MyDrive/

print(" All files copied into Google Drive")


Mounted at /content/drive
Loaded updates from FL: (8, 11137)
[Ledger] Stored update: round=1, client=client_0, hash=c6d9945c...
[Ledger] Stored update: round=1, client=client_1, hash=66e8a4bb...
[Ledger] Stored update: round=1, client=client_2, hash=4df5c7ca...
[Ledger] Stored update: round=1, client=client_3, hash=d349ae90...
[Ledger] Stored update: round=1, client=client_4, hash=d2a6d698...
[Ledger] Stored update: round=1, client=client_5, hash=88c04381...
[Ledger] Stored update: round=1, client=client_6, hash=b4813bb3...
[Ledger] Stored update: round=1, client=client_7, hash=14e32337...
Saved ledger to local file 'ledger.pkl'
 All files copied into Google Drive
