In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================

# Built-in
import os
import sys
import re
import time
import json
import random
import warnings
from typing import List, Tuple, Dict
import operator

# NumPy, Pandas, Matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sklearn
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score, pairwise_distances
)
from sklearn.exceptions import NotFittedError
from collections import defaultdict

# Flower
from flwr.client import ClientApp, NumPyClient
from flwr.common import (
    Context, NDArrays, Metrics, Scalar,
    ndarrays_to_parameters, parameters_to_ndarrays
)
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# LORE
from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.rule import Expression, Rule

# Otros
from graphviz import Digraph


2025-07-16 12:24:20,621	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-07-16 12:24:24,188 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.piping.pipe(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-07-16 12:24:24,196 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.rendering.render(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-07-16 12:24:24,200 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.unflattening.unflatten(['stagger', 'fanout', 'chain', 'encoding'])
2025-07-16 12:24:24,203 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.viewing.view(['quiet'])
2025-07-16 12:24:24,209 graphviz._tools DEBUG    deprecate positional args: graphviz.quoting.quote(['is_html_string', 'is_valid_id', 'dot_keywords', 'endswith_odd_number_of_backslashes', 'escape_unescaped

In [2]:
# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 4
MIN_AVAILABLE_CLIENTS = NUM_CLIENTS
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None





class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# 📥 CARGAR DATOS
# =======================

def get_global_onehot_info(flower_dataset_name, class_col):
    partitioner = IidPartitioner(num_partitions=1)
    fds_tmp = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner})
    df = fds_tmp.load_partition(0, "train").with_format("pandas")[:]

    # Preprocesado estándar
    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        df.drop(columns=[col for col in drop_cols if col in df.columns], inplace=True)
        df = df[~df["workclass"].isin([" ?"])]
        df = df[~df["occupation"].isin([" ?"])]
    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        df.drop(columns=[col for col in drop_cols if col in df.columns], inplace=True)
        df['MonthlyCharges'] = pd.to_numeric(df['MonthlyCharges'], errors='coerce')
        df['tenure'] = pd.to_numeric(df['tenure'], errors='coerce')
        df.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)

    for col in df.select_dtypes(include=["object"]).columns:
        if df[col].nunique() < 50:
            df[col] = df[col].astype("category")

    cat_features = [col for col in df.select_dtypes(include="category").columns if col != class_col]
    num_features = [col for col in df.columns if df[col].dtype.kind in "fi" and col != class_col]

    ohe = OneHotEncoder(handle_unknown="ignore")
    ohe.fit(df[cat_features])
    categories_global = ohe.categories_
    onehot_columns = ohe.get_feature_names_out(cat_features).tolist()
    return cat_features, num_features, categories_global, onehot_columns



def load_data_general(flower_dataset_name: str, class_col: str, partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES

    # Saca info global siempre al principio
    cat_features, num_features, categories_global, onehot_columns = get_global_onehot_info(flower_dataset_name, class_col)

    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner})

    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]

    # Preprocesado específico por dataset
    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset = dataset[~dataset["workclass"].isin([" ?"])]
        dataset = dataset[~dataset["occupation"].isin([" ?"])]
    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset['MonthlyCharges'] = pd.to_numeric(dataset['MonthlyCharges'], errors='coerce')
        dataset['tenure'] = pd.to_numeric(dataset['tenure'], errors='coerce')
        dataset.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)

    for col in dataset.select_dtypes(include=["object"]).columns:
        if dataset[col].nunique() < 50:
            dataset[col] = dataset[col].astype("category")

    class_original = dataset[class_col].copy()
    tabular_dataset = TabularDataset(dataset.copy(), class_name=class_col)
    descriptor = tabular_dataset.descriptor

    for col, info in descriptor["categorical"].items():
        if "distinct_values" not in info:
            info["distinct_values"] = list(dataset[col].dropna().unique())

    label_encoder = LabelEncoder()
    label_encoder.fit(dataset[class_col])
    if not UNIQUE_LABELS:
        UNIQUE_LABELS[:] = label_encoder.classes_.tolist()
    label_encoder.classes_ = np.array(UNIQUE_LABELS)
    dataset[class_col] = label_encoder.transform(dataset[class_col])
    dataset.rename(columns={class_col: "class"}, inplace=True)
    y = dataset["class"].reset_index(drop=True).to_numpy()

    numeric_features = list(descriptor["numeric"].keys())
    categorical_features = list(descriptor["categorical"].keys())
    FEATURES[:] = numeric_features + categorical_features

    numeric_indices = list(range(len(numeric_features)))
    categorical_indices = list(range(len(numeric_features), len(FEATURES)))

    X_array = dataset[FEATURES].to_numpy()

    preprocessor = ColumnTransformer([
        ("num", StandardScaler(), numeric_indices),
        ("cat", OneHotEncoder(sparse_output=False, handle_unknown="ignore", categories=categories_global), categorical_indices)
    ])
    X_encoded = preprocessor.fit_transform(X_array)

    # Reconstrucción del DataFrame
    num_out = X_encoded[:, :len(numeric_features)]
    cat_out = X_encoded[:, len(numeric_features):]
    if categorical_features:
        cat_names = preprocessor.named_transformers_["cat"].get_feature_names_out(categorical_features)
    else:
        cat_names = []

    num_names = numeric_features

    X_df = pd.DataFrame(num_out, columns=num_names)
    if len(cat_names) > 0:
        X_cat_df = pd.DataFrame(cat_out, columns=cat_names)
        X_full = pd.concat([X_df.reset_index(drop=True), X_cat_df.reset_index(drop=True)], axis=1)
        for col in onehot_columns:
            if col not in X_cat_df.columns:
                X_full[col] = 0
    else:
        X_full = X_df

    # Rellenar columnas onehot que falten y ordenar
    final_columns = num_names + list(cat_names)
    X_full = X_full[final_columns]
    FEATURES[:] = final_columns

    split_idx = int(0.8 * len(X_full))

        # --- ¡Construye el descriptor global! ---
    descriptor_global = descriptor.copy()
    for i, col in enumerate(cat_features):
        if col in descriptor_global["categorical"]:
            descriptor_global["categorical"][col]["distinct_values"] = list(categories_global[i])

    encoder = ColumnTransformerEnc(descriptor_global)

    return (
        X_full.iloc[:split_idx].to_numpy(), y[:split_idx],
        X_full.iloc[split_idx:].to_numpy(), y[split_idx:],
        tabular_dataset, final_columns, label_encoder,
        preprocessor.named_transformers_["num"], numeric_features, encoder, preprocessor
    )

# =======================



DATASET_NAME = "pablopalacios23/adult_small"
CLASS_COLUMN = "class"


