In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================
import warnings
import time
import sys
import random
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score, 
    f1_score, confusion_matrix, roc_auc_score
)

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context, NDArrays, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

from graphviz import Digraph

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder.tabular_enc import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OrdinalEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F

# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 2
MIN_AVAILABLE_CLIENTS = 2
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None


class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# 📥 CARGAR DATOS
# =======================

def load_data_general(flower_dataset_name: str, class_col: str, partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES

    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner})

    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]

    # Filtro específico para el dataset "adult"
    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset = dataset[~dataset["workclass"].isin([" ?"])]
        dataset = dataset[~dataset["occupation"].isin([" ?"])]

    # Asegurar columnas categóricas
    for col in dataset.select_dtypes(include=["object"]).columns:
        if dataset[col].nunique() < 50:
            dataset[col] = dataset[col].astype("category")

    # Guardar clase original
    class_original = dataset[class_col].copy()

    # Crear TabularDataset antes de codificar
    tabular_dataset = TabularDataset(dataset.copy(), class_name=class_col)
    descriptor = tabular_dataset.descriptor

    # print("Descriptor:", descriptor)

    # Codificar clase
    label_encoder = LabelEncoder()
    dataset[class_col] = label_encoder.fit_transform(dataset[class_col])
    dataset.rename(columns={class_col: "class"}, inplace=True)
    y = dataset["class"]

    if not UNIQUE_LABELS:
        UNIQUE_LABELS[:] = label_encoder.classes_.tolist()

    # Features
    numeric_features = list(descriptor["numeric"].keys())
    categorical_features = list(descriptor["categorical"].keys())
    FEATURES[:] = numeric_features + categorical_features

    # Índices
    numeric_indices = list(range(len(numeric_features)))
    categorical_indices = list(range(len(numeric_features), len(FEATURES)))

    # print("numeric_indices:", numeric_indices)
    # print("categorical_indices:", categorical_indices)

    # Dataset como array con columnas en el orden correcto
    X_array = dataset[FEATURES].to_numpy()

    # Preprocesador por índice
    preprocessor = ColumnTransformer([
        ("num", StandardScaler(), numeric_indices),
        ("cat", OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1), categorical_indices)
    ])

    X_encoded = preprocessor.fit_transform(X_array)

    # Encoder para explicabilidad
    encoder = ColumnTransformerEnc(descriptor)
    feature_names = list(encoder.encoded_features.values())

    # Split
    split_idx = int(0.8 * len(X_encoded))
    return (
        X_encoded[:split_idx], y[:split_idx],
        X_encoded[split_idx:], y[split_idx:],
        tabular_dataset, feature_names, label_encoder,
        preprocessor.named_transformers_["num"], numeric_features, encoder, preprocessor
    )

# Para Iris
load_data_general("pablopalacios23/Iris", "target", partition_id=0, num_partitions=2)

# load_data_general("pablopalacios23/adult_small", "class", partition_id=0, num_partitions=2)

2025-05-26 14:34:26,523	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-05-26 14:34:29,998 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-05-26 14:34:30,578 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/Iris/resolve/main/README.md HTTP/11" 404 0
2025-05-26 14:34:30,734 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/Iris HTTP/11" 200 609
2025-05-26 14:34:30,882 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/Iris/resolve/6bbdbfec420ddde25fd56eb2d01f4bb904d94740/Iris.py HTTP/11" 404 0
2025-05-26 14:34:30,885 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-05-26 14:34:31,253 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "

