In [1]:
# =======================
# üì¶ IMPORTACIONES
# =======================

# Built-in
import os
import sys
import re
import time
import json
import random
import warnings
from typing import List, Tuple, Dict
import operator
import re
from pathlib import Path
import numpy as np
import pandas as pd


# NumPy, Pandas, Matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sklearn
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score, pairwise_distances
)
from sklearn.exceptions import NotFittedError
from collections import defaultdict
from sklearn.metrics import classification_report, confusion_matrix


# Flower
from flwr.client import ClientApp, NumPyClient
from flwr.common import (
    Context, NDArrays, Metrics, Scalar,
    ndarrays_to_parameters, parameters_to_ndarrays
)
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# LORE
from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.rule import Expression, Rule

from lore_sa.client_utils import ClientUtilsMixin

# Otros
from pathlib import Path
from filelock import FileLock  # pip install filelock
import pandas as pd, os
from graphviz import Digraph
from tqdm import tqdm
from datetime import datetime
import cProfile, pstats, io
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner
from lore_sa.client_utils import LabelShardPartitioner



2025-12-11 13:48:32,576	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-12-11 13:48:36,214 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.piping.pipe(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-12-11 13:48:36,214 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.rendering.render(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-12-11 13:48:36,214 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.unflattening.unflatten(['stagger', 'fanout', 'chain', 'encoding'])
2025-12-11 13:48:36,214 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.viewing.view(['quiet'])
2025-12-11 13:48:36,231 graphviz._tools DEBUG    deprecate positional args: graphviz.quoting.quote(['is_html_string', 'is_valid_id', 'dot_keywords', 'endswith_odd_number_of_backslashes', 'escape_unescaped

In [2]:
# =======================
# ‚öôÔ∏è VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []

NUM_TRAIN_ROUNDS = 2        # rondas donde entrenas la NN
NUM_SERVER_ROUNDS = 3       # la √∫ltima solo para explicaciones
NUM_CLIENTS = 6
SEED = 42

NON_IID = True   # o False para los experimentos IID
NON_IID_ALPHA = 0.3  # por ejemplo, Dirichlet m√°s sesgado

MIN_AVAILABLE_CLIENTS = NUM_CLIENTS
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None


# ==============================================
# üßπ Borrar TODOS los CSV individuales de clientes
# ==============================================

csv_dir = Path("results")
all_csvs = list(csv_dir.glob("*.csv"))

# Solo borrar si hay alguno
if all_csvs:
    for f in all_csvs:
        try:
            f.unlink()
        except Exception:
            pass  # Ignora errores




# =======================
# üîß UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# üì• CARGAR DATOS
# =======================

def get_global_onehot_info(flower_dataset_name, class_col):
    partitioner = IidPartitioner(num_partitions=1)
    fds_tmp = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner})
    df = fds_tmp.load_partition(0, "train").with_format("pandas")[:]

    # Preprocesado est√°ndar
    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        df.drop(columns=[col for col in drop_cols if col in df.columns], inplace=True)

    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        df.drop(columns=[col for col in drop_cols if col in df.columns], inplace=True)
        df['MonthlyCharges'] = pd.to_numeric(df['MonthlyCharges'], errors='coerce')
        df['tenure'] = pd.to_numeric(df['tenure'], errors='coerce')
        df['SeniorCitizen'] = df['SeniorCitizen'].map({0: 'No', 1: 'Yes'}).astype(str)
        df.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)
    
    elif "breastcancer" in flower_dataset_name.lower():
        # Preprocesado espec√≠fico para el dataset de c√°ncer de mama
        df.drop(columns=['id'], inplace=True, errors='ignore')



    for col in df.select_dtypes(include=["object"]).columns:
        if df[col].nunique() < 50:
            df[col] = df[col].astype("category")

    cat_features = [col for col in df.select_dtypes(include="category").columns if col != class_col]
    num_features = [col for col in df.columns if df[col].dtype.kind in "fi" and col != class_col]

    ohe = OneHotEncoder(handle_unknown="ignore")
    ohe.fit(df[cat_features])
    categories_global = ohe.categories_
    onehot_columns = ohe.get_feature_names_out(cat_features).tolist()
    return cat_features, num_features, categories_global, onehot_columns



def load_data_general(flower_dataset_name: str, class_col: str, partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES

    # Saca info global siempre al principio
    cat_features, num_features, categories_global, onehot_columns = get_global_onehot_info(flower_dataset_name, class_col)

        # 2) Definir clases globales SOLO UNA VEZ usando TODO el train
    if not UNIQUE_LABELS:
        partitioner_all = IidPartitioner(num_partitions=1)
        fds_all = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner_all})
        df_all = fds_all.load_partition(0, "train").with_format("pandas")[:]

        le_global = LabelEncoder()
        le_global.fit(df_all[class_col])
        UNIQUE_LABELS[:] = le_global.classes_.tolist()

    if fds is None:
        if NON_IID:
            partitioner = LabelShardPartitioner(
                num_partitions=num_partitions,
                partition_by=class_col,
                n_classes_per_client=1,  # o 2 si quieres algo menos extremo
                shards_per_class=3,        # prueba 2, 4, etc.
                shuffle=True,
                seed=SEED,
            )
        else:
            partitioner = IidPartitioner(num_partitions=num_partitions)

        fds = FederatedDataset(
            dataset=flower_dataset_name,
            partitioners={"train": partitioner},
        )

    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]


    # Preprocesado espec√≠fico por dataset
    if "adult" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)

    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset['MonthlyCharges'] = pd.to_numeric(dataset['MonthlyCharges'], errors='coerce')
        dataset['tenure'] = pd.to_numeric(dataset['tenure'], errors='coerce')
        dataset['SeniorCitizen'] = dataset['SeniorCitizen'].map({0: 'No', 1: 'Yes'}).astype(str)

        dataset.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)

    elif "breastcancer" in flower_dataset_name.lower():
        # Preprocesado espec√≠fico para el dataset de c√°ncer de mama
        dataset.drop(columns=['id'], inplace=True, errors='ignore')

    for col in dataset.select_dtypes(include=["object"]).columns:
        if dataset[col].nunique() < 50:
            dataset[col] = dataset[col].astype("category")

    class_original = dataset[class_col].copy()
    tabular_dataset = TabularDataset(dataset.copy(), class_name=class_col)
    descriptor = tabular_dataset.descriptor

    for col, info in descriptor["categorical"].items():
        if "distinct_values" not in info or not info["distinct_values"]:
            info["distinct_values"] = list(dataset[col].dropna().unique())

    # 4) AQU√ç: NUNCA hacer fit por partici√≥n, solo usar las clases globales
    label_encoder = LabelEncoder()
    label_encoder.classes_ = np.array(UNIQUE_LABELS)
    dataset[class_col] = label_encoder.transform(dataset[class_col])
    dataset.rename(columns={class_col: "class"}, inplace=True)
    y = dataset["class"].reset_index(drop=True).to_numpy()

    numeric_features = list(descriptor["numeric"].keys())
    categorical_features = list(descriptor["categorical"].keys())
    FEATURES[:] = numeric_features + categorical_features

    numeric_indices = list(range(len(numeric_features)))
    categorical_indices = list(range(len(numeric_features), len(FEATURES)))

    X_array = dataset[FEATURES].to_numpy()

    preprocessor = ColumnTransformer([
        ("num", "passthrough", numeric_indices),
        ("cat", OneHotEncoder(sparse_output=False, handle_unknown="ignore", categories=categories_global), categorical_indices)
    ])
    X_encoded = preprocessor.fit_transform(X_array)

    # Reconstrucci√≥n del DataFrame
    num_out = X_encoded[:, :len(numeric_features)]
    cat_out = X_encoded[:, len(numeric_features):]
    if categorical_features:
        cat_names = preprocessor.named_transformers_["cat"].get_feature_names_out(categorical_features)
    else:
        cat_names = []

    num_names = numeric_features

    X_df = pd.DataFrame(num_out, columns=num_names)
    if len(cat_names) > 0:
        X_cat_df = pd.DataFrame(cat_out, columns=cat_names)
        X_full = pd.concat([X_df.reset_index(drop=True), X_cat_df.reset_index(drop=True)], axis=1)
        for col in onehot_columns:
            if col not in X_cat_df.columns:
                X_full[col] = 0
    else:
        X_full = X_df

    # Rellenar columnas onehot que falten y ordenar
    final_columns = num_names + list(cat_names)
    X_full = X_full[final_columns]
    FEATURES[:] = final_columns

    split_idx = int(0.7 * len(X_full))

        # --- ¬°Construye el descriptor global! ---
    descriptor_global = descriptor.copy()
    for i, col in enumerate(cat_features):
        if col in descriptor_global["categorical"]:
            descriptor_global["categorical"][col]["distinct_values"] = list(categories_global[i])

    encoder = ColumnTransformerEnc(descriptor_global)



    return (
        X_full.iloc[:split_idx].to_numpy(), y[:split_idx],
        X_full.iloc[split_idx:].to_numpy(), y[split_idx:],
        tabular_dataset, final_columns, label_encoder,
        preprocessor.named_transformers_["num"], numeric_features, encoder, preprocessor
    )


# =======================


# Los resultados de las m√©tricas no son muy buenos aqui
# DATASET_NAME = "pablopalacios23/adult"
# CLASS_COLUMN = "class"



# DATASET_NAME = "pablopalacios23/churn"
# CLASS_COLUMN = "Churn" 



# DATASET_NAME = "pablopalacios23/HeartDisease"
# CLASS_COLUMN = "HeartDisease" 



DATASET_NAME = "pablopalacios23/breastcancer"
CLASS_COLUMN = "diagnosis" 



# DATASET_NAME = "pablopalacios23/Diabetes"
# CLASS_COLUMN = "Outcome" 


 
# =======================


# load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)

### HOLDOUT DEL SERVIDOR

In [3]:
X_train, y_train, X_test, y_test, dataset, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor = load_data_general(
    DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS
)

# Mostrar 5 primeros valores
print("\nüì¶ X_train (primeras filas):")
print(pd.DataFrame(X_train))

print("\nüéØ y_train (primeros valores):")
print(y_train)

print("\nüì¶ X_test (primeras filas):")
print(pd.DataFrame(X_test))

print("\nüéØ y_test (primeros valores):")
print(y_test)

print(feature_names)


2025-12-11 13:48:36,303 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-12-11 13:48:36,549 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/breastcancer/resolve/main/README.md HTTP/11" 404 0
2025-12-11 13:48:36,867 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/breastcancer HTTP/11" 200 565
2025-12-11 13:48:36,995 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/breastcancer/resolve/d21fb27c44731c56662f52e0f762dcc070083b0e/breastcancer.py HTTP/11" 404 0
2025-12-11 13:48:36,998 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-12-11 13:48:37,333 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/pablopalacios23/breastcancer/pablopalacios23/breastcancer.py HTTP/11" 404 0
2025-12-11 13:48:37,469 urllib3.connectionpool


üì¶ X_train (primeras filas):
        0      1      2      3        4        5         6         7       8   \