# DATASET_NAME = "pablopalacios23/Iris"
# CLASS_COLUMN = "target"


# DEMASIADO GRANDE EL DATASET :/
# DATASET_NAME = "pablopalacios23/churn"
# CLASS_COLUMN = "Churn" 
 

# =======================


# load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)

In [3]:
X_train, y_train, X_test, y_test, dataset, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor = load_data_general(
    DATASET_NAME, CLASS_COLUMN, partition_id=1, num_partitions=NUM_CLIENTS
)

# Mostrar 5 primeros valores
print("\n📦 X_train (primeras filas):")
print(pd.DataFrame(X_train))

print("\n🎯 y_train (primeros valores):")
print(y_train)

print("\n📦 X_test (primeras filas):")
print(pd.DataFrame(X_test))

print("\n🎯 y_test (primeros valores):")
print(y_test)

# print(encoder)


2025-07-16 12:24:24,288 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-07-16 12:24:24,566 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/main/README.md HTTP/11" 404 0
2025-07-16 12:24:24,747 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/adult_small HTTP/11" 200 612
2025-07-16 12:24:24,873 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/475f19aed5f80dea1d48deab705f11928fe27493/adult_small.py HTTP/11" 404 0
2025-07-16 12:24:24,873 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-07-16 12:24:25,257 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/pablopalacios23/adult_small/pablopalacios23/adult_small.py HTTP/11" 404 0
2025-07-16 12:24:25,399 urllib3.connectionpool DEBUG


📦 X_train (primeras filas):
         0         1    2    3    4    5    6    7    8    9   ...   34   35  \
0  0.800327  1.345708  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  ...  0.0  0.0   
1 -0.457330  0.681161  0.0  0.0  1.0  0.0  0.0  0.0  0.0  1.0  ...  1.0  0.0   
2  0.342997 -2.375756  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  1.0   
3 -0.114332  0.016614  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
4 -1.028992  0.016614  0.0  0.0  1.0  0.0  0.0  1.0  0.0  0.0  ...  1.0  0.0   
5  1.143324  0.016614  0.0  0.0  1.0  0.0  0.0  1.0  0.0  0.0  ...  0.0  0.0   

    36   37   38   39   40   41   42   43  
0  0.0  0.0  1.0  0.0  1.0  0.0  0.0  1.0  
1  0.0  0.0  1.0  1.0  0.0  0.0  0.0  1.0  
2  0.0  0.0  1.0  1.0  0.0  0.0  0.0  1.0  
3  0.0  0.0  1.0  1.0  0.0  0.0  0.0  1.0  
4  0.0  1.0  0.0  1.0  0.0  0.0  0.0  1.0  
5  0.0  1.0  0.0  0.0  1.0  0.0  0.0  1.0  

[6 rows x 44 columns]

🎯 y_train (primeros valores):
[1 0 1 1 0 0]

📦 X_test (primeras filas):
      

# Cliente

In [4]:
# ==========================
# 🌼 CLIENTE FLOWER
# ==========================
import operator
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from sklearn.metrics import pairwise_distances

from graphviz import Digraph

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()
        

