In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================
import warnings
import time
import sys
import random
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score, 
    f1_score, confusion_matrix, roc_auc_score
)

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context, NDArrays, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

from graphviz import Digraph

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F

# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 2
MIN_AVAILABLE_CLIENTS = 2
fds = None  # Cache del FederatedDataset

class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)

    
# =======================
# 🌲 VISUALIZAR SUPERTREE
# =======================

def visualize_supertree(tree, feature_names=None, class_names=None, filename="supertree"):
    dot = Digraph()
    node_id = [0]

    def add_node(node, parent_id=None, edge_label=''):
        curr_id = str(node_id[0])
        node_id[0] += 1

        if node.is_leaf:
            class_index = np.argmax(node.labels)
            class_label = class_names[class_index] if class_names else f"class {class_index}"
            label = f"class: {class_label}\n{node.labels}"
        else:
            fname = f"X_{node.feat}" if feature_names is None else feature_names[node.feat]
            label = f"{fname}"

        dot.node(curr_id, label)

        if parent_id is not None:
            dot.edge(parent_id, curr_id, label=edge_label)

        if not node.is_leaf:
            for i, child in enumerate(node.children):
                label = f"<= {node.intervals[i]:.2f}" if i == 0 else f"> {node.intervals[i - 1]:.2f}"
                add_node(child, curr_id, label)

    add_node(tree)
    dot.render(filename, format='png', cleanup=True)
    print(f"[SERVIDOR] 🌲 SuperTree guardado como '{filename}.png'")

# =======================
# 📄 CONVERTIR ÁRBOL EN TEXTO A NODO
# =======================

def from_text_representation(text: str) -> SuperTree.Node:
    lines = [line.rstrip() for line in text.split("\n") if line.strip()]
    root = None
    stack = []

    for line in lines:
        indent = len(line) - len(line.lstrip())
        level = indent // 4
        content = line.strip()

        if "class:" in content:
            class_info = content.split("class: ")[-1]
            node = SuperTree.Node(is_leaf=True)
            node.predicted_class = class_info
        else:
            feat, cond = content.split(" <= ")
            node = SuperTree.Node(is_leaf=False)
            node.feature = feat.strip()
            node.threshold = float(cond.strip())

        while len(stack) > level:
            stack.pop()

        if stack:
            stack[-1].children.append(node)
        else:
            root = node

        stack.append(node)

    return root

SuperTree.Node.from_text_representation = staticmethod(from_text_representation)

# =======================
# 📥 CARGAR DATOS
# =======================

def load_data(partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES

    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": partitioner})

    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]
    target_column = dataset.columns[-1]

    if dataset[target_column].dtype == "object":
        label_encoder = LabelEncoder()
        dataset[target_column] = label_encoder.fit_transform(dataset[target_column])
    else:
        dataset[target_column] = dataset[target_column].map({0: "Setosa", 1: "Versicolor", 2: "Virginica"})

    dataset.rename(columns={target_column: "target"}, inplace=True)

    if not UNIQUE_LABELS:
        UNIQUE_LABELS = dataset["target"].unique().tolist()
    if not FEATURES:
        FEATURES = dataset.drop(columns=["target"]).columns.tolist()

    tabular_dataset = TabularDataset(dataset, "target")

    # Train/Test split (80/20)
    X = dataset[FEATURES]
    y = dataset["target"]
    split_idx = int(0.8 * len(X))
    X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
    y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]

    return X_train, y_train, X_test, y_test, tabular_dataset

# =======================
# 🧪 PRUEBA DE CARGA LOCAL (solo en ejecución directa)
# =======================

if __name__ == "__main__":
    X_train, y_train, X_test, y_test, dataset = load_data(partition_id=0, num_partitions=NUM_CLIENTS)

    print("UNIQUE_LABELS:", UNIQUE_LABELS)
    print("FEATURES:", FEATURES)

    print("\nContenido del TabularDataset:")
    print(dataset.df.head())

    print("\nDescriptor del TabularDataset:")
    print(dataset.descriptor)

