In [1]:
# =======================
# üì¶ IMPORTACIONES
# =======================

# Built-in
import os
import sys
import re
import time
import json
import random
import warnings
from typing import List, Tuple, Dict
import operator
import re
from pathlib import Path
import numpy as np
import pandas as pd


# NumPy, Pandas, Matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sklearn
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score, pairwise_distances
)
from sklearn.exceptions import NotFittedError
from collections import defaultdict
from sklearn.metrics import classification_report, confusion_matrix


# Flower
from flwr.client import ClientApp, NumPyClient
from flwr.common import (
    Context, NDArrays, Metrics, Scalar,
    ndarrays_to_parameters, parameters_to_ndarrays
)
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# LORE
from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.rule import Expression, Rule

from lore_sa.client_utils import ClientUtilsMixin

# Otros
from pathlib import Path
from filelock import FileLock  # pip install filelock
import pandas as pd, os
from graphviz import Digraph
from tqdm import tqdm
from datetime import datetime
import cProfile, pstats, io
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner
import copy
from sklearn.model_selection import train_test_split
import shutil
from lore_sa.client_utils.explanation_intersection import ExplanationIntersection



2026-02-20 13:21:54,246	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2026-02-20 13:21:59,928 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.piping.pipe(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2026-02-20 13:21:59,930 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.rendering.render(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2026-02-20 13:21:59,934 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.unflattening.unflatten(['stagger', 'fanout', 'chain', 'encoding'])
2026-02-20 13:21:59,936 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.viewing.view(['quiet'])
2026-02-20 13:21:59,947 graphviz._tools DEBUG    deprecate positional args: graphviz.quoting.quote(['is_html_string', 'is_valid_id', 'dot_keywords', 'endswith_odd_number_of_backslashes', 'escape_unescaped

In [2]:
# =======================
# ‚öôÔ∏è VARIABLES GLOBALES
# =======================



UNIQUE_LABELS = []
FEATURES = []

NUM_TRAIN_ROUNDS = 2        # rondas donde entrenas la NN
NUM_SERVER_ROUNDS = 3       # la √∫ltima solo para explicaciones
NUM_CLIENTS = 6
SEED = 42

NON_IID = False   # o False para los experimentos IID
NON_IID_ALPHA = 0.5  # por ejemplo, Dirichlet m√°s sesgado

MIN_AVAILABLE_CLIENTS = NUM_CLIENTS
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None

GLOBAL_TEST_IDX = None
GLOBAL_TEST_HASHES = None


# ==============================================
# üßπ Borrar TODOS los CSV individuales de clientes
# ==============================================

results_dir = Path("results")
results_dir.mkdir(exist_ok=True)

# Borra TODO lo que haya dentro (csv, pth, im√°genes, etc.)
for f in results_dir.iterdir():
    if f.is_file():
        try:
            f.unlink()
        except Exception:
            pass




# =======================
# üîß UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)





# =======================
# üì• PREPROCESADO DATASET
# =======================
def preprocess_df(df: pd.DataFrame, dataset_name: str, class_col: str) -> pd.DataFrame:
    df = df.copy()

    if "adult" in dataset_name.lower():
        df.drop(columns=['fnlwgt', 'education-num', 'capital-gain', 'capital-loss'],
                inplace=True, errors="ignore")

    elif "churn" in dataset_name.lower():
        df.drop(columns=['customerID', 'TotalCharges'],
                inplace=True, errors="ignore")
        if "MonthlyCharges" in df.columns:
            df['MonthlyCharges'] = pd.to_numeric(df['MonthlyCharges'], errors='coerce')
        if "tenure" in df.columns:
            df['tenure'] = pd.to_numeric(df['tenure'], errors='coerce')
        if "SeniorCitizen" in df.columns:
            df['SeniorCitizen'] = df['SeniorCitizen'].map({0: 'No', 1: 'Yes'}).astype(str)
        df.dropna(subset=[c for c in ["MonthlyCharges", "tenure"] if c in df.columns], inplace=True)

    elif "breastcancer" in dataset_name.lower():
        df.drop(columns=['id'], inplace=True, errors='ignore')

    # object -> category (solo baja cardinalidad)
    for col in df.select_dtypes(include=["object"]).columns:
        if col != class_col and df[col].nunique(dropna=True) < 50:
            df[col] = df[col].astype("category")

    return df


def _stable_row_hash(df: pd.DataFrame) -> np.ndarray:
    return pd.util.hash_pandas_object(df, index=False).astype("uint64").to_numpy()

# =======================
# üì• CARGAR DATOS
# =======================

def get_global_onehot_info(flower_dataset_name: str, class_col: str):
    """
    Lee TODO el pool (train con num_partitions=1) para fijar:
    - cat_features (categorical cols)
    - num_features
    - categories_global (OHE categories_ en el orden de cat_features)
    - onehot_columns (nombres finales onehot)
    """
    fds_tmp = FederatedDataset(
        dataset=flower_dataset_name,
        partitioners={"train": IidPartitioner(num_partitions=1)}
    )
    df_all = fds_tmp.load_partition(0, "train").with_format("pandas")[:]
    df_all = preprocess_df(df_all, flower_dataset_name, class_col)

    # asegurar category dtype
    for col in df_all.select_dtypes(include=["object"]).columns:
        if col != class_col and df_all[col].nunique(dropna=True) < 50:
            df_all[col] = df_all[col].astype("category")

    cat_features = [c for c in df_all.columns if df_all[c].dtype.name == "category" and c != class_col]
    num_features = [c for c in df_all.columns if df_all[c].dtype.kind in "fi" and c != class_col]

    if len(cat_features) > 0:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
        ohe.fit(df_all[cat_features])
        categories_global = ohe.categories_
        onehot_columns = ohe.get_feature_names_out(cat_features).tolist()
    else:
        categories_global = []
        onehot_columns = []

    return cat_features, num_features, categories_global, onehot_columns, df_all




# ============================================
# üì• CARGA GENERAL + TEST GLOBAL SIN FUGA
# ============================================
def load_data_general(flower_dataset_name: str, class_col: str, partition_id: int, num_partitions: int):
    """
    Devuelve:
      X_train, y_train,
      X_test_local, y_test_local,
      X_test_global, y_test_global,
      tabular_dataset, feature_names_out, label_encoder,
      num_transformer, numeric_features,
      encoder (ColumnTransformerEnc), preprocessor
    """
    global fds, UNIQUE_LABELS, FEATURES
    global GLOBAL_TEST_IDX, GLOBAL_TEST_HASHES

    # 1) Info global OHE + df_all pool
    cat_features, num_features, categories_global, onehot_columns, df_all = get_global_onehot_info(
        flower_dataset_name, class_col
    )

    # 2) LabelEncoder global (clases estables)
    if not UNIQUE_LABELS:
        le_global = LabelEncoder()
        le_global.fit(df_all[class_col])
        UNIQUE_LABELS[:] = le_global.classes_.tolist()

    label_encoder = LabelEncoder()
    label_encoder.classes_ = np.array(UNIQUE_LABELS)

    y_all = label_encoder.transform(df_all[class_col])

    # 3) Definir TEST GLOBAL una sola vez (idx + hashes estables)
    if GLOBAL_TEST_IDX is None:
        idx = np.arange(len(df_all))
        _, GLOBAL_TEST_IDX = train_test_split(
            idx,
            test_size=0.2,
            random_state=SEED,
            stratify=y_all if len(np.unique(y_all)) > 1 else None
        )
        row_hash_all = _stable_row_hash(df_all)
        GLOBAL_TEST_HASHES = set(row_hash_all[GLOBAL_TEST_IDX].tolist())

    # 4) Crear/usar FederatedDataset particionado (por filas)
    if fds is None:
        if NON_IID:
            partitioner = DirichletPartitioner(
                num_partitions=num_partitions,
                alpha=NON_IID_ALPHA,
                partition_by=class_col,
            )
        else:
            partitioner = IidPartitioner(num_partitions=num_partitions)

        fds = FederatedDataset(
            dataset=flower_dataset_name,
            partitioners={"train": partitioner},
        )

    df_client = fds.load_partition(partition_id, "train").with_format("pandas")[:]
    df_client = preprocess_df(df_client, flower_dataset_name, class_col)

    # 5) Eliminar filas que est√©n en el TEST GLOBAL (sin fuga)
    row_hash_client = _stable_row_hash(df_client)
    keep_mask = ~np.isin(row_hash_client, np.fromiter(GLOBAL_TEST_HASHES, dtype="uint64"))
    df_client = df_client.loc[keep_mask].copy()

    # 6) TabularDataset/descriptor (cliente) para LORE
    tabular_dataset = TabularDataset(df_client.copy(), class_name=class_col)
    descriptor = tabular_dataset.descriptor

    # Asegurar distinct_values local (por si viene vac√≠o)
    for col, info in descriptor.get("categorical", {}).items():
        if "distinct_values" not in info or not info["distinct_values"]:
            info["distinct_values"] = list(df_client[col].dropna().unique())

    # 7) X/y cliente (sin onehot a√∫n)
    y = label_encoder.transform(df_client[class_col])
    X_raw = df_client.drop(columns=[class_col])

    numeric_features = list(descriptor.get("numeric", {}).keys())
    categorical_features = list(descriptor.get("categorical", {}).keys())

    # Ojo: aqu√≠ mantenemos el mismo orden que usa el descriptor del cliente
    FEATURES[:] = numeric_features + categorical_features

    num_idx = list(range(len(numeric_features)))
    cat_idx = list(range(len(numeric_features), len(FEATURES)))

    # 8) Preprocessor con categor√≠as globales (dim estable)
    transformers = [("num", "passthrough", num_idx)]
    if len(categorical_features) > 0:
        # IMPORTANT√çSIMO: categories_global est√° en el orden de cat_features (global)
        # pero aqu√≠ categorical_features puede venir en otro orden -> reordenamos categories_global
        cat_to_pos = {c: i for i, c in enumerate(cat_features)}
        cats_ordered = [categories_global[cat_to_pos[c]] for c in categorical_features]

        transformers.append((
            "cat",
            OneHotEncoder(
                sparse_output=False,
                handle_unknown="ignore",
                categories=cats_ordered
            ),
            cat_idx
        ))

    preprocessor = ColumnTransformer(transformers)

    # 9) Split local + FIT SOLO con train local
    X_train_raw, X_test_local_raw, y_train, y_test_local = train_test_split(
        X_raw[FEATURES], y,
        test_size=0.3,
        random_state=SEED,
        stratify=y if len(np.unique(y)) > 1 else None
    )

    X_train = preprocessor.fit_transform(X_train_raw.to_numpy())
    X_test_local = preprocessor.transform(X_test_local_raw.to_numpy())

    # 10) Construir test global REAL (mismo preprocessor del cliente)
    df_global = df_all.iloc[GLOBAL_TEST_IDX].copy()
    df_global = preprocess_df(df_global, flower_dataset_name, class_col)

    X_test_global = preprocessor.transform(df_global.drop(columns=[class_col])[FEATURES].to_numpy())
    y_test_global = label_encoder.transform(df_global[class_col])

    # 11) Feature names finales (num + onehot) para NN/servidor
    feature_names_out = []
    feature_names_out += list(numeric_features)
    if len(categorical_features) > 0:
        cat_names = preprocessor.named_transformers_["cat"].get_feature_names_out(categorical_features).tolist()
        feature_names_out += cat_names

    FEATURES[:] = feature_names_out  # ahora s√≠: columnas finales (onehot)

    # 12) Encoder LORE con distinct_values globales (para reglas legibles)
    descriptor_global = descriptor.copy()
    if "categorical" in descriptor_global and len(categorical_features) > 0:
        # cats_ordered ya est√° en orden de categorical_features
        for i, col in enumerate(categorical_features):
            if col in descriptor_global["categorical"]:
                descriptor_global["categorical"][col]["distinct_values"] = list(cats_ordered[i])

    encoder = ColumnTransformerEnc(descriptor_global)

    num_transformer = preprocessor.named_transformers_["num"] if "num" in preprocessor.named_transformers_ else None

    return (
        X_train, y_train,
        X_test_local, y_test_local,
        X_test_global, y_test_global,
        tabular_dataset, feature_names_out, label_encoder,
        num_transformer, numeric_features, encoder, preprocessor
    )


# =======================
# ‚úÖ DATASET
# =======================
# DATASET_NAME = "pablopalacios23/adult"
# CLASS_COLUMN = "class"

# DATASET_NAME = "pablopalacios23/churn"
# CLASS_COLUMN = "Churn"

# DATASET_NAME = "pablopalacios23/HeartDisease"
# CLASS_COLUMN = "HeartDisease"

# DATASET_NAME = "pablopalacios23/breastcancer"
# CLASS_COLUMN = "diagnosis"

DATASET_NAME = "pablopalacios23/Diabetes"
CLASS_COLUMN = "Outcome"



# load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)

### HOLDOUT DEL SERVIDOR

In [3]:
(X_train, y_train,
 X_test_local, y_test_local,
 X_test_global, y_test_global,
 dataset, feature_names, label_encoder,
 scaler, numeric_features, encoder, preprocessor) = load_data_general(
    DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS
)


print("\nüì¶ TRAIN (primeras filas):")
print(pd.DataFrame(X_train))

print(y_train)

print(X_train.shape)

print("\nüß™ TEST LOCAL (primeras filas):")
print(pd.DataFrame(X_test_local))

print(y_test_local)

print("\nüåç TEST GLOBAL (primeras filas):")
print(pd.DataFrame(X_test_global))

print(y_test_global)

def overlap_check(A, B):
    hA = pd.util.hash_pandas_object(pd.DataFrame(A), index=False)
    hB = pd.util.hash_pandas_object(pd.DataFrame(B), index=False)
    return np.intersect1d(hA.values, hB.values).size


print("üîé Overlap TRAIN vs TEST GLOBAL:",
      overlap_check(X_train, X_test_global))

print("üîé Overlap TEST LOCAL vs TEST GLOBAL:",
      overlap_check(X_test_local, X_test_global))

print(feature_names)

2026-02-20 13:22:00,090 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2026-02-20 13:22:00,286 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/Diabetes/resolve/main/README.md HTTP/11" 404 0
2026-02-20 13:22:00,445 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/Diabetes HTTP/11" 200 555
2026-02-20 13:22:00,580 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/Diabetes/resolve/90286a808fa902822a230872737f66665a681328/Diabetes.py HTTP/11" 404 0
2026-02-20 13:22:00,585 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2026-02-20 13:22:00,933 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/pablopalacios23/Diabetes/pablopalacios23/Diabetes.py HTTP/11" 404 0
2026-02-20 13:22:01,104 urllib3.connectionpool DEBUG    https://huggin


üì¶ TRAIN (primeras filas):
       0      1     2     3      4     5      6     7
0   11.0  138.0  76.0   0.0    0.0  33.2  0.420  35.0
1    9.0  152.0  78.0  34.0  171.0  34.2  0.893  33.0
2    6.0  148.0  72.0  35.0    0.0  33.6  0.627  50.0
3    0.0  104.0  64.0  37.0   64.0  33.6  0.510  22.0
4    8.0  151.0  78.0  32.0  210.0  42.9  0.516  36.0
..   ...    ...   ...   ...    ...   ...    ...   ...
67   6.0   93.0  50.0  30.0   64.0  28.7  0.356  23.0
68   4.0  116.0  72.0  12.0   87.0  22.1  0.463  37.0
69   2.0   90.0  68.0  42.0    0.0  38.2  0.503  27.0
70   0.0  141.0  84.0  26.0    0.0  32.4  0.433  22.0
71  13.0  145.0  82.0  19.0  110.0  22.2  0.245  57.0

[72 rows x 8 columns]
[0 1 1 1 1 0 1 0 0 1 0 1 1 1 0 0 0 1 0 0 0 1 1 0 1 1 0 0 0 1 0 0 0 1 0 0 1
 1 0 1 0 1 1 0 0 1 1 0 1 0 0 1 1 1 0 0 0 1 1 1 0 0 1 0 0 0 1 0 0 1 0 0]
(72, 8)

üß™ TEST LOCAL (primeras filas):
       0      1      2     3      4     5      6     7
0   11.0  135.0    0.0   0.0    0.0  52.3  0.578  40.0

# Cliente

In [4]:
# ==========================
# üåº CLIENTE FLOWER
# ==========================
import operator
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.decision_tree import EnsembleDecisionTreeSurrogate, SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from sklearn.metrics import pairwise_distances


class TorchNNWrapper:
    def __init__(self, model, num_idx, mean, scale):
        self.model = model
        self.model.eval()
        self.num_idx = np.asarray(num_idx, dtype=int)
        self.mean = np.asarray(mean, dtype=np.float32)
        self.scale = np.asarray(scale, dtype=np.float32)
        self.scale_safe = np.where(self.scale == 0, 1.0, self.scale)

    def _scale_internally(self, X):
        X = np.asarray(X, dtype=np.float32)
        Xs = X.copy()
        # soporta [n, d] o [d]
        if Xs.ndim == 1:
            Xs = Xs[None, :]
        Xs[:, self.num_idx] = (Xs[:, self.num_idx] - self.mean) / self.scale_safe
        return Xs

    def predict(self, X):
        Xs = self._scale_internally(X)
        with torch.no_grad():
            X_tensor = torch.tensor(Xs, dtype=torch.float32)
            logits = self.model(X_tensor)
            return logits.argmax(dim=1).cpu().numpy()

    def predict_proba(self, X):
        Xs = self._scale_internally(X)
        with torch.no_grad():
            X_tensor = torch.tensor(Xs, dtype=torch.float32)
            logits = self.model(X_tensor)
            probs = F.softmax(logits, dim=1)
            return probs.cpu().numpy()

class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
        

class FlowerClient(NumPyClient, ClientUtilsMixin):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, X_test_global, y_test_global, scaler_nn_mean, scaler_nn_scale, num_idx, dataset, client_id, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.nn_model_local = copy.deepcopy(nn_model)
        self.nn_model_global = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.X_test_global = X_test_global
        self.y_test_global = y_test_global
        self.scaler_nn_mean = np.asarray(scaler_nn_mean, dtype=np.float32)
        self.scaler_nn_scale = np.where(np.asarray(scaler_nn_scale, np.float32)==0, 1.0, np.asarray(scaler_nn_scale, np.float32))
        self.num_idx = np.asarray(num_idx, dtype=int)
        self.dataset = dataset
        self.client_id = client_id
        os.makedirs("results", exist_ok=True)
        self.local_ckpt = f"results/bb_local_client_{self.client_id}.pth"
        self.local_trained = False
        self.feature_names = feature_names
        self.label_encoder = label_encoder
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder
        self.unique_labels = label_encoder.classes_.tolist()
        self.y_train_nn = y_train.astype(np.int64)
        self.y_test_nn = y_test.astype(np.int64)
        self.received_supertree = None
        self.preprocessor = preprocessor

    def _train_nn(self, model, epochs=10, lr=1e-3):
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()

        # ‚úÖ Escalado igual que en TorchNNWrapper
        X = np.asarray(self.X_train, dtype=np.float32).copy()
        scale_safe = np.where(self.scaler_nn_scale == 0, 1.0, self.scaler_nn_scale)
        X[:, self.num_idx] = (X[:, self.num_idx] - self.scaler_nn_mean) / scale_safe

        X_tensor = torch.tensor(X, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            logits = model(X_tensor)
            loss = loss_fn(logits, y_tensor)
            loss.backward()
            optimizer.step()



    def fit(self, parameters, config):
        # 1Ô∏è‚É£ Cargar pesos GLOBAL recibidos del servidor
        set_model_params(
            self.tree_model,
            self.nn_model_global,
            {"tree": [
                self.tree_model.get_params()["max_depth"],
                self.tree_model.get_params()["min_samples_split"],
                self.tree_model.get_params()["min_samples_leaf"],
            ], "nn": parameters}
        )

        round_number = int(config.get("server_round", 1))

        # 2Ô∏è‚É£ Baseline LOCAL: cargar si existe, si no entrenar 1 vez y guardar
        if not self.local_trained:
            if os.path.exists(self.local_ckpt):
                state = torch.load(self.local_ckpt, map_location="cpu")
                self.nn_model_local.load_state_dict(state)
                self.nn_model_local.eval()
                bb_local_tmp = TorchNNWrapper(self.nn_model_local, self.num_idx, self.scaler_nn_mean, self.scaler_nn_scale)
                with torch.no_grad():
                    acc_train_load = accuracy_score(self.y_train_nn, bb_local_tmp.predict(self.X_train))
                # print(f"[CLIENTE {self.client_id}] üì¶ LOCAL TRAIN acc tras cargar ckpt:", acc_train_load)
                self.local_trained = True
            else:
                # baseline parte del global recibido (ronda 1 t√≠picamente)
                self.nn_model_local = copy.deepcopy(self.nn_model_global)
                self._train_nn(self.nn_model_local, epochs=80, lr=1e-3)
                self.nn_model_local.eval()
                bb_local_tmp = TorchNNWrapper(self.nn_model_local, self.num_idx, self.scaler_nn_mean, self.scaler_nn_scale)
                with torch.no_grad():
                    acc_train_now = accuracy_score(self.y_train_nn, bb_local_tmp.predict(self.X_train))
                # print(f"[CLIENTE {self.client_id}] ‚úÖ LOCAL TRAIN acc justo tras entrenar:", acc_train_now)
                torch.save(self.nn_model_local.state_dict(), self.local_ckpt)
                self.local_trained = True
                print(f"[CLIENTE {self.client_id}] ‚úÖ LOCAL baseline entrenado y guardado")

        # 3Ô∏è‚É£ Entrenar GLOBAL (federado) en ESTA ronda con datos del cliente (FedAvg)
        self._train_nn(self.nn_model_global, epochs=10, lr=1e-3)
        # print(f"[CLIENTE {self.client_id}] üåç GLOBAL entrenado (ronda {round_number})")

        # 4Ô∏è‚É£ √Årbol local (si aplica)
        if round_number <= NUM_TRAIN_ROUNDS:
            self.tree_model.fit(self.X_train, self.y_train)

        # 5Ô∏è‚É£ Enviar al servidor los pesos del GLOBAL entrenado
        nn_weights = get_model_parameters(self.tree_model, self.nn_model_global)["nn"]
        return nn_weights, len(self.X_train), {}
    


    def evaluate(self, parameters, config):
        # 0) Set global params received from server (tree hyperparams + NN weights)
        set_model_params(
            self.tree_model,
            self.nn_model_global,
            {"tree": [
                self.tree_model.get_params()["max_depth"],
                self.tree_model.get_params()["min_samples_split"],
                self.tree_model.get_params()["min_samples_leaf"],
            ], "nn": parameters}
        )

        round_number = int(config.get("server_round", 1))
        explain_only = bool(config.get("explain_only", False))

        # ‚úÖ FIX CLAVE: en Ray/Flower el actor puede ‚Äúrenacer‚Äù en evaluate()
        # y perder el nn_model_local entrenado. En la ronda final lo recargamos SIEMPRE.
        if explain_only:
            if not os.path.exists(self.local_ckpt):
                raise RuntimeError(
                    f"[CLIENTE {self.client_id}] ‚ùå No existe ckpt local para explicar: {self.local_ckpt}"
                )
            state = torch.load(self.local_ckpt, map_location="cpu")
            self.nn_model_local.load_state_dict(state)
            self.nn_model_local.eval()
            self.local_trained = True
            print(f"[CLIENTE {self.client_id}] üì¶ LOCAL baseline recargado en evaluate()")

        # (Opcional pero recomendable) asegurar eval mode del global al explicar
        if explain_only:
            self.nn_model_global.eval()

        # Recibir SuperTree + mappings si vienen
        if "supertree" in config:
            try:
                print("Recibiendo supertree....")
                supertree_dict = json.loads(config["supertree"])
                self.received_supertree = SuperTree.convert_SuperNode_to_Node(
                    SuperTree.SuperNode.from_dict(supertree_dict)
                )
                self.global_mapping = json.loads(config["global_mapping"])
                self.feature_names = json.loads(config["feature_names"])
            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ‚ùå Error al recibir SuperTree: {e}")

        # üîπ CASO 1: rondas de entrenamiento (1..NUM_TRAIN_ROUNDS)
        if not explain_only:
            self.tree_model.fit(self.X_train, self.y_train)

            supertree = SuperTree()
            root_node = supertree.rec_buildTree(
                self.tree_model,
                list(range(self.X_train.shape[1])),
                len(self.unique_labels)
            )
            root_node = supertree.prune_redundant_leaves_local(root_node)

            self._save_local_tree(
                root_node,
                round_number,
                FEATURES,
                self.numeric_features,
                scaler=None,
                unique_labels=UNIQUE_LABELS,
                encoder=self.encoder
            )
            tree_json = json.dumps([root_node.to_dict()])

            return 0.0, len(self.X_test), {
                f"tree_ensemble_{self.client_id}": tree_json,
                f"encoded_feature_names_{self.client_id}": json.dumps(FEATURES),
                f"numeric_features_{self.client_id}": json.dumps(self.numeric_features),
                f"unique_labels_{self.client_id}": json.dumps(self.unique_labels),
                f"distinct_values_{self.client_id}": json.dumps(self.encoder.dataset_descriptor["categorical"]),
            }

        # üîπ CASO 2: ronda final (solo explicaci√≥n con Supertree final)
        print(f"[CLIENTE {self.client_id}] üîç Ronda final: solo explicaciones")

        # Si quieres √°rbol local para m√©tricas comparativas (no afecta a la NN)
        self.tree_model.fit(self.X_train, self.y_train)
        y_pred_tree_local = self.tree_model.predict(self.X_test)

        self.local_metrics = {
            "acc_local_tree": accuracy_score(self.y_test, y_pred_tree_local),
            "prec_local_tree": precision_score(self.y_test, y_pred_tree_local, average="weighted", zero_division=0),
            "rec_local_tree": recall_score(self.y_test, y_pred_tree_local, average="weighted", zero_division=0),
            "f1_local_tree": f1_score(self.y_test, y_pred_tree_local, average="weighted", zero_division=0),
        }

        # Usamos el SuperTree final recibido + LORE (y lo que toque)
        if self.received_supertree is not None:
            self.explain_all_test_instances(config)
            # self.explain_all_test_instances(config, only_idx=0)

        return 0.0, len(self.X_test), {}
        


    
    
    def _explain_one_instance(self, num_row, config, save_trees=False):
        from sklearn.metrics import accuracy_score
        import numpy as np
        
        self.nn_model_local.eval()
        self.nn_model_global.eval()


        # Wrapper que escala SOLO para la NN (espacio NN)
        bb_local = TorchNNWrapper(
            model=self.nn_model_local,
            num_idx=self.num_idx,
            mean=self.scaler_nn_mean,
            scale=self.scaler_nn_scale
        )

        bb_global = TorchNNWrapper(
            model=self.nn_model_global,
            num_idx=self.num_idx,
            mean=self.scaler_nn_mean,
            scale=self.scaler_nn_scale
        )

        # 1. Visualizar instancia escalada y decodificada usando el encoder/preprocessor ORIGINAL
        
        decoded = self.decode_onehot_instance(
            self.X_test[num_row],
            self.numeric_features,
            self.encoder,
            None,                 # <-- sin scaler (en crudo)
            self.feature_names
        )



        # Aseg√∫rate de que X_test[num_row] es un numpy array del shape correcto (1, n_features)
        row = np.asarray(self.X_test[num_row], dtype=np.float32)

        probs_local = bb_local.predict_proba(row[None, :])
        pred_class_idx_local = int(probs_local.argmax(axis=1)[0])
        pred_class_local = self.label_encoder.inverse_transform([pred_class_idx_local])[0]

        probs_global = bb_global.predict_proba(row[None, :])
        pred_class_idx_global = int(probs_global.argmax(axis=1)[0])
        pred_class_global = self.label_encoder.inverse_transform([pred_class_idx_global])[0]

        # 2. Construir DataFrame para LORE (si es necesario, solo para TabularDataset)

        local_df = pd.DataFrame(self.X_train, columns=self.feature_names).astype(np.float32)
        local_df["class"] = self.label_encoder.inverse_transform(self.y_train_nn)
        local_tabular_dataset = TabularDataset(local_df, class_name="class")

        # Explicaci√≥n LORE
        x_instance = pd.Series(self.X_test[num_row], index=self.feature_names)
        round_number = config.get("server_round", 1)

        # print("Instancia a explicar para el cliente:", self.client_id)
        # print(x_instance)
        # print("ü§ñ NN local pred:", pred_class_local)
        # print("üåç NN global pred:", pred_class_global)
        # print("\n")



        # =========================
        # Vecindad GLOBAL
        # =========================
        bbox_global_for_Z = sklearn_classifier_bbox.sklearnBBox(bb_global)
        lore_vecindad_global  = TabularGeneticGeneratorLore(bbox_global_for_Z, local_tabular_dataset)




        explanation_global  = lore_vecindad_global.explain_instance(x_instance, merge=True, num_classes=len(UNIQUE_LABELS), feature_names= self.feature_names, categorical_features=list(self.global_mapping.keys()), global_mapping=self.global_mapping, UNIQUE_LABELS=UNIQUE_LABELS,
                                                    client_id=self.client_id, round_number=round_number)
        

        lore_tree_global = explanation_global["merged_tree"]
        Z_global = explanation_global["neighborhood_Z"]
        y_bb_global = explanation_global["neighborhood_Yb"]
        dfZ_global = pd.DataFrame(Z_global, columns=self.feature_names)


        
        if save_trees:
            self.save_lore_tree_image(lore_tree_global.root,round_number,self.feature_names,self.numeric_features,UNIQUE_LABELS,self.encoder,folder="lore_tree_global")
            

        
        # =========================
        # Vecindad LOCAL (NUEVA)
        # =========================
        bbox_local = sklearn_classifier_bbox.sklearnBBox(bb_local)
        lore_vecindad_local = TabularGeneticGeneratorLore(bbox_local, local_tabular_dataset)

        explanation_local = lore_vecindad_local.explain_instance(x_instance,merge=True,num_classes=len(UNIQUE_LABELS),feature_names=self.feature_names,categorical_features=list(self.global_mapping.keys()),global_mapping=self.global_mapping,UNIQUE_LABELS=UNIQUE_LABELS,
                                                                 client_id=self.client_id,round_number=round_number)

        lore_tree_local = explanation_local["merged_tree"]
        Z_local = explanation_local["neighborhood_Z"]
        y_bb_local = explanation_local["neighborhood_Yb"]
        dfZ_local = pd.DataFrame(Z_local, columns=self.feature_names)


        if save_trees:
            self.save_lore_tree_image(lore_tree_local.root,round_number,self.feature_names,self.numeric_features,UNIQUE_LABELS,self.encoder,folder="lore_tree_local")

        # ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
        # 
        # # MERGED TREE

        # merged_tree = SuperTree()
        # merged_tree.mergeDecisionTrees(
        #     roots=[lore_tree.root, self.received_supertree],
        #     num_classes=len(self.unique_labels),
        #     feature_names=self.feature_names,
        #     categorical_features=list(self.global_mapping.keys()), 
        #     global_mapping=self.global_mapping
        # )
        
        # merged_tree.prune_redundant_leaves_full()


        # if save_trees:
        #     self.save_mergedTree_plot(root_node=merged_tree.root,round_number=round_number,feature_names=self.feature_names,class_names=self.unique_labels,numeric_features=self.numeric_features,scaler=None, global_mapping=self.global_mapping,folder="MergedTree")

        # ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------




        # # tree_str = self.tree_to_str(merged_tree.root,self.feature_names,numeric_features=self.numeric_features,scaler=None, global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        lore_tree_local_str = self.tree_to_str(lore_tree_local.root, self.feature_names, numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        lore_tree_global_str = self.tree_to_str(lore_tree_global.root, self.feature_names, numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        supertree_str = self.tree_to_str(self.received_supertree, self.feature_names, numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)





        
        # # rules = self.extract_rules_from_str(tree_str, target_class_label=pred_class)
        rules_lore_local = self.extract_rules_from_str(lore_tree_local_str, target_class_label=pred_class_local)
        rules_lore_global = self.extract_rules_from_str(lore_tree_global_str, target_class_label=pred_class_global)
        rules_supertree_global = self.extract_rules_from_str(supertree_str, target_class_label=pred_class_global)







    
        def cumple_regla(instancia, regla):
            import re
            import numpy as np

            def norm_op(op: str) -> str:
                return op.replace("‚â§", "<=").replace("‚â•", ">=")

            def onehot_value(var: str):
                # busca columnas tipo "var_*" y devuelve el sufijo del m√°ximo
                prefix = var + "_"
                cols = [k for k in instancia.keys() if k.startswith(prefix)]
                if not cols:
                    return None
                best = max(cols, key=lambda c: float(instancia.get(c, 0.0)))
                return best.split(prefix, 1)[1]

            for cond in regla:
                cond = cond.strip()

                # Intervalo: 'age > 44.33 ‚àß ‚â§ 48.50'
                if "‚àß" in cond:
                    m = re.match(r'(.+?)([><]=?|‚â§|‚â•)\s*([-\d\.]+)\s*‚àß\s*([><]=?|‚â§|‚â•)\s*([-\d\.]+)', cond)
                    if m:
                        var = m.group(1).strip()
                        op1, val1 = norm_op(m.group(2)), float(m.group(3))
                        op2, val2 = norm_op(m.group(4)), float(m.group(5))
                        v = float(instancia[var])
                        if not (eval(f"v {op1} {val1}") and eval(f"v {op2} {val2}")):
                            return False
                        continue

                # Num√©ricas
                if "‚â§" in cond:
                    var, val = cond.split("‚â§")
                    var = var.strip()
                    val = float(val.strip())
                    if float(instancia[var]) > val:
                        return False

                elif ">=" in cond or "‚â•" in cond:
                    var, val = cond.replace("‚â•", ">=").split(">=")
                    var = var.strip()
                    val = float(val.strip())
                    if float(instancia[var]) < val:
                        return False

                elif ">" in cond:
                    var, val = cond.split(">")
                    var = var.strip()
                    val = float(val.strip())
                    if float(instancia[var]) <= val:
                        return False

                elif "<" in cond:
                    var, val = cond.split("<")
                    var = var.strip()
                    val = float(val.strip())
                    if float(instancia[var]) >= val:
                        return False

                # Categ√≥ricas
                elif "‚â†" in cond:
                    var, val = cond.split("‚â†")
                    var = var.strip()
                    val = val.strip().replace('"', "")

                    # Caso 1: instancia decodificada (instancia[var] = "Up")
                    if var in instancia and isinstance(instancia[var], str):
                        if instancia[var] == val:
                            return False
                        continue

                    # Caso 2: one-hot directa (col = "STSlope_Up")
                    col = f"{var}_{val}"
                    if col in instancia:
                        if float(instancia[col]) >= 0.5:
                            return False
                        continue

                    # Caso 3: one-hot por argmax (si hay varias var_*)
                    oh = onehot_value(var)
                    if oh is not None and oh == val:
                        return False

                elif "=" in cond:
                    var, val = cond.split("=")
                    var = var.strip()
                    val = val.strip().replace('"', "")

                    # Caso 1: instancia decodificada
                    if var in instancia and isinstance(instancia[var], str):
                        if instancia[var] != val:
                            return False
                        continue

                    # Caso 2: one-hot directa
                    col = f"{var}_{val}"
                    if col in instancia:
                        if float(instancia[col]) < 0.5:
                            return False
                        continue

                    # Caso 3: one-hot por argmax
                    oh = onehot_value(var)
                    if oh is None or oh != val:
                        return False

                else:
                    # Si llega una condici√≥n rara, mejor fallar seguro
                    return False

            return True



        
        regla_factual_lore_local = None
        for r in rules_lore_local:
            if cumple_regla(decoded, r):
                regla_factual_lore_local = r
                break


        regla_factual_lore_global = None
        for r in rules_lore_global:
            if cumple_regla(decoded, r):
                regla_factual_lore_global = r
                break

        # ‚úÖ factual = primera regla que cumple (ya lo estabas haciendo arriba)
        rules_factual_local = [regla_factual_lore_local] if regla_factual_lore_local is not None else []
        rules_factual_global = [regla_factual_lore_global] if regla_factual_lore_global is not None else []





        

        # # # Extraer 1 contrafactual por cada clase distinta a la predicha


        # Extraer 1 contrafactual tipo LORE por cada clase distinta a la predicha
        cf_rules_LORE_local_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class_local:
                rules_clase = self.extract_rules_from_str(lore_tree_local_str, target_class_label=clase)
                if rules_clase:
                    
                    # cf_rules_LORE_local_por_clase[clase] = min(rules_clase, key=len) # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_LORE_local_por_clase[clase] = rules_clase  # ‚úÖ todos



        cf_rules_LORE_global_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class_global:
                rules_clase = self.extract_rules_from_str(lore_tree_global_str, target_class_label=clase)
                if rules_clase:
                    
                    # cf_rules_LORE_global_por_clase[clase] = min(rules_clase, key=len) # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_LORE_global_por_clase[clase] = rules_clase  # ‚úÖ todos



        # print("client_id:", self.client_id)

        # print("rules_factual_local:", rules_factual_local)
        # print("rules_factual_global:", rules_factual_global)

        # print("cf_rules_LORE_local_por_clase:", cf_rules_LORE_local_por_clase)
        # print("cf_rules_LORE_global_por_clase:", cf_rules_LORE_global_por_clase)


        # --- Factual robusto ---
        if not rules_factual_local:
            # intenta con el primero que encontraste con break (si lo guardaste)
            if regla_factual_lore_local is not None:
                rules_factual_local = [regla_factual_lore_local]

        if not rules_factual_global:
            if regla_factual_lore_global is not None:
                rules_factual_global = [regla_factual_lore_global]



        # rows = dfZ.to_dict(orient="records")
        # idx_ok = next((i for i,r in enumerate(rows) if cumple_regla(r, factual_local)), None)
        # print("primer idx en Z que cumple factual_local:", idx_ok)


        dfZ_eval_global = dfZ_global
        dfZ_eval_local  = dfZ_local


        def mask_regla_en_Z(dfZ_eval, regla):
            m = np.zeros(len(dfZ_eval), dtype=bool)
            for i, row in enumerate(dfZ_eval.to_dict(orient="records")):
                m[i] = cumple_regla(row, regla)
            return m

        has_factual = bool(rules_factual_local) and bool(rules_factual_global)


        if not has_factual:
            print(f"[CLIENTE {self.client_id}] ‚ö†Ô∏è Sin factual para instancia {num_row}")
            jaccard_cov_global = covL_g = covG_g = covInter_g = covUnion_g = np.nan
            jaccard_cov_local = covL_l = covG_l = covInter_l = covUnion_l = np.nan


        else:

            ### Global
            mL_g = mask_regla_en_Z(dfZ_eval_global, rules_factual_local[0])
            mG_g = mask_regla_en_Z(dfZ_eval_global, rules_factual_global[0])

            inter_g = np.logical_and(mL_g, mG_g).sum()
            union_g = np.logical_or(mL_g, mG_g).sum()

            jaccard_cov_global = 0.0 if union_g == 0 else inter_g / union_g
            covL_g = mL_g.mean()
            covG_g = mG_g.mean()
            covInter_g = np.logical_and(mL_g, mG_g).mean()
            covUnion_g = np.logical_or(mL_g, mG_g).mean()

            ### Local
            mL_l = mask_regla_en_Z(dfZ_eval_local, rules_factual_local[0])
            mG_l = mask_regla_en_Z(dfZ_eval_local, rules_factual_global[0])

            inter_l = np.logical_and(mL_l, mG_l).sum()
            union_l = np.logical_or(mL_l, mG_l).sum()

            jaccard_cov_local = 0.0 if union_l == 0 else inter_l / union_l
            covL_l = mL_l.mean()
            covG_l = mG_l.mean()
            covInter_l = np.logical_and(mL_l, mG_l).mean()
            covUnion_l = np.logical_or(mL_l, mG_l).mean()



        # print(f"[CLIENTE {self.client_id}]")
        # print(f"Jaccard factual_local vs factual_global: {jaccard_cov:.4f} (inter={inter}, union={union})")
        # print(f"Coverage factual_local: {covL:.4f}, factual_global: {covG:.4f}, intersecci√≥n: {covInter:.4f}, uni√≥n: {covUnion:.4f}")







        
        

        
        

        

        # ========================================================================================================================================================================================================
        # üìè M√âTRICAS DE EXPLICABILIDAD LOCAL (vecindario Z)
        # 
        # Silhouette:  Distancia media entre x y las instancias de su misma clase en el vecindario (Z+)
        # ========================================================================================================================================================================================================

        mask_same_class_local = (y_bb_local == pred_class_idx_local)
        mask_diff_class_local = (y_bb_local != pred_class_idx_local)

        Z_plus_local = dfZ_local[mask_same_class_local]
        Z_minus_local = dfZ_local[mask_diff_class_local]

        x = self.X_test[num_row]

        a = pairwise_distances([x], Z_plus_local).mean() if len(Z_plus_local) > 0 else 0.0
        b = pairwise_distances([x], Z_minus_local).mean() if len(Z_minus_local) > 0 else 0.0

        silhouette_local = 0.0
        if (a + b) > 0:
            silhouette_local = (b - a) / max(a, b)






        mask_same_class_global = (y_bb_global == pred_class_idx_global)
        mask_diff_class_global = (y_bb_global != pred_class_idx_global)

        Z_plus_global = dfZ_global[mask_same_class_global]
        Z_minus_global = dfZ_global[mask_diff_class_global]

        a_global = pairwise_distances([x], Z_plus_global).mean() if len(Z_plus_global) > 0 else 0.0
        b_global = pairwise_distances([x], Z_minus_global).mean() if len(Z_minus_global) > 0 else 0.0

        silhouette_global = 0.0
        if (a_global + b_global) > 0:
            silhouette_global = (b_global - a_global) / max(a_global, b_global)




        


        
        
        # ==================================================================================================================
        # M√âTRICAS DE COMPARATIVA DE PRECISION DE LOS √ÅRBOLES CON EL TEST (QUE TAN BUENOS SON LOS ARBOLES QUE HEMOS GENERADO)
        # ==================================================================================================================

        y_true = self.y_test

        # Predicciones
        y_pred_supertree = self.received_supertree.predict(self.X_test)

        Xg = self.X_test_global
        if Xg.ndim == 1:
            Xg = Xg.reshape(1, -1)

        y_pred_superTree_globalTest = self.received_supertree.predict(self.X_test_global)
        y_pred_localTree_globalTest = self.tree_model.predict(self.X_test_global)
        

        # Accuracy, precision, Recall, F1
        acc_supertree = accuracy_score(y_true, y_pred_supertree)

        prec_supertree       = precision_score(y_true, y_pred_supertree, average="weighted")
        rec_super  = recall_score(y_true, y_pred_supertree, average="weighted")
        f1_super   = f1_score(y_true, y_pred_supertree, average="weighted")

        acc_super_globalTest = accuracy_score(self.y_test_global, y_pred_superTree_globalTest)
        prec_super_globalTest = precision_score(self.y_test_global, y_pred_superTree_globalTest, average="weighted")
        rec_super_globalTest = recall_score(self.y_test_global, y_pred_superTree_globalTest, average="weighted")
        f1_super_globalTest = f1_score(self.y_test_global, y_pred_superTree_globalTest, average="weighted")

        acc_localTree_globalTest = accuracy_score(self.y_test_global, y_pred_localTree_globalTest)
        prec_localTree_globalTest = precision_score(self.y_test_global, y_pred_localTree_globalTest, average="weighted")
        rec_localTree_globalTest = recall_score(self.y_test_global, y_pred_localTree_globalTest, average="weighted")
        f1_localTree_globalTest = f1_score(self.y_test_global, y_pred_localTree_globalTest, average="weighted")







        # ================= CSV por cliente =================
        row = {
            "round": int(round_number),
            "dataset": DATASET_NAME,
            "client_id": int(self.client_id),
            "bbox_pred_class_global": str(pred_class_global),
            "bbox_pred_class_local": str(pred_class_local),

            # Vecindario
            "silhouette_global": float(silhouette_global),
            "silhouette_local": float(silhouette_local),

            # ================= M√©tricas de como de buenos son los √°rboles =================
            "acc_superTree_localTest": float(acc_supertree),
            "prec_superTree_localTest": float(prec_supertree),
            "rec_superTree_localTest": float(rec_super),
            "f1_superTree_localTest": float(f1_super),

            "acc_superTree_globalTest": float(acc_super_globalTest),
            "prec_superTree_globalTest": float(prec_super_globalTest),
            "rec_superTree_globalTest": float(rec_super_globalTest),
            "f1_superTree_globalTest": float(f1_super_globalTest),

            # üîπ m√©tricas LOCALES (guardadas antes)
            "acc_localTree_localTest": self.local_metrics["acc_local_tree"],
            "prec_localTree_localTest": self.local_metrics["prec_local_tree"],
            "rec_localTree_localTest": self.local_metrics["rec_local_tree"],
            "f1_localTree_localTest": self.local_metrics["f1_local_tree"],

            "acc_localTree_globalTest": float(acc_localTree_globalTest),
            "prec_localTree_globalTest": float(prec_localTree_globalTest),
            "rec_localTree_globalTest": float(rec_localTree_globalTest),
            "f1_localTree_globalTest": float(f1_localTree_globalTest),

            # Jaccard en vecindad GLOBAL
            "jaccard_cov_globalZ": float(jaccard_cov_global),
            "covL_globalZ": float(covL_g),
            "covG_globalZ": float(covG_g),
            "covInter_globalZ": float(covInter_g),
            "covUnion_globalZ": float(covUnion_g),

            # Jaccard en vecindad LOCAL
            "jaccard_cov_localZ": float(jaccard_cov_local),
            "covL_localZ": float(covL_l),
            "covG_localZ": float(covG_l),
            "covInter_localZ": float(covInter_l),
            "covUnion_localZ": float(covUnion_l),    
        }


        

        # (Opcional) M√©tricas contrafactuales por clase en columnas ‚Äúanchas‚Äù

        # for cl in self.unique_labels:
        #     row[f"cf_cov_lore_{cl}_TEST"]  = self._to_float(coverage_cf_lore.get(cl,0))
        #     row[f"cf_comp_lore_{cl}_TEST"] = int(comp_cf_lore_simpl.get(cl,0))
        #     row[f"cf_prec_lore_{cl}_TEST"] = self._to_float(precision_cf_lore.get(cl,0))


        # Guardar
        self._append_client_csv(row, filename="Balanced")

        return row

    # ======================================================================
    # Bucle sobre todo el test
    # ======================================================================
    def explain_all_test_instances(self, config, only_idx=None):
        results = []

        # Si only_idx es None ‚Üí explicamos TODO el test
        # Si only_idx es un entero ‚Üí explicamos solo esa instancia
        if only_idx is None:
            indices = range(len(self.X_test))
            desc_text = f"Cliente {self.client_id} explicando test completo"
            save_trees_flag = False      

        else:
            indices = [only_idx]
            desc_text = f"Cliente {self.client_id} explicando instancia {only_idx}"
            save_trees_flag = True


        for i in tqdm(indices, desc=desc_text):
            try:

                row = self._explain_one_instance(i, config, save_trees=save_trees_flag)
                results.append(row)

            except Exception as e:
                print(f"[Cliente {self.client_id}] ‚ö†Ô∏è Error en instancia {i}: {e}")
                continue


        full_path = f"results/metrics_Balanced_cliente_{self.client_id}.csv"
        df = pd.read_csv(full_path)

        mean_metrics = df.mean(numeric_only=True)
        count_metrics = df.count(numeric_only=True)  # no-NaN

        mean_df = pd.DataFrame({"mean": mean_metrics, "count": count_metrics})

        # ratio: % instancias donde se pudo calcular jaccard (no NaN)
        mean_df.loc["ratio_has_factual_globalZ", ["mean", "count"]] = [
            df["jaccard_cov_globalZ"].notna().mean(),
            int(df["jaccard_cov_globalZ"].notna().sum()),
        ]

        mean_df.loc["ratio_has_factual_localZ", ["mean", "count"]] = [
            df["jaccard_cov_localZ"].notna().mean(),
            int(df["jaccard_cov_localZ"].notna().sum()),
        ]

        mean_df.to_csv(
            f"results/metrics_cliente_{self.client_id}_balanced_mean.csv",
            index_label="metric"
        )

        return df


        

            


def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    dataset_name = DATASET_NAME 
    class_col = CLASS_COLUMN 


    (X_train, y_train,
     X_test_local, y_test_local,
     X_test_global, y_test_global,
     dataset, feature_names, label_encoder,
     scaler, numeric_features, encoder, preprocessor) = load_data_general(flower_dataset_name=dataset_name,class_col=class_col,partition_id=partition_id,num_partitions=NUM_CLIENTS)
    
    tree_model = DecisionTreeClassifier(max_depth=3, min_samples_split=2, random_state=42)

    num_idx = list(range(len(numeric_features)))

    scaler_nn = StandardScaler().fit(X_train[:, num_idx])

    # ‚úÖ SIEMPRE MISMO N√öMERO DE CLASES GLOBAL
    n_clases_global = len(UNIQUE_LABELS)  # o len(label_encoder.classes_)

    input_dim = X_train.shape[1]
    output_dim = n_clases_global

    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model=tree_model, 
                        nn_model=nn_model,
                        X_train=X_train,
                        y_train=y_train,
                        X_test=X_test_local,
                        y_test=y_test_local,
                        X_test_global=X_test_global,
                        y_test_global=y_test_global,
                        dataset=dataset,
                        client_id=partition_id + 1,
                        feature_names=feature_names,
                        label_encoder=label_encoder,
                        scaler=scaler,
                        numeric_features=numeric_features,
                        encoder=encoder,
                        preprocessor=preprocessor,         
                        scaler_nn_mean=scaler_nn.mean_,  
                        scaler_nn_scale=scaler_nn.scale_,
                        num_idx=num_idx).to_client()

client_app = ClientApp(client_fn=client_fn)


# Servidor

In [5]:
# ============================
# üì¶ IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================
# ‚öôÔ∏è CONFIGURACI√ìN GLOBAL
# ============================
# MIN_AVAILABLE_CLIENTS = 4
# NUM_SERVER_ROUNDS = 2

FEATURES = []  # se rellenan din√°micamente
UNIQUE_LABELS = []
LATEST_SUPERTREE_JSON = None
GLOBAL_MAPPING_JSON = None
FEATURE_NAMES_JSON = None
GLOBAL_SCALER_JSON = None


# ============================
# üß† UTILIDADES MODELO
# ============================
def create_model(input_dim, output_dim):
    from __main__ import Net  # necesario si Net est√° en misma libreta
    return Net(input_dim, output_dim)


def get_model_parameters(tree_model, nn_model):
    tree_params = [-1, 2, 1]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    sums: Dict[str, float] = {}
    counts: Dict[str, int] = {}

    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                sums[k] = sums.get(k, 0.0) + n * float(v)
                counts[k] = counts.get(k, 0) + n

    return {k: sums[k] / counts[k] for k in sums}

# ============================
# üöÄ SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    global FEATURES, UNIQUE_LABELS

    # Justo antes de llamar a create_model
    if not FEATURES or not UNIQUE_LABELS:
        
        load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)


    FEATURES = FEATURES or ["feat_0", "feat_1"]  # fallback por si no se carg√≥ antes
    UNIQUE_LABELS = UNIQUE_LABELS or ["Class_0", "Class_1"]


    model = create_model(len(FEATURES), len(UNIQUE_LABELS))
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])

    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)
    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)

        # ============================
        # üîπ Ronda final: NO fusionar nada
        # ============================
        if server_round > NUM_TRAIN_ROUNDS:
            return aggregated_metrics

        try:
            print(f"\n[SERVIDOR] üå≤ Generando SuperTree - Ronda {server_round}")
            from collections import defaultdict

            tree_nodes = []
            all_distincts = defaultdict(set)
            client_encoders = {}

            feature_names = None
            numeric_features = None
            class_names = None

            # 1) recolectar mapeos categ√≥ricos y metadatos
            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                # distinct_values_* para global_mapping
                for k, v in metrics.items():
                    if k.startswith("distinct_values_"):
                        cid = k.split("_")[-1]
                        enc = json.loads(v)
                        client_encoders[cid] = enc
                        for feat, d in enc.items():
                            all_distincts[feat].update(d["distinct_values"])

            global_mapping = {feat: sorted(list(vals)) for feat, vals in all_distincts.items()}

            # 2) recolectar √°rboles y dem√°s metadatos por cliente
            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                for k, v in metrics.items():
                    if k.startswith("tree_ensemble_"):
                        cid = k.split("_")[-1]
                        trees_list = json.loads(v)

                        # lee estos una sola vez (son iguales por cliente)
                        if feature_names is None and f"encoded_feature_names_{cid}" in metrics:
                            feature_names = json.loads(metrics[f"encoded_feature_names_{cid}"])
                        if numeric_features is None and f"numeric_features_{cid}" in metrics:
                            numeric_features = json.loads(metrics[f"numeric_features_{cid}"])
                        if class_names is None and f"unique_labels_{cid}" in metrics:
                            class_names = json.loads(metrics[f"unique_labels_{cid}"])

                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)
                            tree_nodes.append(root)

            if not tree_nodes:
                return aggregated_metrics

            # 3) fusionar
            st = SuperTree()
            st.mergeDecisionTrees(
                roots=tree_nodes,
                num_classes=len(class_names),
                feature_names=feature_names,
                categorical_features=list(global_mapping.keys()),
                global_mapping=global_mapping,
            )

            # print("\n[SERVIDOR] SuperTree unpruned:")
            # print(st)
            # print("\n")

            # print("\n[SERVIDOR] SuperTree prune_redundant_leaves_full:")
            st.prune_redundant_leaves_full()
            # print(st)
            # print("\n")


            # print("\n[SERVIDOR] SuperTree merge_equal_class_leaves:")
            # st.merge_equal_class_leaves()

            # print(supertree)
            # print("\n")
            
            # print("\n")

            # 4) guardar/emitir
            save_supertree_plot(
                root_node=st.root,
                round_number=server_round,
                feature_names=feature_names,
                class_names=class_names,
                numeric_features=numeric_features,
                global_mapping=global_mapping,   # sin scaler
            )

            LATEST_SUPERTREE_JSON = json.dumps(st.root.to_dict())
            GLOBAL_MAPPING_JSON = json.dumps(global_mapping)
            FEATURE_NAMES_JSON = json.dumps(feature_names)

        except Exception as e:
            print(f"[SERVIDOR] ‚ùå Error en SuperTree: {e}")

        return aggregated_metrics


    strategy.aggregate_evaluate = custom_aggregate_evaluate
    return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=NUM_SERVER_ROUNDS))