class FlowerClient(NumPyClient):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.dataset = dataset
        self.client_id = client_id
        self.feature_names = feature_names
        self.label_encoder = label_encoder
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder
        self.unique_labels = label_encoder.classes_.tolist()
        self.y_train_nn = y_train.astype(np.int64)
        self.y_test_nn = y_test.astype(np.int64)
        self.received_supertree = None
        self.preprocessor = preprocessor

    def _train_nn(self, epochs=10, lr=0.01):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        X_tensor = torch.tensor(self.X_train, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()
        print(f"[CLIENTE {self.client_id}] ✅ Red neuronal entrenada")

    def decode_onehot_instance(self, X_row, numeric_features, encoder, scaler, feature_names):
        """
        X_row: array 1D (ya escalada y codificada con OneHotEncoder)
        numeric_features: lista de nombres de variables numéricas originales
        encoder: tu ColumnTransformerEnc, con .dataset_descriptor
        scaler: StandardScaler ajustado a las numéricas
        feature_names: nombres de TODAS las columnas tras preprocesar (el orden de X_row)
        """
        import numpy as np
        import pandas as pd

        data = {}

        # 1. Variables numéricas: desescalar y añadir
        for i, col in enumerate(numeric_features):
            mean = scaler.mean_[i]
            std = scaler.scale_[i]
            val = X_row[i] * std + mean
            data[col] = val

        # 2. Variables categóricas: busca los 1s en columnas OneHot
        cat_map = encoder.dataset_descriptor["categorical"]
        cat_cols = list(cat_map.keys())
        start = len(numeric_features)

        for col in cat_cols:
            # Todas las columnas onehot de esa variable
            prefix = col + "_"
            candidates = [fname for fname in feature_names if fname.startswith(prefix)]
            found = False
            for cname in candidates:
                idx = feature_names.index(cname)
                if X_row[idx] == 1:
                    valor = cname[len(prefix):].strip()
                    data[col] = valor
                    found = True
                    break
            if not found:
                data[col] = None

        return pd.Series(data)



    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")


            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()


        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}
    


    def evaluate(self, parameters, config):


        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})

        if "supertree" in config:
            try:
                print("Recibiendo supertree....")
                supertree_dict = json.loads(config["supertree"])
                
                # print("supertree_dict")
                # print("supertree_dict:", supertree_dict)
                # print("type:", type(supertree_dict))
                # print("dir(supertree_dict):", dir(supertree_dict))
                # print("\n")

                self.received_supertree = SuperTree.convert_SuperNode_to_Node(SuperTree.SuperNode.from_dict(supertree_dict))
                self.global_mapping = json.loads(config["global_mapping"])
                self.feature_names = json.loads(config["feature_names"])
                self.global_scaler = json.loads(config["global_scaler"])

            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ❌ Error al recibir SuperTree: {e}")

        try:
            _ = self.tree_model.predict(self.X_test)
        except NotFittedError:
            self.tree_model.fit(self.X_train, self.y_train)

        round_number = config.get("server_round", 1)
        
        supertree = SuperTree()
        root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])), len(self.unique_labels))

        # print(f"[CLIENTE {self.client_id}]")
        # print(export_text(self.tree_model, feature_names=FEATURES))
        # print("root_node:", root_node)
        # print("type:", type(root_node))
        # print(dir(root_node))
        # print("\n")
        # print("FEATURES:", FEATURES)

        
        self._save_local_tree(root_node, round_number, FEATURES, self.numeric_features, self.scaler, UNIQUE_LABELS, self.encoder)
        tree_json = json.dumps([root_node.to_dict()])

        if self.received_supertree is not None:
            self._explain_local_and_global(config)

        return 0.0, len(self.X_test), {
            f"tree_ensemble_{self.client_id}": tree_json,
            f"scaler_mean_{self.client_id}": json.dumps(self.scaler.mean_.tolist()),
            f"scaler_std_{self.client_id}": json.dumps(self.scaler.scale_.tolist()),
            f"encoded_feature_names_{self.client_id}": json.dumps(FEATURES),
            f"numeric_features_{self.client_id}": json.dumps(self.numeric_features),
            f"unique_labels_{self.client_id}": json.dumps(self.unique_labels),
            f"encoder_descriptor_{self.client_id}": json.dumps(self.encoder.dataset_descriptor),
            f"distinct_values_{self.client_id}": json.dumps(self.encoder.dataset_descriptor["categorical"])
        }
    
    def _explain_local_and_global(self, config):
        from sklearn.metrics import accuracy_score
        import numpy as np
    
        num_row = 0

        # 1. Visualizar instancia escalada y decodificada usando el encoder/preprocessor ORIGINAL
    
        
        decoded = self.decode_onehot_instance(
            self.X_test[num_row],
            self.numeric_features,
            self.encoder,
            self.scaler,
            self.feature_names
        )

        # print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar (decodificada):")
        # print(decoded)
        # print(f"[CLIENTE {self.client_id}] 🧪 Clase real: {self.label_encoder.inverse_transform([self.y_test_nn[num_row]])[0]}")

        # Asegúrate de que X_test[num_row] es un numpy array del shape correcto (1, n_features)
        x_tensor = torch.tensor(self.X_test[num_row], dtype=torch.float32).unsqueeze(0)  # shape: [1, n_features]

        with torch.no_grad():
            logits = self.nn_model(x_tensor)   # shape: [1, n_classes]
            probs = torch.softmax(logits, dim=1).numpy()
            pred_class_idx = int(probs.argmax(axis=1)[0])

        # Si tienes un label_encoder:
        pred_class = self.label_encoder.inverse_transform([pred_class_idx])[0]


        # 2. Construir DataFrame para LORE (si es necesario, solo para TabularDataset)

        # Ahora crea el TabularDataset legible
        local_df = pd.DataFrame(self.X_train, columns=self.feature_names).astype(np.float32)
        local_df["class"] = self.label_encoder.inverse_transform(self.y_train_nn)
        local_tabular_dataset = TabularDataset(local_df, class_name="class")    

        # Explicabilidad local y la vecindad es generada del train (local_tabular_dataset)
        nn_wrapper = TorchNNWrapper(self.nn_model)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore_vecindad = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)

        
        # Explicación LORE
        x_instance = pd.Series(self.X_test[num_row], index=self.feature_names)
        
        explanation = lore_vecindad.explain_instance(x_instance, merge=True, num_classes=len(UNIQUE_LABELS), feature_names= self.feature_names, categorical_features=list(self.global_mapping.keys()), global_mapping=self.global_mapping, UNIQUE_LABELS=UNIQUE_LABELS)
        lore_tree = explanation["merged_tree"]
        
        # self.print_tree_readable(node=lore_tree.root,feature_names=self.feature_names,class_names=UNIQUE_LABELS,  numeric_features=self.numeric_features,scaler=self.scaler,encoder=self.encoder)
        # print('\n')

        round_number = config.get("server_round", 1)
        
        self.save_lore_tree_image(lore_tree.root, round_number, self.feature_names, self.numeric_features, self.scaler, UNIQUE_LABELS, self.encoder, folder="LoreTree")


        merged_tree = SuperTree()
        merged_tree.mergeDecisionTrees(
            roots=[lore_tree.root, self.received_supertree],
            num_classes=len(self.unique_labels),
            feature_names=self.feature_names,
            categorical_features=list(self.global_mapping.keys()), 
            global_mapping=self.global_mapping
        )

        merged_tree.prune_redundant_leaves_full()

        merged_tree.merge_equal_class_leaves()

        self.save_supertree_plot(root_node=merged_tree.root,round_number=round_number,feature_names=self.feature_names,class_names=self.unique_labels,numeric_features=self.numeric_features,scaler=self.scaler,global_mapping=self.global_mapping,folder="MergedTree")
        
        tree_str = self.tree_to_str(merged_tree.root, self.feature_names, numeric_features=self.numeric_features, scaler=self.scaler, global_mapping=self.global_mapping, unique_labels=self.unique_labels)

        rules = self.extract_rules_from_str(tree_str, target_class_label=pred_class)


        print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar (decodificada):")
        print(decoded)

        
        print(f"🧪 Clase real: {self.label_encoder.inverse_transform([self.y_test_nn[num_row]])[0]}")
        print("\n")

        print(f"pred_class: {repr(pred_class)}")
    
        def cumple_regla(instancia, regla):
            for cond in regla:
                if "∧" in cond:
                    # Maneja condiciones tipo intervalo: 'age > 44.33 ∧ ≤ 48.50'
                    import re
                    # Busca: variable, operador1, valor1, operador2, valor2
                    m = re.match(r'(.+?)([><]=?|≤|≥)\s*([-\d\.]+)\s*∧\s*([><]=?|≤|≥)\s*([-\d\.]+)', cond)
                    if m:
                        var = m.group(1).strip()
                        op1, val1 = m.group(2), float(m.group(3))
                        op2, val2 = m.group(4), float(m.group(5))
                        v = instancia[var]
                        # Evalúa las dos condiciones del intervalo
                        if not (
                            eval(f"v {op1.replace('≤','<=').replace('≥','>=')} {val1}") and
                            eval(f"v {op2.replace('≤','<=').replace('≥','>=')} {val2}")
                        ):
                            return False
                        continue  # sigue al siguiente cond
                # ... resto de tu código tal cual ...
                if "≤" in cond:
                    var, val = cond.split("≤")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] > val:
                        return False
                elif ">=" in cond or "≥" in cond:
                    var, val = cond.replace("≥", ">=").split(">=")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] < val:
                        return False
                elif ">" in cond:
                    var, val = cond.split(">")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] <= val:
                        return False
                elif "<" in cond:
                    var, val = cond.split("<")
                    var = var.strip()
                    val = float(val.strip())
                    if instancia[var] >= val:
                        return False
                elif "≠" in cond:
                    var, val = cond.split("≠")
                    var = var.strip()
                    val = val.strip().replace('"', "")
                    if instancia[var] == val:
                        return False
                elif "=" in cond:
                    var, val = cond.split("=")
                    var = var.strip()
                    val = val.strip().replace('"', "")
                    if instancia[var] != val:
                        return False
            return True

        # Buscar la regla factual (la que cubre la instancia)
        regla_factual = None
        for regla in rules:
            if cumple_regla(decoded, regla):
                regla_factual = regla
                break

        if regla_factual:
            print("Regla factual encontrada:", regla_factual)
            print("\n")
        else:
            print("Ninguna regla cubre la instancia. No hay explicación factual disponible para esta predicción.")
            print("\n")

        # Extraer 1 contrafactual por cada clase distinta a la predicha
        cf_rules_por_clase = {}
        for clase in self.unique_labels:
            if clase != pred_class:
                rules_clase = self.extract_rules_from_str(tree_str, target_class_label=clase)
                if rules_clase:
                    # Elige la más sencilla (menos condiciones)
                    cf_rules_por_clase[clase] = min(rules_clase, key=len)

        print("cf_rules_por_clase:", cf_rules_por_clase)
        print("\n")

        # ========================================
        # 📏 MÉTRICAS DE EXPLICACIÓN tipo LORE 
        # ========================================

        Z = explanation["neighborhood_Z"] # instancias del vecindario sintético generado alrededor del punto a explicar.
        y_surrogate = explanation["neighborhood_Yb"] # predicciones del modelo interpretable (arbol) sobre Z.

        # print("y_surrogate")
        # print(y_surrogate)

        y_nn = nn_wrapper.predict(Z)

        # Convertir Z en DataFrame legible
        dfZ = pd.DataFrame(Z, columns=self.feature_names)


        # ==============================================================================================
        # Silhouette:  Distancia media entre x y las instancias de su misma clase en el vecindario (Z+)
        # ==============================================================================================

        mask_same_class = (y_nn == pred_class_idx)
        mask_diff_class = (y_nn != pred_class_idx)

        Z_plus = dfZ[mask_same_class]
        Z_minus = dfZ[mask_diff_class]

        x = self.X_test[num_row]

        a = pairwise_distances([x], Z_plus).mean() if len(Z_plus) > 0 else 0.0

        b = pairwise_distances([x], Z_minus).mean() if len(Z_minus) > 0 else 0.0

        silhouette = 0.0
        if (a + b) > 0:
            silhouette = (b - a) / max(a, b)

        print(f"Silhouette: {silhouette:.3f}")




        # ===========================================================================================================================================================
        # Fidelity: Porcentaje de veces que el modelo interpretable (LORE tree) predice lo mismo que el modelo original (Red neuronal) en el vecindario generado.

        # Un valor alto de fidelity significa que el árbol está imitando bien a la red neuronal para esa instancia.
        # ===========================================================================================================================================================

        fidelity = accuracy_score(y_nn, y_surrogate)

        print(f"Fidelity: {fidelity:.3f}")




        # ==============================================================================================================================================================================================================
        # Coverage: mide cuántas instancias del vecindario 𝑍 (generado alrededor de la instancia a explicar) cumplen la regla factual 𝑝. Es decir, calcula la proporción de instancias en las que la regla es aplicable.

        # Precisión: proporción de las instancias del vecindario que cumplen la regla factual y que, además, el modelo black-box (tu red neuronal) predice la clase de la regla factual.
        # ==============================================================================================================================================================================================================

        # Decodifica cada fila del vecindario a un formato legible
        dfZ_decoded = dfZ.apply(lambda row: self.decode_onehot_instance(
            row.values, self.numeric_features, self.encoder, self.scaler, self.feature_names
        ), axis=1)
        

        if regla_factual:
            cumplen_regla = dfZ_decoded.apply(lambda row: cumple_regla(row, regla_factual), axis=1)
            coverage = cumplen_regla.mean()
            covered_target_match = (y_nn[cumplen_regla.values] == pred_class_idx)
            if cumplen_regla.sum() > 0:
                precision = covered_target_match.sum() / cumplen_regla.sum()
            else:
                precision = 0.0
        else:
            coverage = 0.0
            precision = 0.0

        print(f"Coverage: {coverage:.3f}")
        print(f"Precision: {precision:.3f}")





    

    def extract_rules_from_str(self, tree_str, target_class_label, exclude=False):
        target_class_label = target_class_label.strip()
        lines = tree_str.strip().split("\n")
        path = []
        rules = []

        def recurse(idx, indent_level):
            seen = set()
            while idx < len(lines):
                line = lines[idx]
                current_indent = len(line) - len(line.lstrip())
                if current_indent < indent_level:
                    return idx
                if "⮕" in line:
                    import re
                    m = re.search(r'class = "([^"]+)"', line)
                    leaf_class = m.group(1).strip() if m else None
                    condition = (leaf_class == target_class_label)
                    if exclude:
                        condition = not condition  # cambia la lógica
                    if condition:
                        cleaned = []
                        for cond in path:
                            if cond not in seen:
                                cleaned.append(cond)
                                seen.add(cond)
                        rules.append(cleaned)
                    return idx + 1
                elif "if" in line:
                    condition = line.strip()[3:]
                    path.append(condition)
                    idx = recurse(idx + 1, current_indent + 2)
                    path.pop()
                else:
                    idx += 1
            return idx

        recurse(0, 0)
        return rules
    


    def decode_onehot_instance(self, X_row, numeric_features, encoder, scaler, feature_names):
        import numpy as np
        import pandas as pd

        x_named = pd.Series(X_row, index=feature_names)
        data = {}

        # Numéricas
        for i, col in enumerate(numeric_features):
            if col in x_named:
                val = x_named[col]
                idx = numeric_features.index(col)
                mean = scaler.mean_[idx]
                std = scaler.scale_[idx]
                data[col] = val * std + mean
            else:
                data[col] = None

        # Categóricas
        cat_map = encoder.dataset_descriptor["categorical"]
        for cat in cat_map:
            onehot_names = [c for c in feature_names if c.startswith(cat + "_")]
            val_found = None
            for c in onehot_names:
                if c in x_named and x_named[c] == 1:
                    val_found = c[len(cat) + 1 :]
                    break
            if val_found is not None:
                data[cat] = val_found.strip()
            else:
                data[cat] = None  # O "?"

        return pd.Series(data)
    
    def decode_Xtrain_to_df(self, X_test, numeric_features, encoder, scaler, feature_names):
        # Lista de diccionarios para cada fila
        decoded_rows = []
        for x in X_test:
            decoded = self.decode_onehot_instance(x, numeric_features, encoder, scaler, feature_names)
            decoded_rows.append(decoded)
        df = pd.DataFrame(decoded_rows)
        return df
    



    
    
    def print_tree_readable(self, node, feature_names, class_names, numeric_features, scaler, encoder, depth=0):
        indent = "|   " * depth

        if node.is_leaf:
            class_idx = int(np.argmax(node.labels))
            print(f"{indent}|--- class: {class_names[class_idx]}")
            return

        feat_name = feature_names[node.feat]

        # --- CASO NUMÉRICA ---
        if feat_name in numeric_features:
            idx = numeric_features.index(feat_name)
            threshold = node.thresh * scaler.scale_[idx] + scaler.mean_[idx]
            print(f"{indent}|--- {feat_name} <= {threshold:.2f}")
            self.print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
            print(f"{indent}|--- {feat_name} > {threshold:.2f}")
            self.print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
            return

        # --- CASO CATEGÓRICA ONE-HOT ---
        if "=" in feat_name:
            # Ejemplo: occupation= Adm-clerical
            var, valor = feat_name.split("=")
            var = var.strip()
            valor = valor.strip()
            # Si threshold == 0.5, OneHot típico: <= 0.5 (no es ese valor), > 0.5 (es ese valor)
            if node.thresh == 0.5:
                print(f"{indent}|--- {var} == \"{valor}\"")
                self.print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
                print(f"{indent}|--- {var} != \"{valor}\"")
                self.print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
            else:
                # Por si hay rarezas (poco frecuente)
                print(f"{indent}|--- {feat_name} <= {node.thresh:.2f}")
                self.print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
                print(f"{indent}|--- {feat_name} > {node.thresh:.2f}")
                self.print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
            return

        # --- SI NO ENCAJA ---
        print(f"{indent}|--- {feat_name} <= {node.thresh:.2f}")
        self.print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
        print(f"{indent}|--- {feat_name} > {node.thresh:.2f}")
        self.print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)






    def tree_to_str(self, node, feature_names, numeric_features=None, scaler=None, global_mapping=None, unique_labels=None, depth=0):
        indent = "  " * depth
        result = ""

        if node.is_leaf:
            class_idx = int(np.argmax(node.labels))
            class_label = unique_labels[class_idx] if unique_labels is not None else str(class_idx)
            result += f'{indent}⮕ Leaf: class = "{class_label.strip()}" | {node.labels}\n'
        else:
            fname = feature_names[node.feat]

            # --- Split OneHot ---
            if "_" in fname:
                var, val = fname.split("_", 1)
                var = var.strip()
                val = val.strip()
                for i, child in enumerate(node.children):
                    cond = f'{var} {"≠" if i == 0 else "="} "{val}"'
                    result += f"{indent}if {cond}\n"
                    result += self.tree_to_str(child, feature_names, numeric_features, scaler, global_mapping, unique_labels, depth + 1)

            # --- Split categórico ordinal ---
            elif global_mapping and fname in global_mapping:
                vals_cat = global_mapping[fname]
                for i, child in enumerate(node.children):
                    val_idx = node.intervals[i] if hasattr(node, "intervals") else int(node.thresh)
                    val = vals_cat[val_idx] if val_idx < len(vals_cat) else f"desconocido({val_idx})"
                    cond = f'{fname} {"≠" if i == 0 else "="} "{val}"'
                    result += f"{indent}if {cond}\n"
                    result += self.tree_to_str(child, feature_names, numeric_features, scaler, global_mapping, unique_labels, depth + 1)

            # --- Split numérico ---
            elif numeric_features and fname in numeric_features:
                idx = numeric_features.index(fname)
                mean = scaler.mean_[idx]
                std = scaler.scale_[idx]
                bounds = [-np.inf] + list(node.intervals)
                for i, child in enumerate(node.children):
                    left = bounds[i]
                    right = bounds[i+1]
                    left_real = left * std + mean if np.isfinite(left) else -np.inf
                    right_real = right * std + mean if np.isfinite(right) else np.inf
                    if i == 0:
                        cond = f"{fname} ≤ {right_real:.2f}"
                    elif i == len(node.children)-1:
                        cond = f"{fname} > {left_real:.2f}"
                    else:
                        cond = f"{fname} ∈ ({left_real:.2f}, {right_real:.2f}]"
                    result += f"{indent}if {cond}\n"
                    result += self.tree_to_str(child, feature_names, numeric_features, scaler, global_mapping, unique_labels, depth + 1)
            else:
                # Por si acaso, caso no detectado
                for child in node.children:
                    result += f"{indent}if {fname} ?\n"
                    result += self.tree_to_str(child, feature_names, numeric_features, scaler, global_mapping, unique_labels, depth + 1)
        return result



    def save_supertree_plot(self, root_node, round_number, feature_names, class_names, numeric_features, scaler, global_mapping, folder="Supertree"):
        dot = Digraph()
        node_id = [0]

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = int(np.argmax(node.labels))
                class_label = class_names[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                fname = feature_names[node.feat]
                # Si es OneHot: "sex_ Male"
                if "_" in fname:
                    var, val = fname.split("_", 1)
                    label = var.strip()
                else:
                    label = fname

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Nodos hijos
            if not node.is_leaf:
                fname = feature_names[node.feat]
                # --- Caso OneHotEncoder ---
                if "_" in fname:
                    var, val = fname.split("_", 1)
                    var = var.strip()
                    val = val.strip()
                    left_label = f'≠ "{val}"'    # <= 0.5
                    right_label = f'= "{val}"'   # > 0.5
                    add_node(node.children[0], curr, left_label)
                    add_node(node.children[1], curr, right_label)

                # --- Caso numérica ---
                elif fname in numeric_features:
                    idx = numeric_features.index(fname)
                    mean = scaler.mean_[idx]
                    std = scaler.scale_[idx]
                    bounds = [-np.inf] + list(node.intervals)
                    for i, child in enumerate(node.children):
                        left = bounds[i]
                        right = bounds[i+1]
                        left_real = left * std + mean if np.isfinite(left) else -np.inf
                        right_real = right * std + mean if np.isfinite(right) else np.inf
                        if i == 0:
                            cond = f"≤ {right_real:.2f}"
                        elif i == len(node.children)-1:
                            cond = f"> {left_real:.2f}"
                        else:
                            cond = f"∈ ({left_real:.2f}, {right_real:.2f}]"
                        add_node(child, curr, cond)

                # --- Caso categórica ordinal ---
                elif fname in global_mapping:
                    vals_cat = global_mapping[fname]
                    for i, child in enumerate(node.children):
                        val = vals_cat[node.intervals[i] if i < len(node.intervals) else -1]
                        edge = f'= "{val}"' if i == 0 else f'≠ "{val}"'
                        add_node(child, curr, edge)
                else:
                    # Desconocido
                    for child in node.children:
                        add_node(child, curr, "?")

        # --- Guardado ---
        folder_path = f"Ronda_{round_number}/{folder}"
        os.makedirs(folder_path, exist_ok=True)
        filename = f"{folder_path}/LoreTree_cliente{self.client_id}_Supertree_ronda_{round_number}"
        add_node(root_node)
        dot.render(filename, format="png", cleanup=True)
        return f"{filename}.png"
    




    def save_lore_tree_image(self, root_node, round_number, feature_names, numeric_features, scaler, unique_labels, encoder, tree_type="LoreTree", folder = "LoreTree"):

        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            match = re.match(r"([a-zA-Z0-9\- ]+)", feat)
            return match.group(1).strip() if match else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1 

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = int(np.argmax(node.labels))
                class_label = unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = feature_names[node.feat]
                    label = base_name(fname)
                except:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Recorrido binario (OneHot y splits normales)
            if hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                try:
                    fname = feature_names[node.feat]
                except:
                    fname = f"X_{node.feat}"

                # Si es OneHot
                if "=" in fname or "_" in fname:
                    if "=" in fname:
                        var, val = fname.split("=")
                    else:
                        var, val = fname.split("_", 1)
                    var = var.strip()
                    val = val.strip()
                    left_label = f'≠ "{val}"'   # <= 0.5 → no es ese valor
                    right_label = f'= "{val}"'  # > 0.5  → sí es ese valor
                else:
                    original_feat = base_name(fname)

                    if original_feat in encoder.dataset_descriptor["categorical"]:
                        val_idx = int(node.thresh)
                        vals_cat = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]
                        val = vals_cat[val_idx] if val_idx < len(vals_cat) else f"desconocido({val_idx})"
                        left_label = f'= "{val}"'
                        right_label = f'≠ "{val}"'

                    elif fname in numeric_features:
                        idx = numeric_features.index(fname)
                        mean = scaler.mean_[idx]
                        std = scaler.scale_[idx]
                        thresh = node.thresh * std + mean
                        left_label = f"<= {thresh:.2f}"
                        right_label = f"> {thresh:.2f}"

                    else:
                        left_label = "≤ ?"
                        right_label = "> ?"

                if hasattr(node, "_left_child") and node._left_child:
                    add_node(node._left_child, curr, left_label)
                if hasattr(node, "_right_child") and node._right_child:
                    add_node(node._right_child, curr, right_label)

        folder_path = f"Ronda_{round_number}/{folder}"
        os.makedirs(folder_path, exist_ok=True)
        filename = f"{folder_path}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        add_node(root_node)
        dot.render(filename, format="png", cleanup=True)
        return f"{filename}.png"
            

    
    def _save_local_tree(self, root_node, round_number, feature_names, numeric_features, scaler, unique_labels, encoder, tree_type= "LocalTree"):
        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            # Extrae solo el nombre de la variable, antes de '_' o '=' o espacios
            match = re.match(r"([a-zA-Z0-9\- ]+)", feat)
            return match.group(1).strip() if match else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1 

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = feature_names[node.feat]
                    label = base_name(fname)
                except:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Árbol tipo SuperTree
            if hasattr(node, "children") and node.children is not None and hasattr(node, "intervals"):
                for i, child in enumerate(node.children):
                    try:
                        fname = feature_names[node.feat]
                    except:
                        fname = f"X_{node.feat}"

                    original_feat = base_name(fname)
                    if original_feat in encoder.dataset_descriptor["categorical"]:
                        val_idx = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val_idx = int(val_idx)
                        vals_cat = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]
                        val = vals_cat[val_idx] if val_idx < len(vals_cat) else f"desconocido({val_idx})"
                        edge = f'= "{val}"' if i == 0 else f'≠ "{val}"'
                    elif original_feat in numeric_features:
                        idx = numeric_features.index(original_feat)
                        mean = scaler.mean_[idx]
                        std = scaler.scale_[idx]
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val = val * std + mean
                        edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"
                    else:
                        edge = "?"

                    add_node(child, curr, edge)

            elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                try:
                    fname = feature_names[node.feat]
                except:
                    fname = f"X_{node.feat}"

                # Si es OneHot
                if "_" in fname:
                    var, val = fname.split("_", 1)
                    var = var.strip()
                    val = val.strip()
                    # La split es: Si sex_ Male <= 0.5  (NO es Male)
                    #              Si sex_ Male > 0.5   (SÍ es Male)
                    left_label = f'≠ "{val}"'   # <= 0.5 → no es ese valor
                    right_label = f'= "{val}"'  # > 0.5  → sí es ese valor
                else:
                    original_feat = base_name(fname)

                    if original_feat in encoder.dataset_descriptor["categorical"]:
                        val_idx = int(node.thresh)
                        vals_cat = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]
                        val = vals_cat[val_idx] if val_idx < len(vals_cat) else f"desconocido({val_idx})"
                        left_label = f'= "{val}"'
                        right_label = f'≠ "{val}"'

                    elif fname in numeric_features:
                        idx = numeric_features.index(fname)
                        mean = scaler.mean_[idx]
                        std = scaler.scale_[idx]
                        thresh = node.thresh * std + mean
                        left_label = f"<= {thresh:.2f}"
                        right_label = f"> {thresh:.2f}"
                        
                    else:
                        left_label = "≤ ?"
                        right_label = "> ?"

                if node._left_child:
                    add_node(node._left_child, curr, left_label)
                if node._right_child:
                    add_node(node._right_child, curr, right_label)

        add_node(root_node)
        folder = f"Ronda_{round_number}/{tree_type}_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        filepath = f"{folder}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)