(array([[-1.24524215e-02,  4.37539503e-01,  5.24959056e-01,
          7.85098281e-01],
        [-2.45935324e-01, -1.03180958e+00, -2.34440172e-01,
         -3.12941273e-01],
        [ 2.21030481e-01, -1.03180958e+00,  9.92281658e-01,
          2.36078504e-01],
        [-1.53009129e+00,  1.92647990e-01, -1.40274668e+00,
         -1.41098083e+00],
        [ 4.54513384e-01, -7.86918062e-01,  5.83374381e-01,
          7.85098281e-01],
        [ 1.15496209e+00,  4.37539503e-01,  1.05069698e+00,
          1.47137300e+00],
        [-4.79418227e-01, -1.52159260e+00, -1.17609521e-01,
         -3.12941273e-01],
        [ 1.97215225e+00, -5.22435228e-02,  1.57643491e+00,
          1.19686311e+00],
        [ 1.73866935e+00, -5.42026549e-01,  1.28435828e+00,
          9.22353225e-01],
        [-4.79418227e-01,  2.88645463e+00, -1.46116200e+00,
         -1.41098083e+00],
        [ 9.21479189e-01,  6.82431017e-01,  1.05069698e+00,
          1.19686311e+00],
        [-1.06312548e+00,  1.41710556e+00, 

# Cliente

In [2]:
# ==========================
# 🌼 CLIENTE FLOWER
# ==========================
import operator
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from graphviz import Digraph

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()
        

class FlowerClient(NumPyClient):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.dataset = dataset
        self.client_id = client_id
        self.feature_names = feature_names
        self.label_encoder = label_encoder
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder
        self.unique_labels = label_encoder.classes_.tolist()
        self.y_train_nn = y_train.astype(np.int64)
        self.received_supertree = None
        self.preprocessor = preprocessor

    def _train_nn(self, epochs=10, lr=0.01):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        X_tensor = torch.tensor(self.X_train, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()
        print(f"[CLIENTE {self.client_id}] ✅ Red neuronal entrenada")



    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()
        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}
    


    def evaluate(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})

        if "supertree" in config:
            try:
                print("Recibiendo supertree....")
                supertree_dict = json.loads(config["supertree"])
                self.received_supertree = SuperTree.convert_SuperNode_to_Node(SuperTree.SuperNode.from_dict(supertree_dict))
            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ❌ Error al recibir SuperTree: {e}")

        try:
            _ = self.tree_model.predict(self.X_test)
        except NotFittedError:
            self.tree_model.fit(self.X_train, self.y_train)

        y_pred = self.tree_model.predict(self.X_test)
        y_proba = self.tree_model.predict_proba(self.X_test)

        supertree = SuperTree()
        root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])), len(self.unique_labels))
        round_number = config.get("server_round", 1)
        self._save_local_tree(root_node, round_number)
        tree_json = json.dumps([root_node.to_dict()])

        if self.received_supertree is not None:
            self._explain_local_and_global(config)

        if len(np.unique(self.y_test)) == 2:
            # Clasificación binaria: usar la probabilidad de clase positiva
            auc = roc_auc_score(self.y_test, y_proba[:, 1])
        else:
            # Clasificación multiclase
            auc = roc_auc_score(self.y_test, y_proba, multi_class="ovr")

        return float(log_loss(self.y_test, y_proba)), len(self.X_test), {
            "Accuracy": accuracy_score(self.y_test, y_pred),
            "Precision": precision_score(self.y_test, y_pred, average="weighted", zero_division=1),
            "Recall": recall_score(self.y_test, y_pred, average="weighted"),
            "F1_Score": f1_score(self.y_test, y_pred, average="weighted"),
            "AUC": auc,
            "tree_ensemble": tree_json,
            "scaler_mean": json.dumps(self.scaler.mean_.tolist()),
            "scaler_std": json.dumps(self.scaler.scale_.tolist()),
            "encoded_feature_names": json.dumps(self.feature_names)
        }

    def _explain_local_and_global(self, config):
        num_row = 5

        # Reconstruir DataFrame original codificado
        # feature_cols = self.feature_names
        
        local_df = pd.DataFrame(self.X_train, columns=FEATURES).astype(np.float32)
        local_df["target"] = self.label_encoder.inverse_transform(self.y_train_nn)

        # print(local_df.head())

        # print(local_df)

        # TabularDataset + descriptor
        local_tabular_dataset = TabularDataset(local_df, class_name="target")
        descriptor = local_tabular_dataset.get_descriptor()

        encoder = ColumnTransformerEnc(descriptor)
        encoder.set_classes(self.unique_labels)
        self.encoder = encoder  # opcional

        # Explicabilidad local
        nn_wrapper = TorchNNWrapper(self.nn_model)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)

        instance_scaled = local_tabular_dataset.df.iloc[num_row][:-1]
        target = local_tabular_dataset.df.iloc[num_row][-1]
        instance_array = instance_scaled.values.reshape(1, -1).astype(np.float32)
        pred_idx = self.nn_model(torch.tensor(instance_array)).argmax(dim=1).item()
        pred_label = self.label_encoder.inverse_transform([pred_idx])[0]
        print(f"[CLIENTE {self.client_id}] 🤖 Predicción de la red neuronal: {pred_label}")

        # Explicación LORE
        explanation = lore.explain_instance(instance_scaled.astype(np.float32), merge=True)
        lore_tree = explanation["merged_tree"]
        round_number = config.get("server_round", 1)
        self._save_lore_tree(lore_tree.root, round_number)

        # Árbol fusionado
        merged_tree = SuperTree()
        node_LORE_Tree = SuperTree.convert_SuperNode_to_Node(lore_tree.root)
        merged_root = merged_tree.mergeDecisionTrees(
            roots=[node_LORE_Tree, self.received_supertree],
            num_classes=len(self.unique_labels),
            feature_names=self.dataset.df.columns[:-1].tolist()
        )
        merged_tree.root = merged_root
        merged_tree.prune_redundant_leaves_full()
        merged_tree.merge_equal_class_leaves()
        self._save_merged_tree(merged_tree.root, round_number)

        # Decodificación para mostrar instancia legible
        raw_instance = self.dataset.df.iloc[num_row][:-1].values.reshape(1, -1)
        z_encoded = encoder.encode(raw_instance)[0]
        decoded_instance = encoder.decode(np.array([z_encoded]))[0]
        decoded_instance = pd.Series(decoded_instance, index=self.dataset.df.columns[:-1])

        # Desescalar variables numéricas
        if hasattr(self, 'numeric_features') and hasattr(self, 'scaler'):
            for i, col in enumerate(self.numeric_features):
                if col in decoded_instance.index:
                    decoded_instance[col] = instance_scaled[col] * self.scaler.scale_[i] + self.scaler.mean_[i]

        print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar:")
        print(decoded_instance)
        print(f"[CLIENTE {self.client_id}] 🧪 Clase real: {target}")

        # Preprocesamiento para obtener z_encoded
        # ⚠️ Utiliza instancia ya codificada numéricamente para evitar strings
        raw_instance_df = pd.DataFrame(instance_scaled.values.reshape(1, -1), columns=self.dataset.df.columns[:-1])
        raw_instance_preprocessed = self.preprocessor.transform(raw_instance_df)
        z_encoded = encoder.encode(raw_instance_preprocessed).astype(np.float32)[0]

        tree_str = self.tree_to_str(merged_tree.root, self.feature_names)

        # print(tree_str)

        # Regla del árbol fusionado
        # Extraer reglas
        rules = self.extract_rules_from_str(tree_str, target_class=pred_idx)

        if rules:
            print(f"\n[CLIENTE {self.client_id}] 📜 Regla desde árbol:")
            for cond in rules[0]:
                descaled = False
                for var in self.numeric_features:
                    if cond.startswith(var):
                        import re
                        match1 = re.match(rf"{re.escape(var)} (≤|<|>|≥) ([\d\.\-e]+)", cond)
                        match2 = re.match(rf"{re.escape(var)} > ([\d\.\-e]+) ∧ ≤ ([\d\.\-e]+)", cond)

                        if match2:
                            low, high = map(float, match2.groups())
                            idx = self.numeric_features.index(var)
                            low_real = low * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                            high_real = high * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                            print(f"   - {var} > {low_real:.2f} ∧ ≤ {high_real:.2f}")
                            descaled = True
                            break
                        elif match1:
                            op, val = match1.groups()
                            idx = self.numeric_features.index(var)
                            val_real = float(val) * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                            print(f"   - {var} {op} {val_real:.2f}")
                            descaled = True
                            break
                if not descaled:
                    print(f"   - {cond}")
            print(f" ⇒ target = {pred_label}")
        else:
            print(f"[CLIENTE {self.client_id}] ⚠️ No se encontró regla para clase {pred_label}")
            

        cf_rules = self.extract_counterfactual_rules(tree_str, pred_idx)
        if cf_rules:
            print(f"\n🧬 [CLIENTE {self.client_id}] Contrafactuales sugeridos:")
            for class_idx, rule in cf_rules.items():
                label = self.label_encoder.inverse_transform([class_idx])[0]
                print(f" ⇒ Posible clase: {label}")
                seen = set()
                for cond in rule:
                    if cond in seen:
                        continue
                    seen.add(cond)

                    descaled = False
                    for var in self.numeric_features:
                        if cond.startswith(var):
                            import re
                            match1 = re.match(rf"{re.escape(var)} (≤|<|>|≥) ([\d\.\-e]+)", cond)
                            match2 = re.match(rf"{re.escape(var)} > ([\d\.\-e]+) ∧ ≤ ([\d\.\-e]+)", cond)
                            if match2:
                                low, high = map(float, match2.groups())
                                idx = self.numeric_features.index(var)
                                low_real = low * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                                high_real = high * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                                print(f"   - {var} > {low_real:.2f} ∧ ≤ {high_real:.2f}")
                                descaled = True
                                break
                            elif match1:
                                op, val = match1.groups()
                                idx = self.numeric_features.index(var)
                                val_real = float(val) * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                                print(f"   - {var} {op} {val_real:.2f}")
                                descaled = True
                                break
                    if not descaled:
                        print(f"   - {cond}")

        


    def extract_rules_from_str(self, tree_str, target_class):
        lines = tree_str.strip().split("\n")
        path = []
        rules = []

        def recurse(idx, indent_level):
            seen = set()  # ← Aquí guardamos las premisas únicas
            while idx < len(lines):
                line = lines[idx]
                current_indent = len(line) - len(line.lstrip())

                if current_indent < indent_level:
                    return idx
                if "⮕" in line:
                    if f"class = {target_class}" in line:
                        cleaned = []
                        for cond in path:
                            if cond not in seen:
                                cleaned.append(cond)
                                seen.add(cond)
                        rules.append(cleaned)
                    return idx + 1
                elif "if" in line:
                    condition = line.strip()[3:]  # remove 'if '
                    path.append(condition)
                    idx = recurse(idx + 1, current_indent + 2)
                    path.pop()
                else:
                    idx += 1
            return idx

        recurse(0, 0)
        return rules
    

    
    def extract_counterfactual_rules(self, tree_str, predicted_class):
        lines = tree_str.strip().split("\n")
        path = []
        counterfactuals = {}
        
        def recurse(idx, indent_level):
            while idx < len(lines):
                line = lines[idx]
                current_indent = len(line) - len(line.lstrip())

                if current_indent < indent_level:
                    return idx
                if "⮕" in line:
                    import re
                    match = re.search(r"class = (\d+)", line)
                    if match:
                        class_idx = int(match.group(1))
                        if class_idx != predicted_class and class_idx not in counterfactuals:
                            counterfactuals[class_idx] = list(path)
                    return idx + 1
                elif "if" in line:
                    condition = line.strip()[3:]
                    path.append(condition)
                    idx = recurse(idx + 1, current_indent + 2)
                    path.pop()
                else:
                    idx += 1
            return idx

        recurse(0, 0)
        return counterfactuals






    def tree_to_str(self, node, feature_names, depth=0):
        indent = "  " * depth
        result = ""

        if node.is_leaf:
            class_idx = np.argmax(node.labels)
            result += f"{indent}⮕ Leaf: class = {class_idx} | {node.labels}\n"
        else:
            fname = feature_names[node.feat] if node.feat < len(feature_names) else f"X_{node.feat}"

            is_cat = "=" in fname
            base_feat = fname.split("=")[0] if is_cat else fname
            cat_val = fname.split("=")[1] if is_cat else None

            for i, child in enumerate(node.children):
                if is_cat:
                    cond = f"{base_feat} {'≠' if i == 0 else '='} {cat_val}"
                else:
                    if i == 0:
                        cond = f"{base_feat} ≤ {node.intervals[i]:.3f}"
                    elif i < len(node.intervals):
                        cond = f"{base_feat} > {node.intervals[i-1]:.3f} ∧ ≤ {node.intervals[i]:.3f}"
                    else:
                        cond = f"{base_feat} > {node.intervals[i-1]:.3f}"

                result += f"{indent}if {cond}\n"
                result += self.tree_to_str(child, feature_names, depth + 1)

        return result
    
    def _save_local_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node,
            round_number,
            tree_type="Arbol_Local"
        )

    def _save_lore_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="LoreTree"
        )

    def _save_merged_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="MergedTree"
        )

    def _save_generic_tree(self, root_node, round_number, tree_type):
        from graphviz import Digraph
        import numpy as np
        import os

        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            return feat.split('=')[0] if '=' in feat else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = self.unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = self.feature_names[node.feat]
                    label = base_name(fname)
                except:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Árbol tipo SuperTree
            if hasattr(node, "children") and node.children is not None and hasattr(node, "intervals"):
                for i, child in enumerate(node.children):
                    try:
                        fname = self.feature_names[node.feat]
                    except:
                        fname = f"X_{node.feat}"

                    if '=' in fname:
                        attr, val = fname.split('=')
                        edge = f"= {val}" if i == 1 else f"≠ {val}"
                    else:
                        original_feat = base_name(fname)
                        if hasattr(self, "scaler") and original_feat in self.numeric_features:
                            idx = self.numeric_features.index(original_feat)
                            mean = self.scaler.mean_[idx]
                            std = self.scaler.scale_[idx]
                            val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                            val = val * std + mean
                        else:
                            val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"

                    add_node(child, curr, edge)

            # Árbol binario clásico
            elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                try:
                    fname = self.feature_names[node.feat]
                except:
                    fname = f"X_{node.feat}"

                if '=' in fname:
                    attr, val = fname.split('=')
                    left_label = f"≠ {val}"
                    right_label = f"= {val}"
                else:
                    original_feat = base_name(fname)
                    if hasattr(self, "scaler") and original_feat in self.numeric_features:
                        idx = self.numeric_features.index(original_feat)
                        mean = self.scaler.mean_[idx]
                        std = self.scaler.scale_[idx]
                        thresh = node.thresh * std + mean if node.thresh is not None else None
                    else:
                        thresh = node.thresh

                    if thresh is not None:
                        left_label = f"<= {thresh:.2f}"
                        right_label = f"> {thresh:.2f}"
                    else:
                        left_label = "≤ ?"
                        right_label = "> ?"

                if node._left_child:
                    add_node(node._left_child, curr, left_label)
                if node._right_child:
                    add_node(node._right_child, curr, right_label)

        add_node(root_node)
        folder = f"Ronda_{round_number}/{tree_type}_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        filepath = f"{folder}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)