# ============================
# üß© FUNCIONES AUXILIARES
# ============================
def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round

            # Siempre mandamos el √∫ltimo SuperTree disponible
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
                ins.config["global_mapping"] = GLOBAL_MAPPING_JSON
                ins.config["feature_names"] = FEATURE_NAMES_JSON

            # Ronda final: modo solo explicaci√≥n
            if server_round == NUM_SERVER_ROUNDS:
                ins.config["explain_only"] = True
        return instructions
    return wrapper



def print_supertree_legible_fusionado(
    node,
    feature_names,
    class_names,
    numeric_features,
    scaler,  # dict con mean y std
    global_mapping,
    depth=0
):
    import numpy as np
    indent = "|   " * depth
    if node is None:
        print(f"{indent}[Nodo None]")
        return

    if getattr(node, "is_leaf", False):
        class_idx = int(np.argmax(node.labels))
        print(f"{indent}class: {class_names[class_idx]} (pred: {node.labels})")
        return

    feat_idx = node.feat
    feat_name = feature_names[feat_idx]
    intervals = node.intervals
    children = node.children

    # ====== NUM√âRICA ======
    if feat_name in numeric_features:
        bounds = [-np.inf] + list(intervals)
        while len(bounds) < len(children) + 1:
            bounds.append(np.inf)

        for i, child in enumerate(children):
            left = bounds[i]
            right = bounds[i + 1]
            left_real  = left
            right_real = right

            if i == 0:
                cond = f"{feat_name} ‚â§ {right_real:.2f}"
            elif i == len(children) - 1:
                cond = f"{feat_name} > {left_real:.2f}"
            else:
                cond = f"{feat_name} ‚àà ({left_real:.2f}, {right_real:.2f}]"
            print(f"{indent}{cond}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features,
                scaler=None,  # ya no se usa
                global_mapping=global_mapping, depth=depth + 1
            )

    # ====== CATEG√ìRICA ONEHOT ======
    elif "=" in feat_name or "_" in feat_name:
        # Soporta 'var=valor' o 'var_valor'
        if "=" in feat_name:
            var, val = feat_name.split("=", 1)
        else:
            var, val = feat_name.split("_", 1)
        var = var.strip()
        val = val.strip()

        if len(children) != 2:
            print(f"[ERROR] Nodo OneHot {feat_name} tiene {len(children)} hijos, esperado 2.")

        # Primero !=, luego ==
        conds = [
            f'{var} != "{val}"',
            f'{var} == "{val}"'
        ]
        for i, child in enumerate(children):
            print(f"{indent}{conds[i]}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )

    # ====== CATEG√ìRICA ORDINAL ======
    elif global_mapping and feat_name in global_mapping:
        vals_cat = global_mapping[feat_name]
        # Primero !=, luego ==
        for i, child in enumerate(children):
            try:
                val_idx = node.intervals[i] if hasattr(node, "intervals") and i < len(node.intervals) else int(getattr(node, "thresh", 0))
                val = vals_cat[val_idx] if val_idx < len(vals_cat) else f"desconocido({val_idx})"
            except Exception as e:
                print(f"[DEPURACI√ìN] Error interpretando categ√≥rica: {e}")
                val = "?"
            cond = f'{feat_name} != "{val}"' if i == 0 else f'{feat_name} == "{val}"'
            print(f"{indent}{cond}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )

    # ====== TIPO DESCONOCIDO ======
    else:
        print(f"{indent}{feat_name} [tipo desconocido]")
        print(f"    [DEPURACI√ìN] Nombres de features: {feature_names}")
        print(f"    [DEPURACI√ìN] Nombres num√©ricas: {numeric_features}")
        print(f"    [DEPURACI√ìN] global_mapping: {list(global_mapping.keys()) if global_mapping else None}")
        print(f"    [DEPURACI√ìN] children: {len(children)}")
        for child in children:
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )



def save_supertree_plot(
    root_node,
    round_number,
    feature_names,
    class_names,
    numeric_features,
    global_mapping,
    folder="Supertree",
):
    from graphviz import Digraph
    import numpy as np
    import os

    dot = Digraph()
    node_id = [0]

    def add_node(node, parent=None, edge_label=""):
        curr = str(node_id[0]); node_id[0] += 1

        # etiqueta
        if node.is_leaf:
            class_index = int(np.argmax(node.labels))
            label = f"class: {class_names[class_index]}\n{node.labels}"
        else:
            fname = feature_names[node.feat]
            label = fname.split("_", 1)[0] if "_" in fname else fname

        dot.node(curr, label)
        if parent: dot.edge(parent, curr, label=edge_label)

        if not node.is_leaf:
            fname = feature_names[node.feat]
            # OneHot
            if "_" in fname:
                _, val = fname.split("_", 1)
                add_node(node.children[0], curr, f'‚â† "{val.strip()}"')
                add_node(node.children[1], curr, f'= "{val.strip()}"')
            # Num√©rica
            elif fname in numeric_features:
                thr = node.intervals[0] if node.intervals else node.thresh
                add_node(node.children[0], curr, f"‚â§ {thr:.2f}")
                add_node(node.children[1], curr, f"> {thr:.2f}")
            # Categ√≥rica ordinal
            elif fname in global_mapping:
                vals = global_mapping[fname]
                val = vals[node.intervals[0]] if node.intervals else "?"
                add_node(node.children[0], curr, f'= "{val}"')
                add_node(node.children[1], curr, f'‚â† "{val}"')
            else:
                for ch in node.children:
                    add_node(ch, curr, "?")

    folder_path = f"Ronda_{round_number}/{folder}"
    os.makedirs(folder_path, exist_ok=True)
    filename = f"{folder_path}/supertree_ronda_{round_number}"
    add_node(root_node)
    dot.render(filename, format="pdf", cleanup=True)
    return f"{filename}.pdf"




# ============================
# üîß INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)



In [6]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray
import cProfile
import pstats

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
logging.getLogger().setLevel(logging.WARNING)  # O ERROR para ocultar a√∫n m√°s
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("fsspec").setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesi√≥n previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

2026-02-20 13:22:10,235	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2026-02-20 13:22:15,325 flwr         DEBUG    Asyncio event loop already running.
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)
[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures


[CLIENTE 3] ‚úÖ LOCAL baseline entrenado y guardado
[CLIENTE 5] ‚úÖ LOCAL baseline entrenado y guardado
[CLIENTE 6] ‚úÖ LOCAL baseline entrenado y guardado
[CLIENTE 1] ‚úÖ LOCAL baseline entrenado y guardado
[CLIENTE 4] ‚úÖ LOCAL baseline entrenado y guardado
[CLIENTE 2] ‚úÖ LOCAL baseline entrenado y guardado


[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)
[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures



[SERVIDOR] üå≤ Generando SuperTree - Ronda 1


[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)
[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....


[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures


Recibiendo supertree....
Recibiendo supertree....

[SERVIDOR] üå≤ Generando SuperTree - Ronda 2


[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)
[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


[CLIENTE 2] üì¶ LOCAL baseline recargado en evaluate()
Recibiendo supertree....
[CLIENTE 2] üîç Ronda final: solo explicaciones
[CLIENTE 4] üì¶ LOCAL baseline recargado en evaluate()
Recibiendo supertree....
[CLIENTE 4] üîç Ronda final: solo explicaciones


Cliente 2 explicando test completo:   0%|          | 0/32 [00:00<?, ?it/s]

[CLIENTE 5] üì¶ LOCAL baseline recargado en evaluate()
Recibiendo supertree....
[CLIENTE 5] üîç Ronda final: solo explicaciones



[A

[CLIENTE 3] üì¶ LOCAL baseline recargado en evaluate()
Recibiendo supertree....
[CLIENTE 3] üîç Ronda final: solo explicaciones




[A[A

[CLIENTE 6] üì¶ LOCAL baseline recargado en evaluate()
Recibiendo supertree....
[CLIENTE 6] üîç Ronda final: solo explicaciones





[A[A[A

[CLIENTE 1] üì¶ LOCAL baseline recargado en evaluate()
Recibiendo supertree....
[CLIENTE 1] üîç Ronda final: solo explicaciones






[A[A[A[A

[A[A


[A[A[A
Cliente 2 explicando test completo:   3%|‚ñé         | 1/32 [00:56<29:04, 56.28s/it]



[A[A[A[A


[A[A[A

[A[A
[A

[A[A


Cliente 2 explicando test completo:   6%|‚ñã         | 2/32 [01:53<28:20, 56.70s/it]
[A



[A[A[A[A

[A[A


[A[A[A
[A



Cliente 2 explicando test completo:   9%|‚ñâ         | 3/32 [02:52<28:02, 58.00s/it]

[A[A


[A[A[A
[A



[A[A[A[A

Cliente 2 explicando test completo:  12%|‚ñà‚ñé        | 4/32 [03:38<24:47, 53.11s/it]


[A[A[A
[A



[A[A[A[A

[A[A


Cliente 2 explicando test completo:  16%|‚ñà‚ñå        | 5/32 [04:34<24:21, 54.15s/it]
[A

[A[A



[A[A[A[A
[A


[A[A[A

Cliente 2 explicando test completo:  19%|‚ñà‚ñâ        | 6/32 [05:32<24:00, 55.39s/it]
[A



[A[A[A[A


[A[A[A

[A[A
Cliente 2 explicando test completo:  22%|‚ñà‚ñà‚ñè       | 7/32 [06:29<23:18, 55.94s/it]



[A[A[A[A


[A[A[A

[A[A
[A


[A[A[A

Cliente 2 explicando test completo:  25%

[CLIENTE 1] ‚ö†Ô∏è Sin factual para instancia 19






[A[A[A[A

[A[A


Cliente 2 explicando test completo:  59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 19/32 [16:55<11:26, 52.83s/it]

[A[A
[A



[A[A[A[A


[A[A[A

[A[A
Cliente 2 explicando test completo:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 20/32 [17:48<10:35, 52.94s/it]



[A[A[A[A


[A[A[A

[A[A
Cliente 2 explicando test completo:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 21/32 [18:36<09:27, 51.59s/it]


[A[A[A



[A[A[A[A

Cliente 3 explicando test completo: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 31/31 [18:49<00:00, 36.45s/it]

[A


Cliente 4 explicando test completo: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [19:26<00:00, 40.21s/it]
Cliente 2 explicando test completo:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 22/32 [19:27<08:33, 51.35s/it]



[A[A[A[A
[A



[A[A[A[A


[A[A[A
Cliente 5 explicando test completo: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [19:59<00:00, 37.50s/it]
Cliente 2 explicando test completo:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 23/32 [20:00<06:51, 45.72s/it]





### BALANCED METRICS

In [7]:
# ==========================================
# üìä Promedio global a partir de los *balanced_mean*
# - Lee results/metrics_cliente_*_balanced_mean.csv con columnas: metric, mean, count
# - Calcula media GLOBAL ponderada por count (no macro-media)
# - Colapsa m√©tricas por clase (TEST y Z) tambi√©n ponderando por count
# - Guarda en: experimentos_FeatureSkew/<DATASET>/<SKEW_FEATS>/{NCLIENTS}_Clients_Mean_global.csv
# ==========================================
from pathlib import Path
import pandas as pd
import re

# -------------------------------------------------
# ‚úÖ AJUSTA ESTO en tu notebook/entorno:
# - DATASET_NAME (ya lo tienes)
# - SKEW_FEATURE (str o lista de str, ej. ["Age","Sex"])
# -------------------------------------------------

def _slug_feat(name: str) -> str:
    """Normaliza nombres para usarlos como carpeta."""
    name = str(name).strip().replace(" ", "_")
    name = re.sub(r"[^A-Za-z0-9_\-]", "", name)
    return name

def _skew_tag(skew_feature) -> str:
    """Convierte skew_feature (str/list/tuple/set) en 'Age_Sex'."""
    if isinstance(skew_feature, (list, tuple, set)):
        parts = [_slug_feat(x) for x in skew_feature]
    else:
        parts = [_slug_feat(skew_feature)]
    parts = [p for p in parts if p]
    return "_".join(parts) if parts else "UnknownSkew"

def _weighted_mean(subdf: pd.DataFrame, mean_col: str = "mean", count_col: str = "count"):
    """Media ponderada ignorando NaNs."""
    m = subdf[mean_col]
    c = subdf[count_col]
    mask = m.notna() & c.notna() & (c > 0)
    if mask.sum() == 0:
        return float("nan")
    return float((m[mask] * c[mask]).sum() / c[mask].sum())

# -------------------------------------------------
# üì• Cargar ficheros balanced_mean
# -------------------------------------------------
csv_dir = Path("results")
files = sorted(csv_dir.glob("metrics_cliente_*_balanced_mean.csv"))

if not files:
    raise FileNotFoundError("No encuentro ficheros metrics_cliente_*_balanced_mean.csv en results/")

print("Voy a usar estos ficheros:")
for f in files:
    print("  -", f.name)

n_clients = len(files)
DATASET_NAME = DATASET_NAME.split("/")[-1]

# -------------------------------------------------
# üìÅ Carpeta destino
# -------------------------------------------------
out_dir = Path("experimentos_balanced") / DATASET_NAME
out_dir.mkdir(parents=True, exist_ok=True)

# -------------------------------------------------
# üìÑ Leer cada fichero y acumular en formato largo:
# metric | mean | count | client_file
# -------------------------------------------------
dfs_long = []
for f in files:
    df = pd.read_csv(f)

    # Normaliza nombres por si vinieran raros
    df = df.rename(columns={
        df.columns[0]: "metric",
        df.columns[1]: "mean",
    })
    if len(df.columns) >= 3:
        df = df.rename(columns={df.columns[2]: "count"})
    else:
        # Si no hay count, asumimos 1 (macro-media)
        df["count"] = 1.0

    df["client_file"] = f.stem
    dfs_long.append(df[["metric", "mean", "count", "client_file"]])

all_long = pd.concat(dfs_long, ignore_index=True)

# -------------------------------------------------
# ‚úÖ Media global ponderada por count, por m√©trica
# -------------------------------------------------
means_df = (
    all_long
    .groupby("metric", as_index=False)
    .apply(lambda g: pd.Series({
        "mean": _weighted_mean(g, "mean", "count"),
        "count": float(g["count"].dropna().sum())
    }))
    .reset_index(drop=True)
)

# ==========================================
# üìâ Colapsar clases (TEST y Z) (ponderado)
# - crea nuevas m√©tricas "colapsadas"
# - luego elimina las m√©tricas por-clase originales
# ==========================================
collapse_patterns = {
    # --------- TEST ----------
    "cf_cov_merged_TEST":   r"^cf_cov_merged_[^_]+_TEST$",
    "cf_comp_merged_TEST":  r"^cf_comp_merged_[^_]+_TEST$",
    "cf_prec_merged_TEST":  r"^cf_prec_merged_[^_]+_TEST$",

    "cf_cov_lore_TEST":     r"^cf_cov_lore_[^_]+_TEST$",
    "cf_comp_lore_TEST":    r"^cf_comp_lore_[^_]+_TEST$",
    "cf_prec_lore_TEST":    r"^cf_prec_lore_[^_]+_TEST$",

    "cf_cov_super_TEST":    r"^cf_cov_super_[^_]+_TEST$",
    "cf_comp_super_TEST":   r"^cf_comp_super_[^_]+_TEST$",
    "cf_prec_super_TEST":   r"^cf_prec_super_[^_]+_TEST$",

    "cf_cov_local_local_TEST":  r"^cf_cov_local_local_[^_]+_TEST$",
    "cf_comp_local_local_TEST": r"^cf_comp_local_local_[^_]+_TEST$",
    "cf_prec_local_local_TEST": r"^cf_prec_local_local_[^_]+_TEST$",

    "cf_cov_local_super_TEST":  r"^cf_cov_local_super_[^_]+_TEST$",
    "cf_comp_local_super_TEST": r"^cf_comp_local_super_[^_]+_TEST$",
    "cf_prec_local_super_TEST": r"^cf_prec_local_super_[^_]+_TEST$",

    "cf_cov_localZ_TEST":   r"^cf_cov_localZ_[^_]+_TEST$",
    "cf_comp_localZ_TEST":  r"^cf_comp_localZ_[^_]+_TEST$",
    "cf_prec_localZ_TEST":  r"^cf_prec_localZ_[^_]+_TEST$",

    "cf_cov_superZ_TEST":   r"^cf_cov_superZ_[^_]+_TEST$",
    "cf_comp_superZ_TEST":  r"^cf_comp_superZ_[^_]+_TEST$",
    "cf_prec_superZ_TEST":  r"^cf_prec_superZ_[^_]+_TEST$",

    # --------- Z ----------
    "cf_cov_merged_Z":   r"^cf_cov_merged_[^_]+_Z$",
    "cf_prec_merged_Z":  r"^cf_prec_merged_[^_]+_Z$",

    "cf_cov_lore_Z":     r"^cf_cov_lore_[^_]+_Z$",
    "cf_prec_lore_Z":    r"^cf_prec_lore_[^_]+_Z$",

    "cf_cov_super_Z":    r"^cf_cov_super_[^_]+_Z$",
    "cf_prec_super_Z":   r"^cf_prec_super_[^_]+_Z$",

    "cf_cov_local_local_Z":  r"^cf_cov_local_local_[^_]+_Z$",
    "cf_prec_local_local_Z": r"^cf_prec_local_local_[^_]+_Z$",

    "cf_cov_local_super_Z":  r"^cf_cov_local_super_[^_]+_Z$",
    "cf_prec_local_super_Z": r"^cf_prec_local_super_[^_]+_Z$",

    "cf_cov_localZ_Z":   r"^cf_cov_localZ_[^_]+_Z$",
    "cf_prec_localZ_Z":  r"^cf_prec_localZ_[^_]+_Z$",

    "cf_cov_superZ_Z":   r"^cf_cov_superZ_[^_]+_Z$",
    "cf_prec_superZ_Z":  r"^cf_prec_superZ_[^_]+_Z$",
}

rows_new = []
for new_name, pattern in collapse_patterns.items():
    sub = means_df[means_df["metric"].str.match(pattern)]
    if len(sub) == 0:
        continue
    rows_new.append({
        "metric": new_name,
        "mean": _weighted_mean(sub, "mean", "count"),
        "count": float(sub["count"].dropna().sum())
    })

if rows_new:
    means_df = pd.concat([means_df, pd.DataFrame(rows_new)], ignore_index=True)

# Elimina las m√©tricas por-clase originales (las que acabas de colapsar)
pattern_drop = (
    r"^cf_(cov|prec|comp)_"
    r"(merged|lore|super|local_local|local_super|localZ|superZ)_"
    r"[^_]+_(TEST|Z)$"
)
means_df = means_df[~means_df["metric"].str.match(pattern_drop)].reset_index(drop=True)

# -------------------------------------------------
# üíæ Guardar resultado final (solo metric,mean)
# -------------------------------------------------
out_path = out_dir / f"{n_clients}_Clients_Mean_global.csv"
means_df[["metric", "mean"]].to_csv(out_path, index=False, encoding="utf-8")
print(f"\n‚úÖ Promedios globales (ponderados por count) guardados en: {out_path}")


Voy a usar estos ficheros:
  - metrics_cliente_1_balanced_mean.csv
  - metrics_cliente_2_balanced_mean.csv
  - metrics_cliente_3_balanced_mean.csv
  - metrics_cliente_4_balanced_mean.csv
  - metrics_cliente_5_balanced_mean.csv
  - metrics_cliente_6_balanced_mean.csv

‚úÖ Promedios globales (ponderados por count) guardados en: experimentos_balanced\Diabetes\6_Clients_Mean_global.csv