def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    
    # dataset_name = context.node_config.get("dataset_name", "pablopalacios23/Iris")
    # class_col = context.node_config.get("class_col", "target")

    dataset_name = DATASET_NAME 
    class_col = CLASS_COLUMN 

    (X_train, y_train,X_test, y_test,dataset, feature_names,label_encoder, scaler,numeric_features, encoder, preprocessor) = load_data_general(flower_dataset_name=dataset_name,class_col=class_col,partition_id=partition_id,num_partitions=num_partitions)

    tree_model = DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)

    input_dim = X_train.shape[1]
    output_dim = len(np.unique(y_train))
    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model=tree_model, 
                        nn_model=nn_model,
                        X_train=X_train,
                        y_train=y_train,
                        X_test=X_test,
                        y_test=y_test,
                        dataset=dataset,
                        client_id=partition_id + 1,
                        feature_names=feature_names,
                        label_encoder=label_encoder,
                        scaler=scaler,
                        numeric_features=numeric_features,
                        encoder=encoder,
                        preprocessor=preprocessor).to_client()

client_app = ClientApp(client_fn=client_fn)


# Servidor

In [5]:
# ============================
# 📦 IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================
# MIN_AVAILABLE_CLIENTS = 4
# NUM_SERVER_ROUNDS = 2