def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    
    dataset_name = context.node_config.get("dataset_name", "pablopalacios23/Iris")
    class_col = context.node_config.get("class_col", "target")

    # dataset_name = context.node_config.get("dataset_name", "pablopalacios23/adult_small")
    # class_col = context.node_config.get("class_col", "class")

    (X_train, y_train,X_test, y_test,dataset, feature_names,label_encoder, scaler,numeric_features, encoder, preprocessor) = load_data_general(flower_dataset_name=dataset_name,class_col=class_col,partition_id=partition_id,num_partitions=num_partitions)

    tree_model = DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)

    input_dim = X_train.shape[1]
    output_dim = len(np.unique(y_train))
    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model=tree_model, 
                        nn_model=nn_model,
                        X_train=X_train,
                        y_train=y_train,
                        X_test=X_test,
                        y_test=y_test,
                        dataset=dataset,
                        client_id=partition_id + 1,
                        feature_names=feature_names,
                        label_encoder=label_encoder,
                        scaler=scaler,
                        numeric_features=numeric_features,
                        encoder=encoder,
                        preprocessor=preprocessor).to_client()

client_app = ClientApp(client_fn=client_fn)


# Servidor

In [3]:
# ============================
# 📦 IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================
MIN_AVAILABLE_CLIENTS = 2
NUM_SERVER_ROUNDS = 2
FEATURES = []  # se rellenan dinámicamente
UNIQUE_LABELS = []
LATEST_SUPERTREE_JSON = None