2025-05-09 12:05:35,085	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-05-09 12:05:39,580 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-05-09 12:05:39,795 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/hitorilabs/iris/resolve/main/README.md HTTP/11" 200 0
2025-05-09 12:05:39,919 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/hitorilabs/iris/resolve/fa62476c42edcf9259f895f43da1a7bf9e2697ae/iris.py HTTP/11" 404 0
2025-05-09 12:05:39,921 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-05-09 12:05:40,243 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/hitorilabs/iris/hitorilabs/iris.py HTTP/11" 404 0
2025-05-09 12:05:40,607 urllib3.connectionpool DEBUG

UNIQUE_LABELS: ['Versicolor', 'Virginica', 'Setosa']
FEATURES: ['petal_length', 'petal_width', 'sepal_length', 'sepal_width']

Contenido del TabularDataset:
   petal_length  petal_width  sepal_length  sepal_width      target
0           4.8          1.8           5.9          3.2  Versicolor
1           3.5          1.0           5.7          2.6  Versicolor
2           5.6          1.4           6.1          2.6   Virginica
3           1.5          0.2           4.6          3.1      Setosa
4           4.9          1.8           6.3          2.7   Virginica

Descriptor del TabularDataset:
{'numeric': {'petal_length': {'index': 0, 'min': 1.100000023841858, 'max': 6.699999809265137, 'mean': 3.9026663, 'std': 1.7214837074279785, 'median': 4.5, 'q1': 1.7000000476837158, 'q3': 5.099999904632568}, 'petal_width': {'index': 1, 'min': 0.10000000149011612, 'max': 2.5, 'mean': 1.228, 'std': 0.7334774732589722, 'median': 1.399999976158142, 'q1': 0.3500000089406967, 'q3': 1.7999999523162842}, 'sep

### Definir el cliente federado con Flower

In [2]:
import warnings
import os
import json
from IPython.display import display
import graphviz
import numpy as np
import pandas as pd
from graphviz import Digraph
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

from flwr.client import NumPyClient
from flwr.common import Context

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from flwr.common import parameters_to_ndarrays

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)  # 👈 Forzamos tipo compatible
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)  # 👈 Igual aquí
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            return probs.numpy()