0   12.460  12.83  78.83  477.3  0.07372  0.04043  0.007173  0.011490  0.1613   
1   10.910  12.35  69.14  363.7  0.08518  0.04721  0.012360  0.013690  0.1449   
2   12.100  17.72  78.07  446.2  0.10290  0.09758  0.047830  0.033260  0.1937   
3   11.710  17.19  74.68  420.3  0.09774  0.06141  0.038090  0.032390  0.1516   
4   11.360  17.57  72.49  399.8  0.08858  0.05313  0.027830  0.021000  0.1601   
..     ...    ...    ...    ...      ...      ...       ...       ...     ...   
78  11.460  18.16  73.59  403.1  0.08853  0.07694  0.033440  0.015020  0.1411   
79   9.876  19.40  63.95  298.3  0.10050  0.09697  0.061540  0.030290  0.1945   
80  12.340  14.95  78.29  469.1  0.08682  0.04571  0.021090  0.020540  0.1571   
81  13.200  15.82  84.07  537.3  0.08511  0.05251  0.001461  0.003261  0.1632   
82  13.640  15.60  87.38  575.3  0.09423  0.06630  0.047050  0.037310  0.1717

# Cliente

In [4]:
# ==========================
# üåº CLIENTE FLOWER
# ==========================
import operator
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from sklearn.metrics import pairwise_distances


class TorchNNWrapper:
    def __init__(self, model, num_idx, mean, scale):
        self.model = model
        self.model.eval()
        self.num_idx = np.asarray(num_idx, dtype=int)
        self.mean = np.asarray(mean, dtype=np.float32)
        self.scale = np.asarray(scale, dtype=np.float32)
        self.scale_safe = np.where(self.scale == 0, 1.0, self.scale)

    def _scale_internally(self, X):
        X = np.asarray(X, dtype=np.float32)
        Xs = X.copy()
        # soporta [n, d] o [d]
        if Xs.ndim == 1:
            Xs = Xs[None, :]
        Xs[:, self.num_idx] = (Xs[:, self.num_idx] - self.mean) / self.scale_safe
        return Xs

    def predict(self, X):
        Xs = self._scale_internally(X)
        with torch.no_grad():
            X_tensor = torch.tensor(Xs, dtype=torch.float32)
            logits = self.model(X_tensor)
            return logits.argmax(dim=1).cpu().numpy()

    def predict_proba(self, X):
        Xs = self._scale_internally(X)
        with torch.no_grad():
            X_tensor = torch.tensor(Xs, dtype=torch.float32)
            logits = self.model(X_tensor)
            probs = F.softmax(logits, dim=1)
            return probs.cpu().numpy()

class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
        