# ============================
# 🧠 UTILIDADES MODELO
# ============================
def create_model(input_dim, output_dim):
    from __main__ import Net  # necesario si Net está en misma libreta
    return Net(input_dim, output_dim)


def get_model_parameters(tree_model, nn_model):
    tree_params = [-1, 2, 1]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    total = sum(n for n, _ in metrics)
    avg: Dict[str, List[float]] = {}
    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                avg.setdefault(k, []).append(n * float(v))
    return {k: sum(vs) / total for k, vs in avg.items()}

# ============================
# 🚀 SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    global FEATURES, UNIQUE_LABELS

    # Justo antes de llamar a create_model
    if not FEATURES or not UNIQUE_LABELS:
        
        load_data_general("pablopalacios23/Iris", "target", partition_id=0, num_partitions=2)

        # load_data_general("pablopalacios23/adult_small", "class", partition_id=0, num_partitions=2)


    FEATURES = FEATURES or ["feat_0", "feat_1"]  # fallback por si no se cargó antes
    UNIQUE_LABELS = UNIQUE_LABELS or ["Class_0", "Class_1"]


    model = create_model(len(FEATURES), len(UNIQUE_LABELS))
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])

    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)
    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)

        try:
            print(f"\n[SERVIDOR] 🌲 Generando SuperTree - Ronda {server_round}")
            tree_dicts = []
            scaler_means = None
            scaler_stds = None
            feature_names = None

            for client_idx, (_, evaluate_res) in enumerate(results):
                metrics = evaluate_res.metrics

                if metrics.get("scaler_mean") and metrics.get("scaler_std"):
                    scaler_means = json.loads(metrics["scaler_mean"])
                    scaler_stds = json.loads(metrics["scaler_std"])

                if "encoded_feature_names" in metrics:
                    feature_names = json.loads(metrics["encoded_feature_names"])

                trees_json = metrics.get("tree_ensemble")
                if trees_json:
                    try:
                        trees_list = json.loads(trees_json)
                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)
                            if root:
                                tree_dicts.append(root)
                    except Exception as e:
                        print(f"[CLIENTE {client_idx+1}] ❌ Error al parsear árbol: {e}")

            if not tree_dicts:
                print("[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.")
                return aggregated_metrics

            supertree = SuperTree()
            supertree.mergeDecisionTrees(tree_dicts, num_classes=len(UNIQUE_LABELS), feature_names=feature_names)
            supertree.prune_redundant_leaves_full()
            supertree.merge_equal_class_leaves()

            _save_supertree_plot(supertree.root, server_round, feature_names, UNIQUE_LABELS, scaler_means, scaler_stds)
            LATEST_SUPERTREE_JSON = json.dumps(supertree.root.to_dict())

        except Exception as e:
            print(f"[SERVIDOR] ❌ Error en SuperTree: {e}")

        time.sleep(3)
        return aggregated_metrics

    strategy.aggregate_evaluate = custom_aggregate_evaluate
    return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=NUM_SERVER_ROUNDS))