class FlowerClient(NumPyClient):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train.values
        self.y_train = y_train.values
        self.y_train_nn = LabelEncoder().fit_transform(self.y_train).astype(np.int64)
        self.X_test = X_test.values
        self.y_test = y_test.values
        self.dataset = dataset
        self.unique_labels = np.unique(y_train)
        self.client_id = client_id
        self.received_supertree = None

    def _train_nn(self, epochs=10, lr=0.01):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()

        X_tensor = torch.tensor(self.X_train, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)
        

        for epoch in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()

        print(f"[CLIENTE {self.client_id}] ✅ Red neuronal entrenada ")

    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()
        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}
    
    


    def evaluate(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})

        if "supertree" in config:
            try:
                supertree_dict = json.loads(config["supertree"])
                print(f"[CLIENTE {self.client_id}] recibiendo SuperTree...")
                self.received_supertree = SuperTree.SuperNode.from_dict(supertree_dict)
            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ❌ Error al recibir SuperTree: {e}")

        try:
            _ = self.tree_model.predict(self.X_test)
        except NotFittedError:
            self.tree_model.fit(self.X_train, self.y_train)

        y_pred = self.tree_model.predict(self.X_test)
        y_proba = self.tree_model.predict_proba(self.X_test)

        supertree = SuperTree()
        root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])))
        round_number = config.get("server_round", 1)
        self._save_local_tree(root_node, round_number)

        tree_json = json.dumps([root_node.to_dict()])
        print(f"[CLIENTE {self.client_id}] ✅ Árbol local generado y enviado")

        self._explain_local_and_global(config)

        return float(log_loss(self.y_test, y_proba)), len(self.X_test), {
            "Accuracy": accuracy_score(self.y_test, y_pred),
            "Precision": precision_score(self.y_test, y_pred, average="weighted", zero_division=1),
            "Recall": recall_score(self.y_test, y_pred, average="weighted"),
            "F1_Score": f1_score(self.y_test, y_pred, average="weighted"),
            "AUC": roc_auc_score(self.y_test, y_proba, multi_class="ovr"),
            "tree_ensemble": tree_json,
        }
    
    def _print_tree_structure(self, node, prefix=""):
        if node is None:
            print(f"{prefix}None")
            return
        leaf_info = f"LEAF → {node.labels}" if node.is_leaf else f"feat: {node.feat}, intervals: {getattr(node, 'intervals', None)}"
        print(f"{prefix}{leaf_info}")
        if hasattr(node, "children") and node.children:
            for idx, child in enumerate(node.children):
                self._print_tree_structure(child, prefix + f"  [{idx}]→ ")


    def _explain_local_and_global(self, config):
        num_row = 5
        local_df = pd.DataFrame(self.X_train, columns=self.dataset.df.columns[:-1]).astype(np.float32)
        local_df["target"] = self.y_train.astype(str)

        local_tabular_dataset = TabularDataset(local_df, class_name="target")
        descriptor = local_tabular_dataset.get_descriptor()
        encoder = ColumnTransformerEnc(descriptor)

        # 🧠 Wrapper y generador LORE
        nn_wrapper = TorchNNWrapper(self.nn_model)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)

        instance = local_tabular_dataset.df.iloc[num_row][:-1]
        target = local_tabular_dataset.df.iloc[num_row][-1]

        # 🔍 Generamos el árbol LORE
        explanation = lore.explain_instance(instance.astype(np.float32), merge=True)
        lore_tree = explanation["merged_tree"]
        self.lore_tree_root = lore_tree.root
        round_number = config.get("server_round", 1)
        self._save_lore_tree(lore_tree.root, round_number)

        # 🌲 Fusionamos con SuperTree si está presente

        print(f"[CLIENTE {self.client_id}]  LORE:")
        self._print_tree_structure(self.lore_tree_root)
        print(f"[CLIENTE {self.client_id}] SuperTree:")
        self._print_tree_structure(self.received_supertree)


        if self.received_supertree is not None:
            merged_tree = SuperTree()
            merged_root = merged_tree.mergeDecisionTrees(
                roots=[self.lore_tree_root, self.received_supertree],
                num_classes=len(self.unique_labels),
                feature_names=self.dataset.df.columns[:-1].tolist(),
            )

            if merged_root is None:
                print("[DEBUG] ❌ El merge no generó un árbol válido.")
                return

            merged_tree.root = merged_root
            merged_tree.prune_redundant_leaves_full()
            merged_tree.merge_equal_class_leaves()

            print("Arbol podado")
            self._print_tree_structure(merged_tree.root)


            self._save_tree(merged_tree.root, round_number, f"LoreTree+Supertree_cliente_{self.client_id}_ronda_{round_number}", f"LoreTree+Supertree_Cliente_{self.client_id}")

        
        # rule = explanation["rule"]
    
    def _save_local_tree(self, tree, round_number):
        self._save_tree(
            tree,
            round_number,
            f"arbol_local_cliente_{self.client_id}_ronda_{round_number}",
            f"ArbolLocal_Cliente_{self.client_id}"
        )

    def _save_tree(self, root_node, round_number, filename, subfolder):
        dot = Digraph()
        node_id = [0]


        def add_node(node, parent_id=None, edge_label=""):
            curr_id = str(node_id[0])
            node_id[0] += 1

            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = str(self.unique_labels[class_index]) if len(self.unique_labels) > 0 else f"class {class_index}"
                label = f"class: {class_label}\n{node.labels}"
            else:
                fname = self.dataset.df.columns[node.feat] if node.feat is not None else "?"
                label = f"{fname}"

            dot.node(curr_id, label)

            if parent_id:
                dot.edge(parent_id, curr_id, label=edge_label)

            if hasattr(node, "children") and hasattr(node, "intervals"):
                for i, child in enumerate(node.children):
                    # 👇 Aquí cambiamos inf por una notación comprensible
                    left = node.intervals[i - 1] if i > 0 else None
                    right = node.intervals[i] if i < len(node.intervals) else None

                    if left is None:
                        thr_label = f"≤ {right:.2f}" if right != float("inf") else f"> {node.intervals[-2]:.2f}"
                    else:
                        thr_label = f"> {left:.2f}" if right == float("inf") else f"({left:.2f}, {right:.2f}]"

                    add_node(child, curr_id, thr_label)

            elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                if node._left_child:
                    add_node(node._left_child, curr_id, f"≤ {node.thresh:.2f}")
                if node._right_child:
                    add_node(node._right_child, curr_id, f"> {node.thresh:.2f}")

        add_node(root_node)

        round_folder = f"Ronda_{round_number}"
        os.makedirs(round_folder, exist_ok=True)
        folder_path = f"{round_folder}/{subfolder}"
        os.makedirs(folder_path, exist_ok=True)

        filepath = f"{folder_path}/{filename}"
        dot.render(filepath, format="png", cleanup=True)
        # print(f"[CLIENTE {self.client_id}] 🖼 Árbol guardado como '{filepath}.png'")

    def _save_lore_tree(self, root_node, round_number):
        dot = Digraph()
        node_id = [0]

        def add_node(node, parent_id=None, edge_label=""):
            curr_id = str(node_id[0])
            node_id[0] += 1

            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = str(self.unique_labels[class_index]) if len(self.unique_labels) > 0 else f"class {class_index}"
                label = f"Clase: {class_label}\n{node.labels}"
            else:
                fname = self.dataset.df.columns[node.feat] if node.feat is not None else "?"
                label = f"{fname}"

            dot.node(curr_id, label)

            if parent_id:
                dot.edge(parent_id, curr_id, label=edge_label)

            if not node.children:
                return

            for i, child in enumerate(node.children):
                left = node.intervals[i - 1] if i > 0 else None
                right = node.intervals[i] if i < len(node.intervals) else None

                if left is None:
                    thr_label = f"≤ {right:.2f}" if right != float("inf") else f"> {node.intervals[-2]:.2f}"
                else:
                    thr_label = f"> {left:.2f}" if right == float("inf") else f"({left:.2f}, {right:.2f}]"

                add_node(child, curr_id, thr_label)

        add_node(root_node)

        round_folder = f"Ronda_{round_number}"
        os.makedirs(round_folder, exist_ok=True)
        folder_path = f"{round_folder}/LoreTree_Cliente_{self.client_id}"
        os.makedirs(folder_path, exist_ok=True)

        filepath = f"{folder_path}/lore_tree_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)
        print(f"[CLIENTE {self.client_id}] 📄 LoreTree guardado en '{filepath}.png'")