FEATURES = []  # se rellenan dinámicamente
UNIQUE_LABELS = []
LATEST_SUPERTREE_JSON = None
GLOBAL_MAPPING_JSON = None
FEATURE_NAMES_JSON = None
GLOBAL_SCALER_JSON = None


# ============================
# 🧠 UTILIDADES MODELO
# ============================
def create_model(input_dim, output_dim):
    from __main__ import Net  # necesario si Net está en misma libreta
    return Net(input_dim, output_dim)


def get_model_parameters(tree_model, nn_model):
    tree_params = [-1, 2, 1]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    total = sum(n for n, _ in metrics)
    avg: Dict[str, List[float]] = {}
    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                avg.setdefault(k, []).append(n * float(v))
    return {k: sum(vs) / total for k, vs in avg.items()}

# ============================
# 🚀 SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    global FEATURES, UNIQUE_LABELS

    # Justo antes de llamar a create_model
    if not FEATURES or not UNIQUE_LABELS:
        
        load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)


    FEATURES = FEATURES or ["feat_0", "feat_1"]  # fallback por si no se cargó antes
    UNIQUE_LABELS = UNIQUE_LABELS or ["Class_0", "Class_1"]


    model = create_model(len(FEATURES), len(UNIQUE_LABELS))
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])

    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)
    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON, GLOBAL_SCALER_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)


        try:
            print(f"\n[SERVIDOR] 🌲 Generando SuperTree - Ronda {server_round}")
            tree_dicts = []
            all_distincts = defaultdict(set)
            client_encoders = {}

            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                for key, value in metrics.items():
                    if key.startswith("distinct_values_"):
                        client_id = key.split("_")[-1]
                        client_encoders[client_id] = json.loads(value)
                        for feat, d in client_encoders[client_id].items():
                            all_distincts[feat].update(d["distinct_values"])

            global_mapping = {feat: sorted(list(vals)) for feat, vals in all_distincts.items()}

            all_means = []
            all_stds = []

            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                for key, value in metrics.items():
                    if key.startswith("tree_ensemble_"):
                        client_id = key.split("_")[-1]
                        trees_list = json.loads(value)
                        local_encoder = client_encoders[client_id]
                        feature_names = json.loads(metrics.get(f"encoded_feature_names_{client_id}"))
                        numeric_features = json.loads(metrics.get(f"numeric_features_{client_id}"))
                        unique_labels = json.loads(metrics.get(f"unique_labels_{client_id}"))
                        scaler = {
                            "mean": json.loads(metrics.get(f"scaler_mean_{client_id}")),
                            "std": json.loads(metrics.get(f"scaler_std_{client_id}")),
                        }

                        # Guarda los scalers de cada cliente
                        all_means.append(scaler["mean"])
                        all_stds.append(scaler["std"])
                        
                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)

                            # print("Local tree del cliente", client_id)
                            # print("root:", root)
                            # print("type:", type(root))
                            # print("dir(root):", dir(root))
                            # print("\n")

                            tree_dicts.append(root)

                # Calcular el scaler promedio
                global_mean = np.mean(np.stack(all_means), axis=0)
                global_std = np.mean(np.stack(all_stds), axis=0)
                global_scaler = {"mean": global_mean, "std": global_std}


                            
            # print(tree_dicts)
            
            if not tree_dicts:
                print("[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.")
                return aggregated_metrics
            
            supertree = SuperTree()
            roots = tree_dicts
            
            supertree.mergeDecisionTrees(roots, num_classes=len(UNIQUE_LABELS), feature_names=feature_names, categorical_features=list(global_mapping.keys()), global_mapping=global_mapping)
            # print("\n[SERVIDOR] SuperTree unpruned:")
            # print(supertree)
            # print("\n")

            # print("\n[SERVIDOR] SuperTree prune_redundant_leaves_full:")
            supertree.prune_redundant_leaves_full()
            # print(supertree)
            # print("\n")

            # print("\n[SERVIDOR] SuperTree merge_equal_class_leaves:")
            supertree.merge_equal_class_leaves()
            # print(supertree)
            # print("\n")
            
            # print("\n")


            # print("supertree.root.to_dict(): ", supertree.root.to_dict())
            # print("type:", type(supertree.root))
            # print("dir(supertree.root): ", dir(supertree.root))
            # print("\n")



            # print("\n[SERVIDOR] 🌳 SuperTree legible (nombre de variables):")
            # print_supertree_legible_fusionado(
            #     supertree.root,
            #     feature_names=feature_names,
            #     class_names=UNIQUE_LABELS,
            #     numeric_features=numeric_features,
            #     scaler=global_scaler,  # <-- ahora el scaler promedio
            #     global_mapping=global_mapping
            # )
            

            save_supertree_plot(
                root_node=supertree.root,
                round_number=server_round,
                feature_names=feature_names,
                class_names=UNIQUE_LABELS,
                numeric_features=numeric_features,
                scaler=global_scaler,
                global_mapping=global_mapping
            )

            LATEST_SUPERTREE_JSON = json.dumps(supertree.root.to_dict())

            GLOBAL_MAPPING_JSON = json.dumps(global_mapping)

            FEATURE_NAMES_JSON = json.dumps(feature_names)
      
            global_scaler = {
                "mean": global_mean.tolist(),
                "std": global_std.tolist()
            }

            GLOBAL_SCALER_JSON = json.dumps(global_scaler)

        except Exception as e:
            print(f"[SERVIDOR] ❌ Error en SuperTree: {e}")

        time.sleep(3)
        return aggregated_metrics

    strategy.aggregate_evaluate = custom_aggregate_evaluate
    return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=NUM_SERVER_ROUNDS))