class FlowerClient(NumPyClient, ClientUtilsMixin):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, X_train_nn, X_test_nn, scaler_nn_mean, scaler_nn_scale, num_idx, dataset, client_id, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.X_train_nn = X_train_nn
        self.X_test_nn  = X_test_nn
        self.scaler_nn_mean = np.asarray(scaler_nn_mean, dtype=np.float32)
        self.scaler_nn_scale = np.where(np.asarray(scaler_nn_scale, np.float32)==0, 1.0, np.asarray(scaler_nn_scale, np.float32))
        self.num_idx = np.asarray(num_idx, dtype=int)
        self.dataset = dataset
        self.client_id = client_id
        self.feature_names = feature_names
        self.label_encoder = label_encoder
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder
        self.unique_labels = label_encoder.classes_.tolist()
        self.y_train_nn = y_train.astype(np.int64)
        self.y_test_nn = y_test.astype(np.int64)
        self.received_supertree = None
        self.preprocessor = preprocessor

    def _train_nn(self, epochs=10, lr=1e-3):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        X_tensor = torch.tensor(self.X_train_nn, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()
        print(f"[CLIENTE {self.client_id}] ‚úÖ Red neuronal entrenada")



    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [
            self.tree_model.get_params()["max_depth"],
            self.tree_model.get_params()["min_samples_split"],
            self.tree_model.get_params()["min_samples_leaf"],
        ], "nn": parameters})
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

        round_number = config.get("server_round", 1)
        if round_number <= NUM_TRAIN_ROUNDS:
            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()


        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}
    


    def evaluate(self, parameters, config):


        set_model_params(self.tree_model, self.nn_model, {"tree": [
            self.tree_model.get_params()["max_depth"],
            self.tree_model.get_params()["min_samples_split"],
            self.tree_model.get_params()["min_samples_leaf"],
        ], "nn": parameters})

        round_number = config.get("server_round", 1)
        explain_only = config.get("explain_only", False)

        if "supertree" in config:
            try:
                print("Recibiendo supertree....")
                supertree_dict = json.loads(config["supertree"])

                self.received_supertree = SuperTree.convert_SuperNode_to_Node(SuperTree.SuperNode.from_dict(supertree_dict))
                self.global_mapping = json.loads(config["global_mapping"])
                self.feature_names = json.loads(config["feature_names"])

            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ‚ùå Error al recibir SuperTree: {e}")


            # üîπ CASO 1: rondas de entrenamiento (1..NUM_TRAIN_ROUNDS)
        if not explain_only:

            self.tree_model.fit(self.X_train, self.y_train)
        
            supertree = SuperTree()
            root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])), len(self.unique_labels))
        
            root_node = supertree.prune_redundant_leaves_local(root_node)


            self._save_local_tree(root_node, round_number, FEATURES, self.numeric_features,
                                scaler=None, unique_labels=UNIQUE_LABELS, encoder=self.encoder)
            tree_json = json.dumps([root_node.to_dict()])

            # En rondas de entrenamiento NO explicas todo el test (si no quieres)
            if self.received_supertree is not None and round_number == NUM_TRAIN_ROUNDS:
                pass

            return 0.0, len(self.X_test), {
                f"tree_ensemble_{self.client_id}": tree_json,
                f"encoded_feature_names_{self.client_id}": json.dumps(FEATURES),
                f"numeric_features_{self.client_id}": json.dumps(self.numeric_features),
                f"unique_labels_{self.client_id}": json.dumps(self.unique_labels),
                f"distinct_values_{self.client_id}": json.dumps(self.encoder.dataset_descriptor["categorical"])
            }

         # üîπ CASO 2: ronda final (solo explicaci√≥n con Supertree final)
        else:
            print(f"[CLIENTE {self.client_id}] üîç Ronda final: solo explicaciones")
            # aqu√≠ NO entrenamos self.tree_model ni mandamos tree_ensemble_*

            self.tree_model.fit(self.X_train, self.y_train)

            # Usamos el SuperTree final recibido + LORE + mergedTree
            if self.received_supertree is not None:
                
                # self.explain_all_test_instances(config)
                self.explain_all_test_instances(config, only_idx=0)

            # Puedes devolver m√©tricas dummy o medias de lo que has calculado en explain_all_test_instances
            return 0.0, len(self.X_test), {}
        


    
    
    def _explain_one_instance(self, num_row, config, save_trees=False):
        from sklearn.metrics import accuracy_score
        import numpy as np
        

        # Wrapper que escala SOLO para la NN (espacio NN)
        nn_wrapper = TorchNNWrapper(
            self.nn_model,
            num_idx=self.num_idx,
            mean=self.scaler_nn_mean,
            scale=self.scaler_nn_scale,
        )

        # 1. Visualizar instancia escalada y decodificada usando el encoder/preprocessor ORIGINAL
        
        decoded = self.decode_onehot_instance(
            self.X_test[num_row],
            self.numeric_features,
            self.encoder,
            None,                 # <-- sin scaler (en crudo)
            self.feature_names
        )



        # Aseg√∫rate de que X_test[num_row] es un numpy array del shape correcto (1, n_features)
        row = np.asarray(self.X_test[num_row], dtype=np.float32)
        probs = nn_wrapper.predict_proba(row[None, :])
        pred_class_idx = int(probs.argmax(axis=1)[0])
        pred_class = self.label_encoder.inverse_transform([pred_class_idx])[0]

        # 2. Construir DataFrame para LORE (si es necesario, solo para TabularDataset)

        local_df = pd.DataFrame(self.X_train, columns=self.feature_names).astype(np.float32)
        local_df["class"] = self.label_encoder.inverse_transform(self.y_train_nn)
        local_tabular_dataset = TabularDataset(local_df, class_name="class")


        # Explicabilidad local y la vecindad es generada del train (local_tabular_dataset)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore_vecindad = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)


        # Explicaci√≥n LORE
        x_instance = pd.Series(self.X_test[num_row], index=self.feature_names)
        round_number = config.get("server_round", 1)

        # print("Instancia a explicar (decodificada):")
        # print(x_instance)

        # t0 = time.perf_counter()
        explanation = lore_vecindad.explain_instance(x_instance, merge=True, num_classes=len(UNIQUE_LABELS), feature_names= self.feature_names, categorical_features=list(self.global_mapping.keys()), global_mapping=self.global_mapping, UNIQUE_LABELS=UNIQUE_LABELS,
                                                    client_id=self.client_id, round_number=round_number)

        lore_tree = explanation["merged_tree"]
        
        if save_trees:
            self.save_lore_tree_image(lore_tree.root,round_number,self.feature_names,self.numeric_features,UNIQUE_LABELS,self.encoder,folder="LoreTree")

        merged_tree = SuperTree()
        merged_tree.mergeDecisionTrees(
            roots=[lore_tree.root, self.received_supertree],
            num_classes=len(self.unique_labels),
            feature_names=self.feature_names,
            categorical_features=list(self.global_mapping.keys()), 
            global_mapping=self.global_mapping
        )
        
        merged_tree.prune_redundant_leaves_full()


        if save_trees:
            self.save_mergedTree_plot(root_node=merged_tree.root,round_number=round_number,feature_names=self.feature_names,class_names=self.unique_labels,numeric_features=self.numeric_features,scaler=None, global_mapping=self.global_mapping,folder="MergedTree")

        
        
        
        
        # ========================================================================================================================================================================================================
        # CREACI√ìN DE √ÅRBOL LOCAL SOBRE Z ETIQUETADO POR EL √ÅRBOL LOCAL y el SUPERTREE, PARA LUEGO UNIRLO AL LORE TREE
        # ========================================================================================================================================================================================================


        Z = explanation["neighborhood_Z"] # instancias del vecindario sint√©tico generado alrededor del punto a explicar.
        y_bb = explanation["neighborhood_Yb"] # predicciones del modelo BBOX (red neuronal) sobre Z (el vecindario).

        y_surrogate_preds = explanation["surrogate_preds"]  # predicciones del modelo interpretable (arbol - LORE Tree) sobre Z (el vecindario).




        # Convertir Z en DataFrame legible
        dfZ = pd.DataFrame(Z, columns=self.feature_names)
                
        y_local_Z = self.tree_model.predict(Z)
        y_local_supertree_Z = self.received_supertree.predict(Z)



        # √Årboles entrenados sobre la vecindad Z etiquetada por el √°rbol local y el SuperTree
        local_local_clf = DecisionTreeClassifier(
            max_depth=self.tree_model.get_params()["max_depth"],
            min_samples_split=self.tree_model.get_params()["min_samples_split"],
            min_samples_leaf=self.tree_model.get_params()["min_samples_leaf"],
            random_state=42,
        )

        local_supertree_clf = DecisionTreeClassifier(
            max_depth=self.tree_model.get_params()["max_depth"],
            min_samples_split=self.tree_model.get_params()["min_samples_split"],
            min_samples_leaf=self.tree_model.get_params()["min_samples_leaf"],
            random_state=42,
        )

        local_local_clf.fit(Z, y_local_Z)
        local_supertree_clf.fit(Z, y_local_supertree_Z)


        

        # Pasarlo a SuperTree.Node para poder hacer merge
        st_local_local = SuperTree()
        root_local_local = st_local_local.rec_buildTree(
            local_local_clf,
            list(range(Z.shape[1])),
            len(self.unique_labels),
        )

        st_local_supertree = SuperTree()
        root_local_supertree = st_local_supertree.rec_buildTree(
            local_supertree_clf,
            list(range(Z.shape[1])),
            len(self.unique_labels),
        )



        root_local_local = st_local_local.prune_redundant_leaves_local(root_local_local)
        root_local_supertree = st_local_supertree.prune_redundant_leaves_local(root_local_supertree)






        local_local_tree = SuperTree()
        local_local_tree.mergeDecisionTrees(
            roots=[lore_tree.root, root_local_local],
            num_classes=len(self.unique_labels),
            feature_names=self.feature_names,
            categorical_features=list(self.global_mapping.keys()),
            global_mapping=self.global_mapping,
        )
        


        local_super_tree = SuperTree()
        local_super_tree.mergeDecisionTrees(
            roots=[lore_tree.root, root_local_local, root_local_supertree],
            num_classes=len(self.unique_labels),
            feature_names=self.feature_names,
            categorical_features=list(self.global_mapping.keys()),
            global_mapping=self.global_mapping,
        )

        local_local_tree.prune_redundant_leaves_full()
        local_super_tree.prune_redundant_leaves_full()




        tree_str = self.tree_to_str(merged_tree.root,self.feature_names,numeric_features=self.numeric_features,scaler=None, global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        lore_tree_str = self.tree_to_str(lore_tree.root, self.feature_names, numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        supertree_str = self.tree_to_str(self.received_supertree, self.feature_names, numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        local_local_tree_str = self.tree_to_str(local_local_tree.root,self.feature_names,numeric_features=self.numeric_features,scaler=None, global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        local_super_tree_str = self.tree_to_str(local_super_tree.root,self.feature_names,numeric_features=self.numeric_features,scaler=None, global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        tree_str_localZ = self.tree_to_str(root_local_local, self.feature_names, numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)
        tree_str_superZ = self.tree_to_str(root_local_supertree,self.feature_names,numeric_features=self.numeric_features,scaler=None,global_mapping=self.global_mapping,unique_labels=self.unique_labels)




        
        rules = self.extract_rules_from_str(tree_str, target_class_label=pred_class)
        rules_lore = self.extract_rules_from_str(lore_tree_str, target_class_label=pred_class)
        rules_supertree = self.extract_rules_from_str(supertree_str, target_class_label=pred_class)
        rules_local_local = self.extract_rules_from_str(local_local_tree_str, target_class_label=pred_class)
        rules_local_super = self.extract_rules_from_str(local_super_tree_str, target_class_label=pred_class)
        rules_localZ = self.extract_rules_from_str(tree_str_localZ, target_class_label=pred_class)
        rules_superZ = self.extract_rules_from_str(tree_str_superZ, target_class_label=pred_class)




    
        def cumple_regla(instancia, regla):
            for cond in regla:
                if "‚àß" in cond:
                    # Maneja condiciones tipo intervalo: 'age > 44.33 ‚àß ‚â§ 48.50'
                    import re
                    # Busca: variable, operador1, valor1, operador2, valor2
                    m = re.match(r'(.+?)([><]=?|‚â§|‚â•)\s*([-\d\.]+)\s*‚àß\s*([><]=?|‚â§|‚â•)\s*([-\d\.]+)', cond)
                    if m:
                        var = m.group(1).strip()
                        op1, val1 = m.group(2), float(m.group(3))
                        op2, val2 = m.group(4), float(m.group(5))
                        v = instancia[var]
                        # Eval√∫a las dos condiciones del intervalo
                        if not (
                            eval(f"v {op1.replace('‚â§','<=').replace('‚â•','>=')} {val1}") and
                            eval(f"v {op2.replace('‚â§','<=').replace('‚â•','>=')} {val2}")
                        ):
                            return False
                        continue  # sigue al siguiente cond
                # ... resto de tu c√≥digo tal cual ...
                if "‚â§" in cond:
                    var, val = cond.split("‚â§")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] > val:
                        return False
                elif ">=" in cond or "‚â•" in cond:
                    var, val = cond.replace("‚â•", ">=").split(">=")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] < val:
                        return False
                elif ">" in cond:
                    var, val = cond.split(">")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] <= val:
                        return False
                elif "<" in cond:
                    var, val = cond.split("<")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] >= val:
                        return False
                elif "‚â†" in cond:
                    var, val = cond.split("‚â†")
                    var = var.strip()
                    val = val.strip().replace('"', "")
                    if instancia[var] == val:
                        return False
                elif "=" in cond:
                    var, val = cond.split("=")
                    var = var.strip()
                    val = val.strip().replace('"', "")
                    if instancia[var] != val:
                        return False
            return True

        # Buscar la regla factual (la que cubre la instancia)
        regla_factual = None
        for regla in rules:
            if cumple_regla(decoded, regla):
                regla_factual = regla
                break
        
        regla_factual_lore = None
        for r in rules_lore:
            if cumple_regla(decoded, r):
                regla_factual_lore = r
                break

        regla_factual_supertree = None
        for r in rules_supertree:
            if cumple_regla(decoded, r):
                regla_factual_supertree = r
                break

        regla_factual_local_local = None
        for r in rules_local_local:
            if cumple_regla(decoded, r):
                regla_factual_local_local = r
                break
        
        regla_factual_local_super = None
        for r in rules_local_super: 
            if cumple_regla(decoded, r):
                regla_factual_local_super = r
                break
        
        regla_factual_localZ = None
        for r in rules_localZ:
            if cumple_regla(decoded, r):
                regla_factual_localZ = r
                break
        
        regla_factual_superZ = None
        for r in rules_superZ:
            if cumple_regla(decoded, r):
                regla_factual_superZ = r
                break


        

        # Extraer 1 contrafactual por cada clase distinta a la predicha
        cf_rules_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(tree_str, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_por_clase[clase] = min(rules_clase, key=len)



        # Extraer 1 contrafactual tipo LORE por cada clase distinta a la predicha
        cf_rules_LORE_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(lore_tree_str, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_LORE_por_clase[clase] = min(rules_clase, key=len)


        cf_rules_Supertree_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(supertree_str, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_Supertree_por_clase[clase] = min(rules_clase, key=len)



        cf_rules_local_local_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(local_local_tree_str, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_local_local_por_clase[clase] = min(rules_clase, key=len)



        cf_rules_local_super_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(local_super_tree_str, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_local_super_por_clase[clase] = min(rules_clase, key=len)



        cf_rules_localZ_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(tree_str_localZ, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_localZ_por_clase[clase] = min(rules_clase, key=len)



        cf_rules_superZ_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(tree_str_superZ, target_class_label=clase)
                if rules_clase:
                    # Elige la m√°s sencilla (menos condiciones)
                    cf_rules_superZ_por_clase[clase] = min(rules_clase, key=len)

        
        

        
        




        

        # ========================================================================================================================================================================================================
        # üìè M√âTRICAS DE EXPLICABILIDAD LOCAL (vecindario Z)
        # 
        # Silhouette:  Distancia media entre x y las instancias de su misma clase en el vecindario (Z+)
        # ========================================================================================================================================================================================================
        mask_same_class = (y_bb == pred_class_idx)
        mask_diff_class = (y_bb != pred_class_idx)

        Z_plus = dfZ[mask_same_class]
        Z_minus = dfZ[mask_diff_class]

        x = self.X_test[num_row]

        a = pairwise_distances([x], Z_plus).mean() if len(Z_plus) > 0 else 0.0
        b = pairwise_distances([x], Z_minus).mean() if len(Z_minus) > 0 else 0.0

        silhouette = 0.0
        if (a + b) > 0:
            silhouette = (b - a) / max(a, b)





        
        # ==================================================================================================================
        # M√âTRICAS DE COMPARATIVA DE PRECISION DE LOS √ÅRBOLES CON EL TEST (QUE TAN BUENOS SON LOS ARBOLES QUE HEMOS GENERADO)
        # ==================================================================================================================

        y_true = self.y_test

        # Predicciones
        y_pred_lore = lore_tree.root.predict(self.X_test)
        y_pred_merged = merged_tree.root.predict(self.X_test)
        y_pred_super = self.received_supertree.predict(self.X_test)
        y_pred_local_local = local_local_tree.root.predict(self.X_test)
        y_pred_local_super = local_super_tree.root.predict(self.X_test)


        y_pred_localZ_tree  = root_local_local.predict(self.X_test)        # Z etiquetada por local
        y_pred_superZ_tree  = root_local_supertree.predict(self.X_test)    # Z etiquetada por supertree

        # Accuracy
        acc_lore = accuracy_score(y_true, y_pred_lore)
        acc_merged = accuracy_score(y_true, y_pred_merged)
        acc_super = accuracy_score(y_true, y_pred_super)
        acc_local_local = accuracy_score(y_true, y_pred_local_local)
        acc_local_super = accuracy_score(y_true, y_pred_local_super)
        acc_localZ_tree_TEST  = accuracy_score(y_true, y_pred_localZ_tree)
        acc_superZ_tree_TEST  = accuracy_score(y_true, y_pred_superZ_tree)


        # Precision, Recall, F1
        prec_lore = precision_score(y_true, y_pred_lore, average="weighted")
        rec_lore  = recall_score(y_true, y_pred_lore, average="weighted")
        f1_lore   = f1_score(y_true, y_pred_lore, average="weighted")

        prec_merged = precision_score(y_true, y_pred_merged, average="weighted")
        rec_merged  = recall_score(y_true, y_pred_merged, average="weighted")
        f1_merged   = f1_score(y_true, y_pred_merged, average="weighted")

        prec_super = precision_score(y_true, y_pred_super, average="weighted")
        rec_super  = recall_score(y_true, y_pred_super, average="weighted")
        f1_super   = f1_score(y_true, y_pred_super, average="weighted")

        prec_local_local = precision_score(y_true, y_pred_local_local, average="weighted")
        rec_local_local  = recall_score(y_true, y_pred_local_local, average="weighted")
        f1_local_local   = f1_score(y_true, y_pred_local_local, average="weighted")

        prec_local_super = precision_score(y_true, y_pred_local_super, average="weighted")
        rec_local_super  = recall_score(y_true, y_pred_local_super, average="weighted")
        f1_local_super   = f1_score(y_true, y_pred_local_super, average="weighted")

        prec_localZ_tree_TEST = precision_score(y_true, y_pred_localZ_tree, average="weighted")
        rec_localZ_tree_TEST  = recall_score(y_true, y_pred_localZ_tree, average="weighted")
        f1_localZ_tree_TEST   = f1_score(y_true, y_pred_localZ_tree, average="weighted")

        prec_superZ_tree_TEST = precision_score(y_true, y_pred_superZ_tree, average="weighted")
        rec_superZ_tree_TEST  = recall_score(y_true, y_pred_superZ_tree, average="weighted")
        f1_superZ_tree_TEST   = f1_score(y_true, y_pred_superZ_tree, average="weighted")

        # ===================================================
        # M√âTRICAS EN LA VECINDAD Z (Qu√© tal imitan al BBOX)
        # ===================================================

        
        # Predicciones de los √°rboles sobre el vecindario Z
        y_pred_lore_Z   = lore_tree.root.predict(Z)
        y_pred_merged_Z = merged_tree.root.predict(Z)
        y_pred_super_Z  = self.received_supertree.predict(Z)
        y_pred_local_local_Z = local_local_tree.root.predict(Z)
        y_pred_local_super_Z = local_super_tree.root.predict(Z)

        # Local solo (Z) y SuperTree solo (Z)
        y_pred_localZ_Z  = self.tree_model.predict(Z)           # Local solo (Z)
        y_pred_superZ_Z  = root_local_supertree.predict(Z)      # SuperTree solo (Z)



        # y_bb son las "verdades" en la vecindad (lo que dice el BBOX)
        acc_lore_Z   = accuracy_score(y_bb, y_pred_lore_Z)
        acc_merged_Z = accuracy_score(y_bb, y_pred_merged_Z)
        acc_super_Z  = accuracy_score(y_bb, y_pred_super_Z)
        acc_local_local_Z = accuracy_score(y_bb, y_pred_local_local_Z)
        acc_local_super_Z = accuracy_score(y_bb, y_pred_local_super_Z)
        acc_local_Z       = accuracy_score(y_bb, y_pred_localZ_Z)     # LocalZ_Z
        acc_superZ_Z      = accuracy_score(y_bb, y_pred_superZ_Z)     # SuperZ_Z

        prec_lore_Z        = precision_score(y_bb, y_pred_lore_Z,        average="weighted")
        prec_merged_Z      = precision_score(y_bb, y_pred_merged_Z,      average="weighted")
        prec_super_Z       = precision_score(y_bb, y_pred_super_Z,       average="weighted")
        prec_local_local_Z = precision_score(y_bb, y_pred_local_local_Z, average="weighted")
        prec_local_super_Z = precision_score(y_bb, y_pred_local_super_Z, average="weighted")
        prec_local_Z       = precision_score(y_bb, y_pred_localZ_Z,      average="weighted")
        prec_superZ_Z      = precision_score(y_bb, y_pred_superZ_Z,      average="weighted")


        rec_lore_Z        = recall_score(y_bb, y_pred_lore_Z,        average="weighted")
        rec_merged_Z      = recall_score(y_bb, y_pred_merged_Z,      average="weighted")
        rec_super_Z       = recall_score(y_bb, y_pred_super_Z,       average="weighted")
        rec_local_local_Z = recall_score(y_bb, y_pred_local_local_Z, average="weighted")
        rec_local_super_Z = recall_score(y_bb, y_pred_local_super_Z, average="weighted")
        rec_local_Z       = recall_score(y_bb, y_pred_localZ_Z,      average="weighted")
        rec_superZ_Z      = recall_score(y_bb, y_pred_superZ_Z,      average="weighted")


        f1_lore_Z        = f1_score(y_bb, y_pred_lore_Z,        average="weighted")
        f1_merged_Z      = f1_score(y_bb, y_pred_merged_Z,      average="weighted")
        f1_super_Z       = f1_score(y_bb, y_pred_super_Z,       average="weighted")
        f1_local_local_Z = f1_score(y_bb, y_pred_local_local_Z, average="weighted")
        f1_local_super_Z = f1_score(y_bb, y_pred_local_super_Z, average="weighted")
        f1_local_Z       = f1_score(y_bb, y_pred_localZ_Z,      average="weighted")
        f1_superZ_Z      = f1_score(y_bb, y_pred_superZ_Z,      average="weighted")


    




        # ============================================================================================================
        # üåê COVERAGE / FIDELITY / SUPPORT (REGLAS FACTUALES, VERSI√ìN GLOBAL)
        # ============================================================================================================
        # Aqu√≠ trabajamos en TODO el conjunto de test X_test.
        #
        # Objetivo: para cada regla factual (del LORE tree, Merged tree y Supertree) calcular:
        #
        #   - coverage_factual:
        #       De todas las instancias del test, ¬øcu√°ntas cumplen la regla?
        #
        #   - fidelity_factual (Q_fidelity del paper, Eq. (3)):
        #       entre las instancias que cumplen la regla, qu√© proporci√≥n el BBOX las clasifica como la clase c.
        #
        #           Q_fidelity(E) = | { x ‚àà cover(E) : M(x) = c } | / | cover(E) |
        #
        #   - support_factual:
        #       Tama√±o del cover(E), es decir, cu√°ntas instancias de X_test satisfacen la regla.
        #
        # Nota:
        #   - X_test_decoded: X_test decodificado a espacio legible (variables originales).
        #   - y_bb_test: predicciones del BBOX (red neuronal) sobre TODO X_test.
        #   - c_idx: √≠ndice de la clase predicha por el BBOX para la instancia concreta que estamos explicando.
        #            Es la clase asociada a la regla factual.
        # ============================================================================================================

        # --- Predicciones del BBOX en TODO el test (M(x)) ---
        y_bb_test = nn_wrapper.predict(self.X_test)


        cf_rules_por_clase_simplify = self._simplify_rules_by_class(cf_rules_por_clase, mode='loose')
        cf_rules_LORE_por_clase_simplify = self._simplify_rules_by_class(cf_rules_LORE_por_clase, mode='loose')
        cf_rules_Supertree_por_clase_simplify = self._simplify_rules_by_class(cf_rules_Supertree_por_clase, mode='loose')
        cf_rules_local_local_por_clase_simplify = self._simplify_rules_by_class(cf_rules_local_local_por_clase, mode='loose')
        cf_rules_local_super_por_clase_simplify = self._simplify_rules_by_class(cf_rules_local_super_por_clase, mode='loose')
        cf_rules_localZ_por_clase_simplify = self._simplify_rules_by_class(cf_rules_localZ_por_clase, mode='loose')
        cf_rules_superZ_por_clase_simplify = self._simplify_rules_by_class(cf_rules_superZ_por_clase, mode='loose')
        

        regla_factual_simplify = None
        regla_factual_LORE_simplify = None
        regla_factual_Supertree_simplify = None
        regla_factual_local_local_simplify = None
        regla_factual_local_super_simplify = None
        regla_factual_localZ_simplify = None
        regla_factual_superZ_simplify = None




        # --- Decodificar TODO X_test al espacio legible ---
        Xtest_df = pd.DataFrame(self.X_test, columns=self.feature_names)
        Xtest_decoded = Xtest_df.apply(
            lambda r: self.decode_onehot_instance(
                r.values, self.numeric_features, self.encoder, self.scaler, self.feature_names
            ),
            axis=1
        )

        c_idx = pred_class_idx                       # clase de la regla (c)

        def coverage_fidelity_support(regla):
            if not regla:
                return 0.0, 0.0   # coverage, fidelity

            mask_cover = Xtest_decoded.apply(
                lambda r: cumple_regla(r, regla), axis=1
            ).values
            cover_count = int(mask_cover.sum())

            n = len(Xtest_decoded)
            cov = (cover_count / n) if n > 0 else 0.0

            if cover_count > 0:
                fid = float((y_bb_test[mask_cover] == c_idx).mean())
            else:
                fid = 0.0

            return cov, fid
        

        # ---- Regla factual de cada √°rbol ----
        coverage_merged,        fidelity_merged  = coverage_fidelity_support(regla_factual)
        coverage_lore,          fidelity_lore= coverage_fidelity_support(regla_factual_lore)
        coverage_supertree,     fidelity_super= coverage_fidelity_support(regla_factual_supertree)
        coverage_local_local,   fidelity_local_local= coverage_fidelity_support(regla_factual_local_local)
        coverage_local_super,   fidelity_local_super= coverage_fidelity_support(regla_factual_local_super)
        coverage_localZ,        fidelity_localZ= coverage_fidelity_support(regla_factual_localZ)
        coverage_superZ,        fidelity_superZ= coverage_fidelity_support(regla_factual_superZ)

        # ---- Versi√≥n simplificada solo para imprimir / complejidad ----
        if regla_factual:
            regla_factual_simplify = self._simplify_rule(regla_factual, mode='loose')

        if regla_factual_lore:
            regla_factual_LORE_simplify = self._simplify_rule(regla_factual_lore, mode='loose')

        if regla_factual_supertree:
            regla_factual_Supertree_simplify = self._simplify_rule(regla_factual_supertree, mode='loose')

        if regla_factual_local_local:
            regla_factual_local_local_simplify = self._simplify_rule(regla_factual_local_local, mode='loose')

        if regla_factual_local_super:
            regla_factual_local_super_simplify = self._simplify_rule(regla_factual_local_super, mode='loose')

        if regla_factual_localZ:
            regla_factual_localZ_simplify = self._simplify_rule(regla_factual_localZ, mode='loose')
            
        if regla_factual_superZ:
            regla_factual_superZ_simplify = self._simplify_rule(regla_factual_superZ, mode='loose')







        # =======================
        # Versi√≥n vecindario (Z)
        # =======================


        # Decodificamos Z a espacio legible, igual que hiciste con X_test
        dfZ_decoded = dfZ.apply(
            lambda r: self.decode_onehot_instance(
                r.values,
                self.numeric_features,
                self.encoder,
                self.scaler,          # o None, igual que en X_test_decoded si lo est√°s usando as√≠
                self.feature_names
            ),
            axis=1
        )

        def coverage_fidelity_support_neigh(regla):
            if not regla:
                return 0.0, 0.0   # coverage, fidelity

            mask_cover_Z = dfZ_decoded.apply(
                lambda r: cumple_regla(r, regla), axis=1
            ).values

            cover_count_Z = int(mask_cover_Z.sum())
            nZ = len(dfZ_decoded)

            cov_Z = (cover_count_Z / nZ) if nZ > 0 else 0.0

            if cover_count_Z > 0:
                fid_Z = float((y_bb[mask_cover_Z] == c_idx).mean())
            else:
                fid_Z = 0.0

            return cov_Z, fid_Z
        
        # ---- Versi√≥n vecindario (LOCAL) ----
        coverage_merged_Z,          fidelity_merged_Z= coverage_fidelity_support_neigh(regla_factual)
        coverage_lore_Z,            fidelity_lore_Z= coverage_fidelity_support_neigh(regla_factual_lore)
        coverage_super_Z,           fidelity_super_Z= coverage_fidelity_support_neigh(regla_factual_supertree)
        coverage_local_local_Z,     fidelity_local_local_Z= coverage_fidelity_support_neigh(regla_factual_local_local)
        coverage_local_super_Z,     fidelity_local_super_Z= coverage_fidelity_support_neigh(regla_factual_local_super)
        coverage_localZ_Z,          fidelity_localZ_Z= coverage_fidelity_support_neigh(regla_factual_localZ)
        coverage_superZ_Z,          fidelity_superZ_Z= coverage_fidelity_support_neigh(regla_factual_superZ)




            
        # ======================================== HIT ============================================================================        
        # BBOX predice una clase para tu instancia ùë•, por ejemplo "No".

        # extract_rules_from_str(..., target_class_label=pred_class) te devuelve solo las reglas de esa clase "No" en el surrogate.

        # Si alguna de esas reglas cubre x ‚Üí

        # Existe regla factual ‚úÖ
        # Esa regla ya es de la misma clase que el BBOX ‚úÖ
        # Entonces HIT = 1.

        # Si ninguna regla de esa clase cubre x ‚Üí HIT = 0.
        # ===========================================================================================================================

        hit_merged = int(regla_factual is not None)              # factual del merged
        hit_lore   = int(regla_factual_lore is not None)         # factual del LORE
        hit_supertree = int(regla_factual_supertree is not None) # factual del supertree
        hit_lore_local_local = int(regla_factual_local_local is not None) # factual del local_local
        hit_lore_local_super = int(regla_factual_local_super is not None) # factual del local_super
        hit_localZ = int(regla_factual_localZ is not None)       # factual del localZ
        hit_superZ = int(regla_factual_superZ is not None)       # factual


        # ============================================================================================
        # Metricas de los √°rboles
        # ============================================================================================

        depth_merged_edges = self.tree_depth_edges(merged_tree.root)
        nodes_merged = self.count_nodes(merged_tree.root)
        leaves_merged = self.count_leaves(merged_tree.root)

        depth_lore_edges = self.tree_depth_edges(lore_tree.root)
        nodes_lore = self.count_nodes(lore_tree.root)
        leaves_lore = self.count_leaves(lore_tree.root)

        depth_supertree_edges = self.tree_depth_edges(self.received_supertree)
        nodes_supertree = self.count_nodes(self.received_supertree)
        leaves_supertree = self.count_leaves(self.received_supertree)

        depth_lore_edges_local_local = self.tree_depth_edges(local_local_tree.root)
        nodes_lore_local_local = self.count_nodes(local_local_tree.root)
        leaves_lore_local_local = self.count_leaves(local_local_tree.root)

        depth_lore_edges_local_super = self.tree_depth_edges(local_super_tree.root)
        nodes_lore_local_super = self.count_nodes(local_super_tree.root)
        leaves_lore_local_super = self.count_leaves(local_super_tree.root)

        depth_localZ_edges = self.tree_depth_edges(root_local_local)
        nodes_localZ = self.count_nodes(root_local_local)
        leaves_localZ = self.count_leaves(root_local_local)

        depth_superZ_edges = self.tree_depth_edges(root_local_supertree)
        nodes_superZ = self.count_nodes(root_local_supertree)
        leaves_superZ = self.count_leaves(root_local_supertree)

        # ============================================================================================================================================
        # Complejidad de las reglas (n√∫mero de condiciones)
        # ============================================================================================================================================

        def rule_complexity(regla):
            return len(regla) if regla else 0
        
        comp_factual_merged_simpl = rule_complexity(regla_factual_simplify)
        comp_cf_merged_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_por_clase_simplify.items()}

        comp_factual_lore_simpl = rule_complexity(regla_factual_LORE_simplify)
        comp_cf_lore_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_LORE_por_clase_simplify.items()}

        comp_factual_supertree_simpl = rule_complexity(regla_factual_Supertree_simplify)
        comp_cf_supertree_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_Supertree_por_clase_simplify.items()}

        comp_factual_local_local_simpl = rule_complexity(regla_factual_local_local_simplify)
        comp_cf_local_local_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_local_local_por_clase_simplify.items()}

        comp_factual_local_super_simpl = rule_complexity(regla_factual_local_super_simplify)
        comp_cf_local_super_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_local_super_por_clase_simplify.items()}

        comp_factual_localZ_simpl = rule_complexity(regla_factual_localZ_simplify)
        comp_cf_localZ_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_localZ_por_clase_simplify.items()}

        comp_factual_superZ_simpl = rule_complexity(regla_factual_superZ_simplify)
        comp_cf_superZ_simpl = {cl: rule_complexity(r) for cl, r in cf_rules_superZ_por_clase_simplify.items()}






        # ============================================================================================================================================

        # Coverage: ¬øQu√© proporci√≥n del dataset satisface este contrafactual? 
        # Es decir, cu√°ntas instancias ‚Äúquedan explicadas‚Äù por este contrafactual.

        # Support: ¬øCu√°ntas instancias satisfacen este contrafactual?

        # ============================================================================================================================================


        def compute_cf_coverage(rules_dict):
            coverage_cf  = {}
            support_cf   = {}
            precision_cf = {}

            n = len(Xtest_decoded)   # tama√±o del test

            for clase, regla_cf in rules_dict.items():

                if regla_cf:
                    # cover(E): instancias que cumplen la regla contrafactual
                    mask_cf_test = Xtest_decoded.apply(
                        lambda r: cumple_regla(r, regla_cf), axis=1
                    ).values

                    sup = int(mask_cf_test.sum())      # |cover(E)|
                    cov = sup / n if n > 0 else 0.0    # coverage global

                    if sup > 0:
                        # fidelidad en TEST: de las que cumplen el CF,
                        # cu√°ntas el BBOX las clasifica como la clase del CF
                        clase_idx = self.label_encoder.transform([clase])[0]
                        prec = float((y_bb_test[mask_cf_test] == clase_idx).mean())
                    else:
                        prec = 0.0
                else:
                    sup, cov, prec = 0, 0.0, 0.0

                coverage_cf[clase]  = cov
                precision_cf[clase] = prec

            return coverage_cf, precision_cf


        # --- Merged / LORE / Supertree ---
        coverage_cf_merged,        precision_cf_merged    = compute_cf_coverage(cf_rules_por_clase_simplify)
        coverage_cf_lore,          precision_cf_lore      = compute_cf_coverage(cf_rules_LORE_por_clase_simplify)
        coverage_cf_supertree,     precision_cf_supertree = compute_cf_coverage(cf_rules_Supertree_por_clase_simplify)
        coverage_cf_local_local,   precision_cf_local_local = compute_cf_coverage(cf_rules_local_local_por_clase_simplify)
        coverage_cf_local_super,   precision_cf_local_super = compute_cf_coverage(cf_rules_local_super_por_clase_simplify)
        coverage_cf_localZ,        precision_cf_localZ    = compute_cf_coverage(cf_rules_localZ_por_clase_simplify)
        coverage_cf_superZ,        precision_cf_superZ    = compute_cf_coverage(cf_rules_superZ_por_clase_simplify)



        # ============================================================================
        # Contrafactuales: coverage / support tambi√©n en el VECINDARIO (Z)
        # ============================================================================

        def compute_cf_coverage_neigh(rules_dict):
            coverage_cf_Z  = {}
            support_cf_Z   = {}
            precision_cf_Z = {}

            nZ = len(dfZ_decoded)   # tama√±o del vecindario

            for clase, regla_cf in rules_dict.items():
                if regla_cf:
                    # cover(E) ‚à© Z: instancias del vecindario que cumplen la regla CF
                    mask_cf_Z = dfZ_decoded.apply(
                        lambda r: cumple_regla(r, regla_cf), axis=1
                    ).values

                    supZ = int(mask_cf_Z.sum())              # |cover(E) ‚à© Z|
                    covZ = supZ / nZ if nZ > 0 else 0.0      # coverage local

                    if supZ > 0:
                        # fidelidad local: de las que cumplen el CF,
                        # cu√°ntas el BBOX las clasifica como la clase del CF
                        clase_idx = self.label_encoder.transform([clase])[0]
                        precZ = float((y_bb[mask_cf_Z] == clase_idx).mean())
                    else:
                        precZ = 0.0
                else:
                    supZ, covZ, precZ = 0, 0.0, 0.0

                coverage_cf_Z[clase]  = covZ
                precision_cf_Z[clase] = precZ

            return coverage_cf_Z, precision_cf_Z
        
        # --- Merged / LORE / Supertree (LOCAL, vecindario Z) ---
        coverage_cf_merged_Z,        precision_cf_merged_Z    = compute_cf_coverage_neigh(cf_rules_por_clase_simplify)
        coverage_cf_lore_Z,          precision_cf_lore_Z      = compute_cf_coverage_neigh(cf_rules_LORE_por_clase_simplify)
        coverage_cf_supertree_Z,     precision_cf_supertree_Z = compute_cf_coverage_neigh(cf_rules_Supertree_por_clase_simplify)
        coverage_cf_local_local_Z,   precision_cf_local_local_Z = compute_cf_coverage_neigh(cf_rules_local_local_por_clase_simplify)
        coverage_cf_local_super_Z,   precision_cf_local_super_Z = compute_cf_coverage_neigh(cf_rules_local_super_por_clase_simplify)
        coverage_cf_localZ_Z,        precision_cf_localZ_Z = compute_cf_coverage_neigh(cf_rules_localZ_por_clase_simplify)
        coverage_cf_superZ_Z,        precision_cf_superZ_Z = compute_cf_coverage_neigh(cf_rules_superZ_por_clase_simplify)


        # ================= CSV por cliente =================
        row = {
            "round": int(round_number),
            "dataset": DATASET_NAME,
            "client_id": int(self.client_id),
            "bbox_pred_class": str(pred_class),

            # Vecindario
            "silhouette": float(silhouette),

            # ================= M√©tricas de como de buenos son los √°rboles =================
            "acc_lore": float(acc_lore),
            "acc_merged": float(acc_merged),
            "acc_super": float(acc_super),
            "acc_local_local": float(acc_local_local),
            "acc_local_super": float(acc_local_super),
            "acc_localZ_tree_TEST": float(acc_localZ_tree_TEST),
            "acc_superZ_tree_TEST": float(acc_superZ_tree_TEST),

            "prec_lore": float(prec_lore),
            "prec_merged": float(prec_merged),
            "prec_super": float(prec_super),
            "prec_local_local": float(prec_local_local),
            "prec_local_super": float(prec_local_super),
            "prec_localZ_tree_TEST": float(prec_localZ_tree_TEST),
            "prec_superZ_tree_TEST": float(prec_superZ_tree_TEST),

            "rec_lore": float(rec_lore),
            "rec_merged": float(rec_merged),
            "rec_super": float(rec_super),
            "rec_local_local": float(rec_local_local),
            "rec_local_super": float(rec_local_super),
            "rec_localZ_tree_TEST": float(rec_localZ_tree_TEST),
            "rec_superZ_tree_TEST": float(rec_superZ_tree_TEST),

            "f1_lore": float(f1_lore),
            "f1_merged": float(f1_merged),
            "f1_super": float(f1_super),
            "f1_local_local": float(f1_local_local),
            "f1_local_super": float(f1_local_super),
            "f1_localZ_tree_TEST": float(f1_localZ_tree_TEST),
            "f1_superZ_tree_TEST": float(f1_superZ_tree_TEST),

            # ======= Calidad √°rboles en la vecindad Z =======
            "acc_lore_Z": float(acc_lore_Z),
            "acc_merged_Z": float(acc_merged_Z),
            "acc_super_Z": float(acc_super_Z),
            "acc_local_local_Z": float(acc_local_local_Z),
            "acc_local_super_Z": float(acc_local_super_Z),
            "acc_localZ_Z": float(acc_local_Z),
            "acc_superZ_Z": float(acc_superZ_Z),

            "prec_lore_Z": float(prec_lore_Z),
            "prec_merged_Z": float(prec_merged_Z),
            "prec_super_Z": float(prec_super_Z),
            "prec_local_local_Z": float(prec_local_local_Z),
            "prec_local_super_Z": float(prec_local_super_Z),
            "prec_localZ_Z": float(prec_local_Z),
            "prec_superZ_Z": float(prec_superZ_Z),

            "rec_lore_Z": float(rec_lore_Z),
            "rec_merged_Z": float(rec_merged_Z),
            "rec_super_Z": float(rec_super_Z),
            "rec_local_local_Z": float(rec_local_local_Z),
            "rec_local_super_Z": float(rec_local_super_Z),
            "rec_localZ_Z": float(rec_local_Z),
            "rec_superZ_Z": float(rec_superZ_Z),

            "f1_lore_Z": float(f1_lore_Z),
            "f1_merged_Z": float(f1_merged_Z),
            "f1_super_Z": float(f1_super_Z),
            "f1_local_local_Z": float(f1_local_local_Z),
            "f1_local_super_Z": float(f1_local_super_Z),
            "f1_localZ_Z": float(f1_local_Z),
            "f1_superZ_Z": float(f1_superZ_Z),

            # ================= M√©tricas de explicabilidad del factual ==================
            "coverage_factual_merged": self._to_float(coverage_merged),
            "fidelity_factual_merged": self._to_float(fidelity_merged),
            "hit_factual_merged": int(hit_merged),
            "complexity_factual_merged": int(comp_factual_merged_simpl),

            "coverage_factual_lore": self._to_float(coverage_lore),
            "fidelity_factual_lore": self._to_float(fidelity_lore),
            "hit_factual_lore": int(hit_lore),
            "complexity_factual_lore": int(comp_factual_lore_simpl),

            "coverage_factual_super": self._to_float(coverage_supertree),
            "fidelity_factual_super": self._to_float(fidelity_super),
            "hit_factual_super": int(hit_supertree),
            "complexity_factual_super": int(comp_factual_supertree_simpl),

            "coverage_factual_local_local": self._to_float(coverage_local_local),
            "fidelity_factual_local_local": self._to_float(fidelity_local_local),
            "hit_factual_local_local": int(hit_lore_local_local),
            "complexity_factual_local_local": int(comp_factual_local_local_simpl),

            "coverage_factual_local_super": self._to_float(coverage_local_super),
            "fidelity_factual_local_super": self._to_float(fidelity_local_super),
            "hit_factual_local_super": int(hit_lore_local_super),
            "complexity_factual_local_super": int(comp_factual_local_super_simpl),

            "coverage_factual_localZ": self._to_float(coverage_localZ),
            "fidelity_factual_localZ": self._to_float(fidelity_localZ),
            "hit_factual_localZ": int(hit_localZ),
            "complexity_factual_localZ": int(comp_factual_localZ_simpl),

            "coverage_factual_superZ": self._to_float(coverage_superZ),
            "fidelity_factual_superZ": self._to_float(fidelity_superZ),
            "hit_factual_superZ": int(hit_superZ),
            "complexity_factual_superZ": int(comp_factual_superZ_simpl),

            # ======= Factual LOCAL (vecindad) =======
            "coverage_factual_merged_Z": self._to_float(coverage_merged_Z),
            "fidelity_factual_merged_Z": self._to_float(fidelity_merged_Z),

            "coverage_factual_lore_Z": self._to_float(coverage_lore_Z),
            "fidelity_factual_lore_Z": self._to_float(fidelity_lore_Z),

            "coverage_factual_super_Z": self._to_float(coverage_super_Z),
            "fidelity_factual_super_Z": self._to_float(fidelity_super_Z),

            "coverage_factual_local_local_Z": self._to_float(coverage_local_local_Z),
            "fidelity_factual_local_local_Z": self._to_float(fidelity_local_local_Z),

            "coverage_factual_local_super_Z": self._to_float(coverage_local_super_Z),
            "fidelity_factual_local_super_Z": self._to_float(fidelity_local_super_Z),

            "coverage_factual_localZ_Z": self._to_float(coverage_localZ_Z),
            "fidelity_factual_localZ_Z": self._to_float(fidelity_localZ_Z),

            "coverage_factual_superZ_Z": self._to_float(coverage_superZ_Z),
            "fidelity_factual_superZ_Z": self._to_float(fidelity_superZ_Z),

            # ================= Estructura =================
            "depth_edges_merged": int(depth_merged_edges),
            "nodes_merged": int(nodes_merged),
            "leaves_merged": int(leaves_merged),

            "depth_edges_lore": int(depth_lore_edges),
            "nodes_lore": int(nodes_lore),
            "leaves_lore": int(leaves_lore),

            "depth_edges_super": int(depth_supertree_edges),
            "nodes_super": int(nodes_supertree),
            "leaves_super": int(leaves_supertree),

            "depth_edges_local_local": int(depth_lore_edges_local_local),
            "nodes_local_local": int(nodes_lore_local_local),
            "leaves_local_local": int(leaves_lore_local_local),

            "depth_edges_local_super": int(depth_lore_edges_local_super),
            "nodes_local_super": int(nodes_lore_local_super),
            "leaves_local_super": int(leaves_lore_local_super),
            "depth_edges_localZ": int(depth_localZ_edges),
            "nodes_localZ": int(nodes_localZ),
            "leaves_localZ": int(leaves_localZ),

            "depth_edges_superZ": int(depth_superZ_edges),
            "nodes_superZ": int(nodes_superZ),
            "leaves_superZ": int(leaves_superZ),
        }


        

        # (Opcional) M√©tricas contrafactuales por clase en columnas ‚Äúanchas‚Äù
        for cl in self.unique_labels:
            row[f"cf_cov_merged_{cl}_TEST"]  = self._to_float(coverage_cf_merged.get(cl,0))
            row[f"cf_comp_merged_{cl}_TEST"] = int(comp_cf_merged_simpl.get(cl,0))
            row[f"cf_prec_merged_{cl}_TEST"] = self._to_float(precision_cf_merged.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_lore_{cl}_TEST"]  = self._to_float(coverage_cf_lore.get(cl,0))
            row[f"cf_comp_lore_{cl}_TEST"] = int(comp_cf_lore_simpl.get(cl,0))
            row[f"cf_prec_lore_{cl}_TEST"] = self._to_float(precision_cf_lore.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_super_{cl}_TEST"]  = self._to_float(coverage_cf_supertree.get(cl,0))
            row[f"cf_comp_super_{cl}_TEST"] = int(comp_cf_supertree_simpl.get(cl,0))
            row[f"cf_prec_super_{cl}_TEST"] = self._to_float(precision_cf_supertree.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_local_local_{cl}_TEST"]  = self._to_float(coverage_cf_local_local.get(cl,0))
            row[f"cf_comp_local_local_{cl}_TEST"] = int(comp_cf_local_local_simpl.get(cl,0))
            row[f"cf_prec_local_local_{cl}_TEST"] = self._to_float(precision_cf_local_local.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_local_super_{cl}_TEST"]  = self._to_float(coverage_cf_local_super.get(cl,0))
            row[f"cf_comp_local_super_{cl}_TEST"] = int(comp_cf_local_super_simpl.get(cl,0))
            row[f"cf_prec_local_super_{cl}_TEST"] = self._to_float(precision_cf_local_super.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_localZ_{cl}_TEST"]  = self._to_float(coverage_cf_localZ.get(cl,0))
            row[f"cf_comp_localZ_{cl}_TEST"] = int(comp_cf_localZ_simpl.get(cl,0))
            row[f"cf_prec_localZ_{cl}_TEST"] = self._to_float(precision_cf_localZ.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_superZ_{cl}_TEST"]  = self._to_float(coverage_cf_superZ.get(cl,0))
            row[f"cf_comp_superZ_{cl}_TEST"] = int(comp_cf_superZ_simpl.get(cl,0))
            row[f"cf_prec_superZ_{cl}_TEST"] = self._to_float(precision_cf_superZ.get(cl,0))







        # --- Contrafactuales en VECINDARIO (Z) ---
        for cl in self.unique_labels:
            row[f"cf_cov_merged_{cl}_Z"]   = self._to_float(coverage_cf_merged_Z.get(cl,0))
            row[f"cf_prec_merged_{cl}_Z"]  = self._to_float(precision_cf_merged_Z.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_lore_{cl}_Z"]   = self._to_float(coverage_cf_lore_Z.get(cl,0))
            row[f"cf_prec_lore_{cl}_Z"]  = self._to_float(precision_cf_lore_Z.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_super_{cl}_Z"]   = self._to_float(coverage_cf_supertree_Z.get(cl,0))
            row[f"cf_prec_super_{cl}_Z"]  = self._to_float(precision_cf_supertree_Z.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_local_local_{cl}_Z"]   = self._to_float(coverage_cf_local_local_Z.get(cl,0))
            row[f"cf_prec_local_local_{cl}_Z"]  = self._to_float(precision_cf_local_local_Z.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_local_super_{cl}_Z"]   = self._to_float(coverage_cf_local_super_Z.get(cl,0))
            row[f"cf_prec_local_super_{cl}_Z"]  = self._to_float(precision_cf_local_super_Z.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_localZ_{cl}_Z"]   = self._to_float(coverage_cf_localZ_Z.get(cl,0))
            row[f"cf_prec_localZ_{cl}_Z"]  = self._to_float(precision_cf_localZ_Z.get(cl,0))

        for cl in self.unique_labels:
            row[f"cf_cov_superZ_{cl}_Z"]   = self._to_float(coverage_cf_superZ_Z.get(cl,0))
            row[f"cf_prec_superZ_{cl}_Z"]  = self._to_float(precision_cf_superZ_Z.get(cl,0))

        # Guardar
        self._append_client_csv(row, filename="Balanced")

        return row

    # ======================================================================
    # Bucle sobre todo el test
    # ======================================================================
    def explain_all_test_instances(self, config, only_idx=None):
        results = []

        # Si only_idx es None ‚Üí explicamos TODO el test
        # Si only_idx es un entero ‚Üí explicamos solo esa instancia
        if only_idx is None:
            indices = range(len(self.X_test))
            desc_text = f"Cliente {self.client_id} explicando test completo"
            save_trees_flag = False      

        else:
            indices = [only_idx]
            desc_text = f"Cliente {self.client_id} explicando instancia {only_idx}"
            save_trees_flag = True


        for i in tqdm(indices, desc=desc_text):
            try:

                row = self._explain_one_instance(i, config, save_trees=save_trees_flag)
                results.append(row)

            except Exception as e:
                print(f"[Cliente {self.client_id}] ‚ö†Ô∏è Error en instancia {i}: {e}")
                continue


        # üî• Leer el CSV incremental REAL
        full_path = f"results/metrics_Balanced_cliente_{self.client_id}.csv"
        df = pd.read_csv(full_path)

        # üî• Guardar CSV de medias en formato metric,mean
        mean_metrics = df.mean(numeric_only=True)
        mean_df = mean_metrics.to_frame(name="mean")
        mean_df.to_csv(
            f"results/metrics_cliente_{self.client_id}_balanced_mean.csv",
            index_label="metric"
        )

        return df


        

            


def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    dataset_name = DATASET_NAME 
    class_col = CLASS_COLUMN 


    (X_train, y_train,X_test, y_test,dataset, feature_names,label_encoder, scaler,numeric_features, encoder, preprocessor) = load_data_general(flower_dataset_name=dataset_name,class_col=class_col,partition_id=partition_id,num_partitions=NUM_CLIENTS)

    tree_model = DecisionTreeClassifier(max_depth=3, min_samples_split=2, random_state=42)

    num_idx = list(range(len(numeric_features)))

    scaler_nn = StandardScaler().fit(X_train[:, num_idx])

    def scale_for_nn(X):
        Xs = X.copy().astype(np.float32)
        Xs[:, num_idx] = scaler_nn.transform(Xs[:, num_idx])
        return Xs
    
    X_train_nn = scale_for_nn(X_train)
    X_test_nn  = scale_for_nn(X_test)

    # ‚úÖ SIEMPRE MISMO N√öMERO DE CLASES GLOBAL
    n_clases_global = len(UNIQUE_LABELS)  # o len(label_encoder.classes_)

    input_dim = X_train.shape[1]
    output_dim = n_clases_global

    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model=tree_model, 
                        nn_model=nn_model,
                        X_train=X_train,
                        y_train=y_train,
                        X_test=X_test,
                        y_test=y_test,
                        X_train_nn=X_train_nn, 
                        X_test_nn=X_test_nn,
                        dataset=dataset,
                        client_id=partition_id + 1,
                        feature_names=feature_names,
                        label_encoder=label_encoder,
                        scaler=scaler,
                        numeric_features=numeric_features,
                        encoder=encoder,
                        preprocessor=preprocessor,         
                        scaler_nn_mean=scaler_nn.mean_,  
                        scaler_nn_scale=scaler_nn.scale_,
                        num_idx=num_idx).to_client()

client_app = ClientApp(client_fn=client_fn)


# Servidor

In [5]:
# ============================
# üì¶ IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================
# ‚öôÔ∏è CONFIGURACI√ìN GLOBAL
# ============================
# MIN_AVAILABLE_CLIENTS = 4
# NUM_SERVER_ROUNDS = 2

FEATURES = []  # se rellenan din√°micamente
UNIQUE_LABELS = []
LATEST_SUPERTREE_JSON = None
GLOBAL_MAPPING_JSON = None
FEATURE_NAMES_JSON = None
GLOBAL_SCALER_JSON = None


# ============================
# üß† UTILIDADES MODELO
# ============================
def create_model(input_dim, output_dim):
    from __main__ import Net  # necesario si Net est√° en misma libreta
    return Net(input_dim, output_dim)


def get_model_parameters(tree_model, nn_model):
    tree_params = [-1, 2, 1]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    sums: Dict[str, float] = {}
    counts: Dict[str, int] = {}

    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                sums[k] = sums.get(k, 0.0) + n * float(v)
                counts[k] = counts.get(k, 0) + n

    return {k: sums[k] / counts[k] for k in sums}

# ============================
# üöÄ SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    global FEATURES, UNIQUE_LABELS

    # Justo antes de llamar a create_model
    if not FEATURES or not UNIQUE_LABELS:
        
        load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)


    FEATURES = FEATURES or ["feat_0", "feat_1"]  # fallback por si no se carg√≥ antes
    UNIQUE_LABELS = UNIQUE_LABELS or ["Class_0", "Class_1"]


    model = create_model(len(FEATURES), len(UNIQUE_LABELS))
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])

    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)
    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)

        # ============================
        # üîπ Ronda final: NO fusionar nada
        # ============================
        if server_round > NUM_TRAIN_ROUNDS:
            return aggregated_metrics

        try:
            print(f"\n[SERVIDOR] üå≤ Generando SuperTree - Ronda {server_round}")
            from collections import defaultdict

            tree_nodes = []
            all_distincts = defaultdict(set)
            client_encoders = {}

            feature_names = None
            numeric_features = None
            class_names = None

            # 1) recolectar mapeos categ√≥ricos y metadatos
            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                # distinct_values_* para global_mapping
                for k, v in metrics.items():
                    if k.startswith("distinct_values_"):
                        cid = k.split("_")[-1]
                        enc = json.loads(v)
                        client_encoders[cid] = enc
                        for feat, d in enc.items():
                            all_distincts[feat].update(d["distinct_values"])

            global_mapping = {feat: sorted(list(vals)) for feat, vals in all_distincts.items()}

            # 2) recolectar √°rboles y dem√°s metadatos por cliente
            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                for k, v in metrics.items():
                    if k.startswith("tree_ensemble_"):
                        cid = k.split("_")[-1]
                        trees_list = json.loads(v)

                        # lee estos una sola vez (son iguales por cliente)
                        if feature_names is None and f"encoded_feature_names_{cid}" in metrics:
                            feature_names = json.loads(metrics[f"encoded_feature_names_{cid}"])
                        if numeric_features is None and f"numeric_features_{cid}" in metrics:
                            numeric_features = json.loads(metrics[f"numeric_features_{cid}"])
                        if class_names is None and f"unique_labels_{cid}" in metrics:
                            class_names = json.loads(metrics[f"unique_labels_{cid}"])

                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)
                            tree_nodes.append(root)

            if not tree_nodes:
                return aggregated_metrics

            # 3) fusionar
            st = SuperTree()
            st.mergeDecisionTrees(
                roots=tree_nodes,
                num_classes=len(class_names),
                feature_names=feature_names,
                categorical_features=list(global_mapping.keys()),
                global_mapping=global_mapping,
            )

            # print("\n[SERVIDOR] SuperTree unpruned:")
            # print(st)
            # print("\n")

            # print("\n[SERVIDOR] SuperTree prune_redundant_leaves_full:")
            st.prune_redundant_leaves_full()
            # print(st)
            # print("\n")


            # print("\n[SERVIDOR] SuperTree merge_equal_class_leaves:")
            # st.merge_equal_class_leaves()

            # print(supertree)
            # print("\n")
            
            # print("\n")

            # 4) guardar/emitir
            save_supertree_plot(
                root_node=st.root,
                round_number=server_round,
                feature_names=feature_names,
                class_names=class_names,
                numeric_features=numeric_features,
                global_mapping=global_mapping,   # sin scaler
            )

            LATEST_SUPERTREE_JSON = json.dumps(st.root.to_dict())
            GLOBAL_MAPPING_JSON = json.dumps(global_mapping)
            FEATURE_NAMES_JSON = json.dumps(feature_names)

        except Exception as e:
            print(f"[SERVIDOR] ‚ùå Error en SuperTree: {e}")

        return aggregated_metrics


    strategy.aggregate_evaluate = custom_aggregate_evaluate
    return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=NUM_SERVER_ROUNDS))

# ============================
# üß© FUNCIONES AUXILIARES
# ============================
def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round

            # Siempre mandamos el √∫ltimo SuperTree disponible
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
                ins.config["global_mapping"] = GLOBAL_MAPPING_JSON
                ins.config["feature_names"] = FEATURE_NAMES_JSON

            # Ronda final: modo solo explicaci√≥n
            if server_round == NUM_SERVER_ROUNDS:
                ins.config["explain_only"] = True
        return instructions
    return wrapper



def print_supertree_legible_fusionado(
    node,
    feature_names,
    class_names,
    numeric_features,
    scaler,  # dict con mean y std
    global_mapping,
    depth=0
):
    import numpy as np
    indent = "|   " * depth
    if node is None:
        print(f"{indent}[Nodo None]")
        return

    if getattr(node, "is_leaf", False):
        class_idx = int(np.argmax(node.labels))
        print(f"{indent}class: {class_names[class_idx]} (pred: {node.labels})")
        return

    feat_idx = node.feat
    feat_name = feature_names[feat_idx]
    intervals = node.intervals
    children = node.children

    # ====== NUM√âRICA ======
    if feat_name in numeric_features:
        bounds = [-np.inf] + list(intervals)
        while len(bounds) < len(children) + 1:
            bounds.append(np.inf)

        for i, child in enumerate(children):
            left = bounds[i]
            right = bounds[i + 1]
            left_real  = left
            right_real = right

            if i == 0:
                cond = f"{feat_name} ‚â§ {right_real:.2f}"
            elif i == len(children) - 1:
                cond = f"{feat_name} > {left_real:.2f}"
            else:
                cond = f"{feat_name} ‚àà ({left_real:.2f}, {right_real:.2f}]"
            print(f"{indent}{cond}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features,
                scaler=None,  # ya no se usa
                global_mapping=global_mapping, depth=depth + 1
            )

    # ====== CATEG√ìRICA ONEHOT ======
    elif "=" in feat_name or "_" in feat_name:
        # Soporta 'var=valor' o 'var_valor'
        if "=" in feat_name:
            var, val = feat_name.split("=", 1)
        else:
            var, val = feat_name.split("_", 1)
        var = var.strip()
        val = val.strip()

        if len(children) != 2:
            print(f"[ERROR] Nodo OneHot {feat_name} tiene {len(children)} hijos, esperado 2.")

        # Primero !=, luego ==
        conds = [
            f'{var} != "{val}"',
            f'{var} == "{val}"'
        ]
        for i, child in enumerate(children):
            print(f"{indent}{conds[i]}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )

    # ====== CATEG√ìRICA ORDINAL ======
    elif global_mapping and feat_name in global_mapping:
        vals_cat = global_mapping[feat_name]
        # Primero !=, luego ==
        for i, child in enumerate(children):
            try:
                val_idx = node.intervals[i] if hasattr(node, "intervals") and i < len(node.intervals) else int(getattr(node, "thresh", 0))
                val = vals_cat[val_idx] if val_idx < len(vals_cat) else f"desconocido({val_idx})"
            except Exception as e:
                print(f"[DEPURACI√ìN] Error interpretando categ√≥rica: {e}")
                val = "?"
            cond = f'{feat_name} != "{val}"' if i == 0 else f'{feat_name} == "{val}"'
            print(f"{indent}{cond}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )

    # ====== TIPO DESCONOCIDO ======
    else:
        print(f"{indent}{feat_name} [tipo desconocido]")
        print(f"    [DEPURACI√ìN] Nombres de features: {feature_names}")
        print(f"    [DEPURACI√ìN] Nombres num√©ricas: {numeric_features}")
        print(f"    [DEPURACI√ìN] global_mapping: {list(global_mapping.keys()) if global_mapping else None}")
        print(f"    [DEPURACI√ìN] children: {len(children)}")
        for child in children:
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )



def save_supertree_plot(
    root_node,
    round_number,
    feature_names,
    class_names,
    numeric_features,
    global_mapping,
    folder="Supertree",
):
    from graphviz import Digraph
    import numpy as np
    import os

    dot = Digraph()
    node_id = [0]

    def add_node(node, parent=None, edge_label=""):
        curr = str(node_id[0]); node_id[0] += 1

        # etiqueta
        if node.is_leaf:
            class_index = int(np.argmax(node.labels))
            label = f"class: {class_names[class_index]}\n{node.labels}"
        else:
            fname = feature_names[node.feat]
            label = fname.split("_", 1)[0] if "_" in fname else fname

        dot.node(curr, label)
        if parent: dot.edge(parent, curr, label=edge_label)

        if not node.is_leaf:
            fname = feature_names[node.feat]
            # OneHot
            if "_" in fname:
                _, val = fname.split("_", 1)
                add_node(node.children[0], curr, f'‚â† "{val.strip()}"')
                add_node(node.children[1], curr, f'= "{val.strip()}"')
            # Num√©rica
            elif fname in numeric_features:
                thr = node.intervals[0] if node.intervals else node.thresh
                add_node(node.children[0], curr, f"‚â§ {thr:.2f}")
                add_node(node.children[1], curr, f"> {thr:.2f}")
            # Categ√≥rica ordinal
            elif fname in global_mapping:
                vals = global_mapping[fname]
                val = vals[node.intervals[0]] if node.intervals else "?"
                add_node(node.children[0], curr, f'= "{val}"')
                add_node(node.children[1], curr, f'‚â† "{val}"')
            else:
                for ch in node.children:
                    add_node(ch, curr, "?")

    folder_path = f"Ronda_{round_number}/{folder}"
    os.makedirs(folder_path, exist_ok=True)
    filename = f"{folder_path}/supertree_ronda_{round_number}"
    add_node(root_node)
    dot.render(filename, format="pdf", cleanup=True)
    return f"{filename}.pdf"




# ============================
# üîß INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)



In [6]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray
import cProfile
import pstats

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
logging.getLogger().setLevel(logging.WARNING)  # O ERROR para ocultar a√∫n m√°s
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("fsspec").setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesi√≥n previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

2025-12-11 13:48:47,353	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-12-11 13:48:51,770 flwr         DEBUG    Asyncio event loop already running.
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)


[CLIENTE 2] ‚úÖ Red neuronal entrenada
[CLIENTE 4] ‚úÖ Red neuronal entrenada
[CLIENTE 3] ‚úÖ Red neuronal entrenada
[CLIENTE 1] ‚úÖ Red neuronal entrenada
[CLIENTE 6] ‚úÖ Red neuronal entrenada


[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


[CLIENTE 5] ‚úÖ Red neuronal entrenada


[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)



[SERVIDOR] üå≤ Generando SuperTree - Ronda 1


[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


[CLIENTE 2] ‚úÖ Red neuronal entrenada
[CLIENTE 1] ‚úÖ Red neuronal entrenada
[CLIENTE 5] ‚úÖ Red neuronal entrenada
[CLIENTE 4] ‚úÖ Red neuronal entrenada
[CLIENTE 6] ‚úÖ Red neuronal entrenada
[CLIENTE 3] ‚úÖ Red neuronal entrenada
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....


[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 6 clients (out of 6)



[SERVIDOR] üå≤ Generando SuperTree - Ronda 2


[92mINFO [0m:      aggregate_fit: received 6 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 6 clients (out of 6)


Recibiendo supertree....
[CLIENTE 2] üîç Ronda final: solo explicaciones


Cliente 2 explicando instancia 0:   0%|          | 0/1 [00:00<?, ?it/s]

Recibiendo supertree....
[CLIENTE 3] üîç Ronda final: solo explicaciones




Recibiendo supertree....
[CLIENTE 5] üîç Ronda final: solo explicaciones



[A

Recibiendo supertree....
[CLIENTE 1] üîç Ronda final: solo explicaciones




[A[A

Recibiendo supertree....
[CLIENTE 4] üîç Ronda final: solo explicaciones





[A[A[A

Recibiendo supertree....
[CLIENTE 6] üîç Ronda final: solo explicaciones






[A[A[A[A

Cliente 1 explicando instancia 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:14<00:00, 14.25s/it]
Cliente 2 explicando instancia 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:15<00:00, 15.37s/it]
Cliente 3 explicando instancia 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:15<00:00, 15.61s/it]



Cliente 4 explicando instancia 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:15<00:00, 15.67s/it]

Cliente 5 explicando instancia 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:16<00:00, 16.75s/it]




Cliente 6 explicando instancia 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:10<00:00, 10.72s/it]
[92mINFO [0m:      aggregate_evaluate: received 6 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 40.14s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.0
[92mINFO [0m:      		round 2: 0.0
[92mINFO [0m:      		round 3: 0.0
[92mINFO [0m:      


### BALANCED METRICS

In [7]:
# ==========================================
# üìä Promedio global a partir de los *balanced_mean*
# ==========================================
from pathlib import Path
import pandas as pd
import numpy as np
import re

csv_dir = Path("results")

files = sorted(csv_dir.glob("metrics_cliente_*_balanced_mean.csv"))
print("Voy a usar estos ficheros:")
for f in files:
    print("  -", f.name)

# Leer cada fichero y dejar columnas [metric, value_clienteX]
dfs = []
for f in files:
    df = pd.read_csv(f)          # aqu√≠ s√≠ usamos la cabecera ",0"
    # df.columns suele ser algo como ["Unnamed: 0", "0"]
    df = df.rename(columns={df.columns[0]: "metric", df.columns[1]: f"value_{f.stem}"})
    dfs.append(df)


# Unir por nombre de m√©trica (outer join para no perder nada)
merged = dfs[0]
for df in dfs[1:]:
    merged = merged.merge(df, on="metric", how="outer")

# Calcular la media entre clientes (macro-media)
value_cols = [c for c in merged.columns if c.startswith("value_")]
merged["mean"] = merged[value_cols].mean(axis=1, skipna=True)

means_df = merged[["metric", "mean"]].copy()

# ==========================================
# üìâ Colapsar clases B/M en m√©tricas CF (TEST y Z)
# ==========================================

collapse_patterns = {
    # --------- TEST ----------
    "cf_cov_merged_TEST":   r"^cf_cov_merged_[^_]+_TEST$",
    "cf_comp_merged_TEST":  r"^cf_comp_merged_[^_]+_TEST$",
    "cf_prec_merged_TEST": r"^cf_prec_merged_[^_]+_TEST$",

    "cf_cov_lore_TEST":     r"^cf_cov_lore_[^_]+_TEST$",
    "cf_comp_lore_TEST":    r"^cf_comp_lore_[^_]+_TEST$",
    "cf_prec_lore_TEST":   r"^cf_prec_lore_[^_]+_TEST$",

    "cf_cov_super_TEST":    r"^cf_cov_super_[^_]+_TEST$",
    "cf_comp_super_TEST":   r"^cf_comp_super_[^_]+_TEST$",
    "cf_prec_super_TEST":  r"^cf_prec_super_[^_]+_TEST$",

    "cf_cov_local_local_TEST":   r"^cf_cov_local_local_[^_]+_TEST$",
    "cf_comp_local_local_TEST":  r"^cf_comp_local_local_[^_]+_TEST$",
    "cf_prec_local_local_TEST":  r"^cf_prec_local_local_[^_]+_TEST$",

    "cf_cov_local_super_TEST":   r"^cf_cov_local_super_[^_]+_TEST$",
    "cf_comp_local_super_TEST":  r"^cf_comp_local_super_[^_]+_TEST$",
    "cf_prec_local_super_TEST":  r"^cf_prec_local_super_[^_]+_TEST$",

    "cf_cov_localZ_TEST":   r"^cf_cov_localZ_[^_]+_TEST$",
    "cf_comp_localZ_TEST":  r"^cf_comp_localZ_[^_]+_TEST$",
    "cf_prec_localZ_TEST":  r"^cf_prec_localZ_[^_]+_TEST$",

    "cf_cov_superZ_TEST":   r"^cf_cov_superZ_[^_]+_TEST$",
    "cf_comp_superZ_TEST":  r"^cf_comp_superZ_[^_]+_TEST$",
    "cf_prec_superZ_TEST":  r"^cf_prec_superZ_[^_]+_TEST$",

    # --------- Z ----------
    "cf_cov_merged_Z":   r"^cf_cov_merged_[^_]+_Z$",
    "cf_prec_merged_Z":  r"^cf_prec_merged_[^_]+_Z$",

    "cf_cov_lore_Z":     r"^cf_cov_lore_[^_]+_Z$",
    "cf_prec_lore_Z":    r"^cf_prec_lore_[^_]+_Z$",

    "cf_cov_super_Z":    r"^cf_cov_super_[^_]+_Z$",
    "cf_prec_super_Z":   r"^cf_prec_super_[^_]+_Z$",  # ojo: revisa este patr√≥n si hiciera falta

    "cf_cov_local_local_Z":   r"^cf_cov_local_local_[^_]+_Z$",
    "cf_prec_local_local_Z":  r"^cf_prec_local_local_[^_]+_Z$",

    "cf_cov_local_super_Z":   r"^cf_cov_local_super_[^_]+_Z$",
    "cf_prec_local_super_Z":  r"^cf_prec_local_super_[^_]+_Z$",
    
    "cf_cov_localZ_Z":   r"^cf_cov_localZ_[^_]+_Z$",
    "cf_prec_localZ_Z":  r"^cf_prec_localZ_[^_]+_Z$",

    "cf_cov_superZ_Z":   r"^cf_cov_superZ_[^_]+_Z$",
    "cf_prec_superZ_Z":  r"^cf_prec_superZ_[^_]+_Z$",
}

rows_new = []

for new_name, pattern in collapse_patterns.items():
    mask = means_df["metric"].str.match(pattern)
    sub = means_df[mask]
    if len(sub) == 0:
        continue
    new_mean = sub["mean"].mean()
    rows_new.append({"metric": new_name, "mean": new_mean})

# A√±adir las filas colapsadas
if rows_new:
    means_df = pd.concat([means_df, pd.DataFrame(rows_new)], ignore_index=True)

# Eliminar m√©tricas CF espec√≠ficas por clase (Yes, No, 0, 1, B, M, etc.), tanto TEST como Z
pattern_drop = r"^cf_(cov|prec|comp)_" \
               r"(merged|lore|super|local_local|local_super|localZ|superZ)_" \
               r"[^_]+_(TEST|Z)$"

means_df = means_df[~means_df["metric"].str.match(pattern_drop)].reset_index(drop=True)


# ==========================================
# üíæ Guardar resultado final
# ==========================================

display(means_df.head(30))

out_path = csv_dir / "metrics_Balanced_global.csv"
means_df.to_csv(out_path, index=False, encoding="utf-8")
print(f"\n‚úÖ Promedios globales guardados en: {out_path}")


Voy a usar estos ficheros:
  - metrics_cliente_1_balanced_mean.csv
  - metrics_cliente_2_balanced_mean.csv
  - metrics_cliente_3_balanced_mean.csv
  - metrics_cliente_4_balanced_mean.csv
  - metrics_cliente_5_balanced_mean.csv
  - metrics_cliente_6_balanced_mean.csv


Unnamed: 0,metric,mean
0,acc_localZ_Z,0.519097
1,acc_localZ_tree_TEST,1.0
2,acc_local_local,0.433682
3,acc_local_local_Z,0.977431
4,acc_local_super,0.433682
5,acc_local_super_Z,0.977431
6,acc_lore,0.433682
7,acc_lore_Z,0.977431
8,acc_merged,0.433682
9,acc_merged_Z,0.977431



‚úÖ Promedios globales guardados en: results\metrics_Balanced_global.csv