# ============================
# 🧩 FUNCIONES AUXILIARES
# ============================
def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
        return instructions
    return wrapper

def _save_supertree_plot(root_node, round_number, feature_names=None, class_names=None, scaler_means=None, scaler_stds=None):
    dot = Digraph()
    node_id = [0]

    def base_name(feat):
        return feat.split('=')[0] if '=' in feat else feat

    def add_node(node, parent=None, label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        if node.is_leaf:
            class_index = np.argmax(node.labels)
            class_label = class_names[class_index] if class_names else f"Clase {class_index}"
            label_text = f"Clase: {class_label}\n{node.labels}"
        else:
            try:
                fname = feature_names[node.feat]
                label_text = base_name(fname)
            except:
                label_text = f"X_{node.feat}"

        dot.node(curr, label_text)
        if parent:
            dot.edge(parent, curr, label=label)

        if not node.is_leaf:
            for i, child in enumerate(node.children):
                try:
                    feat_val = feature_names[node.feat]
                except:
                    feat_val = f"X_{node.feat}"

                if '=' in feat_val:
                    attr, val = feat_val.split('=')
                    edge_label = f"= {val}" if i == 1 else f"≠ {val}"
                else:
                    if scaler_means and scaler_stds:
                        idx = feature_names.index(feat_val)
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val = val * scaler_stds[idx] + scaler_means[idx]
                    else:
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                    edge_label = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"

                add_node(child, curr, edge_label)

    add_node(root_node)
    folder = f"Ronda_{round_number}/Supertree"
    os.makedirs(folder, exist_ok=True)
    filename = f"{folder}/supertree_ronda_{round_number}"
    dot.render(filename, format="png", cleanup=True)

# ============================
# 🔧 INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)



In [4]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesión previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)