# ============================
# 🧩 FUNCIONES AUXILIARES
# ============================
def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON, GLOBAL_SCALER_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round
            
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
                ins.config["global_mapping"] = GLOBAL_MAPPING_JSON
                ins.config["feature_names"] = FEATURE_NAMES_JSON
                ins.config["global_scaler"] = GLOBAL_SCALER_JSON
                
        return instructions
    return wrapper



def print_supertree_legible_fusionado(
    node,
    feature_names,
    class_names,
    numeric_features,
    scaler,  # dict con mean y std
    global_mapping,
    depth=0
):
    indent = "|   " * depth
    if node is None:
        print(f"{indent}[Nodo None]")
        return

    if getattr(node, "is_leaf", False):
        class_idx = int(np.argmax(node.labels))
        print(f"{indent}class: {class_names[class_idx]} (pred: {node.labels})")
        return

    feat_idx = node.feat
    feat_name = feature_names[feat_idx]
    intervals = node.intervals
    children = node.children

    if feat_name in numeric_features:
        idx = numeric_features.index(feat_name)
        mean = scaler["mean"][idx]
        std = scaler["std"][idx]
        # Construir los límites reales
        bounds = [-np.inf] + list(intervals)
        for i, child in enumerate(children):
            left = bounds[i]
            right = bounds[i+1]
            left_real = left * std + mean if np.isfinite(left) else -np.inf
            right_real = right * std + mean if np.isfinite(right) else np.inf

            if i == 0:
                cond = f"{feat_name} ≤ {right_real:.2f}"
            elif i == len(children)-1:
                cond = f"{feat_name} > {left_real:.2f}"
            else:
                cond = f"{feat_name} ∈ ({left_real:.2f}, {right_real:.2f}]"
            print(f"{indent}{cond}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )

    elif "=" in feat_name:
        var, val = feat_name.split("=", 1)
        var = var.strip()
        val = val.strip()
        for i, child in enumerate(children):
            cond = f'{var} == "{val}"' if i == 0 else f'{var} != "{val}"'
            print(f"{indent}{cond}")
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )
    else:
        print(f"{indent}{feat_name} [tipo desconocido]")
        for child in children:
            print_supertree_legible_fusionado(
                child, feature_names, class_names, numeric_features, scaler, global_mapping, depth + 1
            )