def create_tree_model():
    return DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)

def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    X_train, y_train, X_test, y_test, dataset = load_data(partition_id, num_partitions)
    tree_model = create_tree_model()
    input_dim = X_train.shape[1]
    output_dim = len(np.unique(y_train))
    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id=partition_id + 1).to_client()

client_app = ClientApp(client_fn=client_fn)


# Configurar el Servidor de Flower

In [3]:
# ============================
# 📦 IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F



# ============================
# ⚖️ CONFIGURACIÓN GLOBAL
# ============================
MIN_AVAILABLE_CLIENTS = 2
NUM_SERVER_ROUNDS = 2
FEATURES = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
UNIQUE_LABELS = ["Setosa", "Versicolor", "Virginica"]
LATEST_SUPERTREE_JSON = None  # 🌲 Guardar árbol generado

# ============================
# 🧐 MODELO Y UTILIDADES
# ============================

def create_model():
    input_dim = len(FEATURES)
    output_dim = len(UNIQUE_LABELS)
    return Net(input_dim, output_dim)

def get_model_parameters(tree_model, nn_model):
    tree_params = [ -1, 2, 1 ]  # Valores por defecto para el servidor
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    total = sum(n for n, _ in metrics)
    avg: Dict[str, List[float]] = {}
    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                avg.setdefault(k, []).append(n * float(v))
    return {k: sum(vs) / total for k, vs in avg.items()}

# ============================
# 🚀 SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    model = create_model()
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])


    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)

    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)

        try:
            print(f"\n[SERVIDOR] 🌲 Generando SuperTree - Ronda {server_round}")
            tree_dicts = []
            total_arboles = 0

            for client_idx, (_, evaluate_res) in enumerate(results):
                metrics = evaluate_res.metrics
                trees_json = metrics.get("tree_ensemble", None)
                if trees_json:
                    try:
                        trees_list = json.loads(trees_json)
                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)
                            if root:
                                tree_dicts.append(root)
                                total_arboles += 1
                    except Exception as e:
                        print(f"[CLIENTE {client_idx+1}] ❌ Error al parsear árbol: {e}")

            # print(f"[SERVIDOR] 📊 Total de árboles: {total_arboles}")

            if not tree_dicts:
                print("[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.")
                return aggregated_metrics

            supertree = SuperTree()
            supertree.mergeDecisionTrees(tree_dicts, num_classes=len(UNIQUE_LABELS), feature_names=FEATURES)
            supertree.prune_redundant_leaves_full()
            supertree.merge_equal_class_leaves()

            _save_supertree_plot(supertree.root, server_round, feature_names=FEATURES, class_names=UNIQUE_LABELS)
            LATEST_SUPERTREE_JSON = json.dumps(supertree.root.to_dict())

        except Exception as e:
            print(f"[SERVIDOR] ❌ Error en SuperTree: {e}")

        time.sleep(10)
        return aggregated_metrics

    strategy.aggregate_evaluate = custom_aggregate_evaluate
    config = ServerConfig(num_rounds=NUM_SERVER_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)