2025-05-26 14:34:38,087	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-05-26 14:34:41,904 flwr         DEBUG    Asyncio event loop already running.
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=2, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 2] ✅ Red neuronal entrenada


[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures



[SERVIDOR] 🌲 Generando SuperTree - Ronda 1


[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 2] ✅ Red neuronal entrenada
Recibiendo supertree....
Recibiendo supertree....
[CLIENTE 1] 🤖 Predicción de la red neuronal: versicolor
[CLIENTE 2] 🤖 Predicción de la red neuronal: setosa

[CLIENTE 1] 🧪 Instancia a explicar:
sepal length (cm)    6.9
sepal width (cm)     3.2
petal length (cm)    5.7
petal width (cm)     2.3
dtype: object
[CLIENTE 1] 🧪 Clase real: virginica

[CLIENTE 1] 📜 Regla desde árbol:
   - sepal length (cm) ≤ 5.37
   - sepal width (cm) ≤ 2.61
   - petal length (cm) > 2.93
 ⇒ target = versicolor

🧬 [CLIENTE 1] Contrafactuales sugeridos:
 ⇒ Posible clase: setosa
   - sepal length (cm) ≤ 5.37
   - sepal width (cm) ≤ 2.61
   - petal length (cm) ≤ 2.93
 ⇒ Posible clase: virginica
   - sepal length (cm) > 5.37
   - sepal width (cm) ≤ 3.74
   - petal length (cm) > 2.93
   - sepal width (cm) ≤ 3.63
   - sepal length (cm) ≤ 5.67
   - sepal length (cm) ≤ 5.63
   - sepal length (cm) ≤ 5.47
   - petal length (cm) ≤ 4.01
   - petal wi

[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures



[CLIENTE 2] 🧪 Instancia a explicar:
sepal length (cm)    4.7
sepal width (cm)     3.2
petal length (cm)    1.6
petal width (cm)     0.2
dtype: object
[CLIENTE 2] 🧪 Clase real: setosa

[CLIENTE 2] 📜 Regla desde árbol:
   - petal length (cm) ≤ 2.60
 ⇒ target = setosa

🧬 [CLIENTE 2] Contrafactuales sugeridos:
 ⇒ Posible clase: versicolor
   - petal length (cm) > 2.60
   - petal length (cm) ≤ 4.71
   - petal width (cm) ≤ 1.75
   - sepal width (cm) ≤ 2.31
 ⇒ Posible clase: virginica
   - petal length (cm) > 2.60
   - petal length (cm) ≤ 4.71
   - petal width (cm) > 1.75 ∧ ≤ 1.95
   - sepal width (cm) ≤ 2.31

[SERVIDOR] 🌲 Generando SuperTree - Ronda 2


[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 2 round(s) in 64.73s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.220446049250313e-16
[92mINFO [0m:      		round 2: 2.220446049250313e-16
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'AUC': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'Accuracy': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'F1_Score': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'Precision': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'Recall': [(1, 1.0), (2, 1.0)]}
[92mINFO [0m:      