def save_supertree_plot(
    root_node,
    round_number,
    feature_names,
    class_names,
    numeric_features,
    scaler,           # dict con mean y std
    global_mapping,
    folder="Supertree"
):
    dot = Digraph()
    node_id = [0]

    def add_node(node, parent=None, edge_label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        # Etiqueta del nodo
        if node.is_leaf:
            class_index = int(np.argmax(node.labels))
            class_label = class_names[class_index]
            label = f"class: {class_label}\n{node.labels}"
        else:
            fname = feature_names[node.feat]
            # Si es OneHot: "sex_ Male"
            if "_" in fname:
                var, val = fname.split("_", 1)
                label = var.strip()
            else:
                label = fname

        dot.node(curr, label)
        if parent:
            dot.edge(parent, curr, label=edge_label)

        # Nodos hijos
        if not node.is_leaf:
            fname = feature_names[node.feat]
            # --- Caso OneHotEncoder ---
            if "_" in fname:
                var, val = fname.split("_", 1)
                var = var.strip()
                val = val.strip()
                left_label = f'≠ "{val}"'    # <= 0.5
                right_label = f'= "{val}"'   # > 0.5
                # Solo 2 hijos: left y right
                add_node(node.children[0], curr, left_label)
                add_node(node.children[1], curr, right_label)

            # --- Caso numérica ---
            elif fname in numeric_features:
                idx = numeric_features.index(fname)
                mean = scaler["mean"][idx]
                std = scaler["std"][idx]
                bounds = [-np.inf] + list(node.intervals)
                for i, child in enumerate(node.children):
                    left = bounds[i]
                    right = bounds[i+1]
                    left_real = left * std + mean if np.isfinite(left) else -np.inf
                    right_real = right * std + mean if np.isfinite(right) else np.inf
                    if i == 0:
                        cond = f"≤ {right_real:.2f}"
                    elif i == len(node.children)-1:
                        cond = f"> {left_real:.2f}"
                    else:
                        cond = f"∈ ({left_real:.2f}, {right_real:.2f}]"
                    add_node(child, curr, cond)

            # --- Caso categórica ordinal ---
            elif fname in global_mapping:
                vals_cat = global_mapping[fname]
                for i, child in enumerate(node.children):
                    val = vals_cat[node.intervals[i] if i < len(node.intervals) else -1]
                    edge = f'= "{val}"' if i == 0 else f'≠ "{val}"'
                    add_node(child, curr, edge)
            else:
                # Desconocido
                for child in node.children:
                    add_node(child, curr, "?")

    # --- Guardado ---
    folder_path = f"Ronda_{round_number}/{folder}"
    os.makedirs(folder_path, exist_ok=True)
    filename = f"{folder_path}/supertree_ronda_{round_number}"
    add_node(root_node)
    dot.render(filename, format="png", cleanup=True)
    # print(f"Árbol SuperTree guardado en: {filename}.png")
    return f"{filename}.png"




# ============================
# 🔧 INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)



In [6]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
logging.getLogger().setLevel(logging.WARNING)  # O ERROR para ocultar aún más
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("fsspec").setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesión previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)