# ============================
# 📂 HELPERS
# ============================

def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
        return instructions
    return wrapper

def _save_supertree_plot(root_node, round_number, feature_names=None, class_names=None):
    round_folder = f"Ronda_{round_number}"
    os.makedirs(round_folder, exist_ok=True)

    supertree_folder = f"{round_folder}/Supertree"
    os.makedirs(supertree_folder, exist_ok=True)

    dot = Digraph()
    node_id = [0]

    def add_node(node, parent=None, label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        if node.is_leaf:
            class_index = np.argmax(node.labels)
            class_label = class_names[class_index] if class_names else f"Clase {class_index}"
            label_text = f"Clase: {class_label}\n{node.labels}"
        else:
            fname = f"X_{node.feat}" if feature_names is None else feature_names[node.feat]
            label_text = f"{fname}"

        dot.node(curr, label_text)

        if parent:
            dot.edge(parent, curr, label=label)

        if not node.is_leaf:
            for i, child in enumerate(node.children):
                thr_label = f"<= {node.intervals[i]:.2f}" if i == 0 else f"> {node.intervals[i - 1]:.2f}"
                add_node(child, curr, thr_label)

    add_node(root_node)
    filename = f"{supertree_folder}/supertree_ronda_{round_number}"
    dot.render(filename, format="png", cleanup=True)
    # print(f"[SERVIDOR] ✅ SuperTree guardado como '{filename}.png'")

# ============================
# 🔧 INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)


**Pasos que se realizan en el notebook:**

1. El servidor inicializa el modelo y lo envía a cada uno de los clientes.

2. Cada cliente entrena un RandomForest con su respectivo subconjunto de datos o partición que hemos realizado al principio.

3. Los clientes entrenan, y mandan sus hiperparámetros (Nº de árboles, profundidad, etc.) al servidor.

4. El servidor combina los parámetros y actualiza el modelo global.

5. Se mide el rendimiento del modelo sobre cada cliente, obteniendo también sus contrafactuales y se repite el proceso las rondas que deseemos.

# Ejecutar la Simulación Federada


In [4]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesión previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)


2025-05-09 12:05:47,693	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-05-09 12:05:53,135 flwr         DEBUG    Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=2, no round_timeout
:job_id:01000000
[92mINFO [0m:      
:actor_name:ClientAppActor
[92mINFO [0m:      [INIT]
:actor_name:ClientAppActor
:actor_name:ClientAppActor
[92mINFO [0m:      Using initial global parameters provided by strategy
:actor_name:ClientAppActor
[92mINFO [0m:      Starting evaluation of initial global parameters
:actor_name:ClientAppActor
:actor_name:ClientAppActor
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[CLIENTE 1] ✅ Red neuronal entrenada 
[CLIENTE 2] ✅ Red neuronal entrenada 
[CLIENTE 2] ✅ Árbol local generado y enviado
[CLIENTE 1] ✅ Árbol local generado y enviado
[CLIENTE 2] 📄 LoreTree guardado en 'Ronda_1/LoreTree_Cliente_2/lore_tree_cliente_2_ronda_1.png'
[CLIENTE 2]  LORE:
feat: 1, intervals: [2.8067674772644042, 2.8067774772644043, inf]
  [0]→ feat: 2, intervals: [1.7460639476776123, inf]
  [0]→   [0]→ LEAF → [2. 0.]
  [0]→   [1]→ feat: 0, intervals: [4.802204608917236, inf]
  [0]→   [1]→   [0]→ LEAF → [2. 0.]
  [0]→   [1]→   [1]→ feat: 1, intervals: [2.3419976234436035, inf]
  [0]→   [1]→   [1]→   [0]→ LEAF → [0. 2.]
  [0]→   [1]→   [1]→   [1]→ feat: 1, intervals: [2.744604706764221, inf]
  [0]→   [1]→   [1]→   [1]→   [0]→ LEAF → [2. 0.]
  [0]→   [1]→   [1]→   [1]→   [1]→ LEAF → [0. 1.]
  [1]→ feat: 2, intervals: [1.7460639476776123, inf]
  [1]→   [0]→ LEAF → [2. 0.]
  [1]→   [1]→ feat: 0, intervals: [4.802204608917236, inf]
  [1]→   [1]→   [0]→ LEAF → [2. 0.]
  [1]→   [1]→   

[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures


[CLIENTE 1] 📄 LoreTree guardado en 'Ronda_1/LoreTree_Cliente_1/lore_tree_cliente_1_ronda_1.png'
[CLIENTE 1]  LORE:
feat: 1, intervals: [3.6198809146881104, inf]
  [0]→ LEAF → [0. 2.]
  [1]→ feat: 2, intervals: [5.880961431732178, 5.880971431732178, inf]
  [1]→   [0]→ feat: 0, intervals: [7.554922832717896, 7.5549328327178955, inf]
  [1]→   [0]→   [0]→ feat: 3, intervals: [2.310054911842346, 2.310064911842346, inf]
  [1]→   [0]→   [0]→   [0]→ feat: 1, intervals: [3.926494850387573, 3.9265048503875732, inf]
  [1]→   [0]→   [0]→   [0]→   [0]→ feat: 1, intervals: [3.884672522544861, inf]
  [1]→   [0]→   [0]→   [0]→   [0]→   [0]→ LEAF → [2. 0.]
  [1]→   [0]→   [0]→   [0]→   [0]→   [1]→ LEAF → [0. 1.]
  [1]→   [0]→   [0]→   [0]→   [1]→ LEAF → [0. 1.]
  [1]→   [0]→   [0]→   [0]→   [2]→ LEAF → [1. 0.]
  [1]→   [0]→   [0]→   [1]→ feat: 1, intervals: [3.926494850387573, 3.9265048503875732, inf]
  [1]→   [0]→   [0]→   [1]→   [0]→ feat: 1, intervals: [3.884672522544861, inf]
  [1]→   [0]→   [0]→  

[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[CLIENTE 2] ✅ Red neuronal entrenada 
[CLIENTE 1] ✅ Red neuronal entrenada 
[CLIENTE 2] recibiendo SuperTree...[CLIENTE 1] recibiendo SuperTree...

[CLIENTE 1] ✅ Árbol local generado y enviado
[CLIENTE 2] ✅ Árbol local generado y enviado
[CLIENTE 2] 📄 LoreTree guardado en 'Ronda_2/LoreTree_Cliente_2/lore_tree_cliente_2_ronda_2.png'
[CLIENTE 2]  LORE:
feat: 3, intervals: [1.7100539939308166, 1.7100639939308167, inf]
  [0]→ feat: 2, intervals: [6.40087628364563, inf]
  [0]→   [0]→ LEAF → [2. 0.]
  [0]→   [1]→ LEAF → [0. 1.]
  [1]→ feat: 2, intervals: [6.40087628364563, inf]
  [1]→   [0]→ LEAF → [2. 0.]
  [1]→   [1]→ LEAF → [0. 1.]
  [2]→ feat: 1, intervals: [2.8597201244735717, 2.8597301244735718, inf]
  [2]→   [0]→ feat: 2, intervals: [3.3639074563980103, inf]
  [2]→   [0]→   [0]→ LEAF → [2. 0.]
  [2]→   [0]→   [1]→ feat: 0, intervals: [6.732056617736816, inf]
  [2]→   [0]→   [1]→   [0]→ LEAF → [0. 2.]
  [2]→   [0]→   [1]→   [1]→ feat: 0, intervals: [7.098649978637695, inf]
  [2]→   [0]

[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures



[SERVIDOR] 🌲 Generando SuperTree - Ronda 2


[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 2 round(s) in 140.68s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.220446049250313e-16
[92mINFO [0m:      		round 2: 2.220446049250313e-16
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'AUC': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'Accuracy': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'F1_Score': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'Precision': [(1, 1.0), (2, 1.0)],
[92mINFO [0m:      	 'Recall': [(1, 1.0), (2, 1.0)]}
[92mINFO [0m:      