2025-07-16 12:24:34,365	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-07-16 12:24:38,568 flwr         DEBUG    Asyncio event loop already running.
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=2, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)
[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)


[CLIENTE 3] ✅ Red neuronal entrenada
[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 2] ✅ Red neuronal entrenada
[CLIENTE 4] ✅ Red neuronal entrenada


[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures



[SERVIDOR] 🌲 Generando SuperTree - Ronda 1


[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)


[CLIENTE 4] ✅ Red neuronal entrenada
[CLIENTE 3] ✅ Red neuronal entrenada


[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)


[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 2] ✅ Red neuronal entrenada
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....

[CLIENTE 4] 🧪 Instancia a explicar (decodificada):
age                        41.0
hours-per-week             40.0
workclass               Private
education                   9th
marital-status        Separated
occupation        Other-service
relationship      Not-in-family
race                      White
sex                        Male
native-country    United-States
dtype: object
🧪 Clase real:  <=50K


pred_class: ' >50K'
Ninguna regla cubre la instancia. No hay explicación factual disponible para esta predicción.


cf_rules_por_clase: {' <=50K': ['age > 48.50', 'hours-per-week ≤ 24.50']}


Silhouette: 0.922
Fidelity: 0.569

[CLIENTE 2] 🧪 Instancia a explicar (decodificada):
age                             53.0
hours-per-week                  40.0
workclass                  Local-gov
education               Some

[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures


Coverage: 0.880
Precision: 1.000

[CLIENTE 3] 🧪 Instancia a explicar (decodificada):
age                             65.0
hours-per-week                  15.0
workclass                    Private
education                  Bachelors
marital-status    Married-civ-spouse
occupation                     Sales
relationship                 Husband
race                           White
sex                             Male
native-country         United-States
dtype: object
🧪 Clase real:  <=50K


pred_class: ' >50K'
Ninguna regla cubre la instancia. No hay explicación factual disponible para esta predicción.


cf_rules_por_clase: {' <=50K': ['age > 49.94', 'hours-per-week ≤ 30.15']}


Silhouette: 0.564
Fidelity: 0.971
Coverage: 0.000
Precision: 0.000

[SERVIDOR] 🌲 Generando SuperTree - Ronda 2


[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 2 round(s) in 93.80s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.0
[92mINFO [0m:      		round 2: 0.0
[92mINFO [0m:      
