In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================

# Built-in
import os
import sys
import re
import time
import json
import random
import warnings
from typing import List, Tuple, Dict
import operator

# NumPy, Pandas, Matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sklearn
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score, pairwise_distances
)
from sklearn.exceptions import NotFittedError
from collections import defaultdict

# Flower
from flwr.client import ClientApp, NumPyClient
from flwr.common import (
    Context, NDArrays, Metrics, Scalar,
    ndarrays_to_parameters, parameters_to_ndarrays
)
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from sklearn.model_selection import train_test_split

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# LORE
from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.rule import Expression, Rule

# Otros
from graphviz import Digraph


2025-06-30 12:50:18,793	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-06-30 12:50:21,907 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.piping.pipe(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-06-30 12:50:21,907 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.rendering.render(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-06-30 12:50:21,923 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.unflattening.unflatten(['stagger', 'fanout', 'chain', 'encoding'])
2025-06-30 12:50:21,923 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.viewing.view(['quiet'])
2025-06-30 12:50:21,923 graphviz._tools DEBUG    deprecate positional args: graphviz.quoting.quote(['is_html_string', 'is_valid_id', 'dot_keywords', 'endswith_odd_number_of_backslashes', 'escape_unescaped

In [2]:
# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 2
MIN_AVAILABLE_CLIENTS = NUM_CLIENTS
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None





class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# 📥 CARGAR DATOS
# =======================

def load_data_general(flower_dataset_name: str, class_col: str, num_partitions: int):
    global UNIQUE_LABELS, FEATURES

    if "csv" in flower_dataset_name:
        dataset = pd.read_csv(flower_dataset_name)
    else:
        from flwr_datasets import FederatedDataset
        fds = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": IidPartitioner(num_partitions=NUM_CLIENTS)})
        dataset = fds.load_partition(0, "train").with_format("pandas")[:]

    # LIMPIEZA SEGÚN DATASET
    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset = dataset[~dataset["workclass"].isin([" ?"])]
        dataset = dataset[~dataset["occupation"].isin([" ?"])]
    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset['MonthlyCharges'] = pd.to_numeric(dataset['MonthlyCharges'], errors='coerce')
        dataset['tenure'] = pd.to_numeric(dataset['tenure'], errors='coerce')
        dataset.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)

    # Convierte strings a category si pocas categorías
    for col in dataset.select_dtypes(include=["object"]).columns:
        if dataset[col].nunique() < 50:
            dataset[col] = dataset[col].astype("category")

    label_encoder = LabelEncoder()
    dataset[class_col] = label_encoder.fit_transform(dataset[class_col])
    dataset.rename(columns={class_col: "class"}, inplace=True)
    y = dataset["class"].reset_index(drop=True).to_numpy()
    if not UNIQUE_LABELS:
        UNIQUE_LABELS[:] = label_encoder.classes_.tolist()

    numeric_features = list(dataset.select_dtypes(include=[np.number]).columns)
    numeric_features = [f for f in numeric_features if f != "class"]
    categorical_features = [c for c in dataset.columns if str(dataset[c].dtype) == "category"]
    FEATURES[:] = numeric_features + categorical_features

    numeric_indices = list(range(len(numeric_features)))
    categorical_indices = list(range(len(numeric_features), len(FEATURES)))
    X_array = dataset[FEATURES].to_numpy()

    preprocessor = ColumnTransformer([
        ("num", StandardScaler(), numeric_indices),
        ("cat", OneHotEncoder(sparse_output=False, handle_unknown="ignore"), categorical_indices)
    ])
    X_encoded = preprocessor.fit_transform(X_array)

    ohe = preprocessor.named_transformers_["cat"]
    onehot_feature_names = ohe.get_feature_names_out([FEATURES[i] for i in categorical_indices])
    all_feature_names = numeric_features + list(onehot_feature_names)

    # Divide en particiones para federado
    X_parts = np.array_split(X_encoded, num_partitions)
    y_parts = np.array_split(y, num_partitions)

    # En cada partición, haz split train/test y guarda en listas
    X_train_parts, X_test_parts, y_train_parts, y_test_parts = [], [], [], []
    for X_part, y_part in zip(X_parts, y_parts):
        X_tr, X_te, y_tr, y_te = train_test_split(
            X_part, y_part, test_size=0.2, random_state=42, stratify=y_part
        )
        X_train_parts.append(X_tr)
        X_test_parts.append(X_te)
        y_train_parts.append(y_tr)
        y_test_parts.append(y_te)

    # Devuelve también el DataFrame original limpio
    return (
        X_train_parts, X_test_parts, y_train_parts, y_test_parts,
        dataset,                          # <--- Añade el dataset original aquí
        all_feature_names, label_encoder, preprocessor.named_transformers_["num"], numeric_features, preprocessor
    )
# =======================



DATASET_NAME = "pablopalacios23/adult_small"
CLASS_COLUMN = "class"


# DATASET_NAME = "pablopalacios23/Iris"
# CLASS_COLUMN = "target"


# DEMASIADO GRANDE EL DATASET :/
# DATASET_NAME = "pablopalacios23/churn"
# CLASS_COLUMN = "Churn" 
 

# =======================


# load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)

In [3]:
(X_train_parts, X_test_parts, y_train_parts, y_test_parts, 
 dataset, feature_names, label_encoder, scaler, numeric_features, preprocessor) = load_data_general(
    DATASET_NAME, CLASS_COLUMN, num_partitions=NUM_CLIENTS
)

# Selecciona partición (por ejemplo, cliente 0)
X_train = X_train_parts[0]
X_test  = X_test_parts[0]
y_train = y_train_parts[0]
y_test  = y_test_parts[0]

print("\n📦 X_train (primeras filas):")
print(pd.DataFrame(X_train).head())

print("\n🎯 y_train (primeros valores):")
print(y_train[:5])

print("\n📦 X_test (primeras filas):")
print(pd.DataFrame(X_test).head())

print("\n🎯 y_test (primeros valores):")
print(y_test[:5])

print("\n🗃️ DataFrame original limpio:")
print(dataset.head())

2025-06-30 12:50:21,973 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-06-30 12:50:22,341 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/main/README.md HTTP/11" 404 0
2025-06-30 12:50:22,474 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/adult_small HTTP/11" 200 612
2025-06-30 12:50:22,613 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/475f19aed5f80dea1d48deab705f11928fe27493/adult_small.py HTTP/11" 404 0
2025-06-30 12:50:22,615 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-06-30 12:50:22,920 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/pablopalacios23/adult_small/pablopalacios23/adult_small.py HTTP/11" 404 0
2025-06-30 12:50:23,056 urllib3.connectionpool DEBUG

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

# Cliente

In [None]:
# ==========================
# 🌼 CLIENTE FLOWER
# ==========================
import operator
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from sklearn.metrics import pairwise_distances

from graphviz import Digraph

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()
        

class FlowerClient(NumPyClient):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.dataset = dataset
        self.client_id = client_id
        self.feature_names = feature_names
        self.label_encoder = label_encoder
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder
        self.unique_labels = label_encoder.classes_.tolist()
        self.y_train_nn = y_train.astype(np.int64)
        self.y_test_nn = y_test.astype(np.int64)
        self.received_supertree = None
        self.preprocessor = preprocessor

    def _train_nn(self, epochs=10, lr=0.01):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        X_tensor = torch.tensor(self.X_train, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()
        print(f"[CLIENTE {self.client_id}] ✅ Red neuronal entrenada")

    def decode_X_onehot(X, numeric_features, categorical_features, feature_names, original_categories, scaler):
        # X: array, feature_names: nombres de las columnas one-hot + numéricas

        # Desescalar numéricas
        X_num = X[:, :len(numeric_features)]
        X_num_inv = scaler.inverse_transform(X_num)
        df_num = pd.DataFrame(X_num_inv, columns=numeric_features)

        # Decodificar categóricas (buscar columna activa para cada variable original)
        df_cat = {}
        idx = len(numeric_features)
        for cat in categorical_features:
            n_values = len(original_categories[cat])
            cols = feature_names[idx:idx+n_values]
            # Para una sola muestra, np.argmax(X_cat) te dice el valor activado
            active = np.argmax(X[:, idx:idx+n_values], axis=1)
            vals = [original_categories[cat][i] for i in active]
            df_cat[cat] = vals
            idx += n_values

        df_cat = pd.DataFrame(df_cat)
        return pd.concat([df_num, df_cat], axis=1)



    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            # print(f"[CLIENTE {self.client_id}]")
            # print("X TEST instancia a explicar:")
            # print(pd.DataFrame(self.X_test).iloc[2])
            # print("y TEST:")
            # print(pd.DataFrame(self.y_test).iloc[2])

            # print(pd.DataFrame(self.X_train).shape)


            # print(f"[CLIENTE {self.client_id}]")
            # df_decoded = self.decode_X(self.X_test, self.preprocessor, self.numeric_features, self.encoder)
            
            # print("Instancia a explicar (EN EL FIT):")
            # print(df_decoded.iloc[2])

            # print("target:")
            # print(df_decoded.iloc[2][-1])

            print(self.X_train.shape, self.y_train.shape)

            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()
        
            # self.print_tree_human_readable(self.tree_model, FEATURES, self.numeric_features, self.scaler, self.encoder)
            # print("\n")


        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}
    


    def evaluate(self, parameters, config):


        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})

        if "supertree" in config:
            try:
                print("Recibiendo supertree....")
                supertree_dict = json.loads(config["supertree"])
                self.received_supertree = SuperTree.convert_SuperNode_to_Node(SuperTree.SuperNode.from_dict(supertree_dict))
                self.global_mapping = json.loads(config["global_mapping"])
                self.feature_names = json.loads(config["feature_names"])

                # self.received_supertree = SuperTree.SuperNode.from_dict(supertree_dict)

            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ❌ Error al recibir SuperTree: {e}")

        try:
            _ = self.tree_model.predict(self.X_test)
        except NotFittedError:
            self.tree_model.fit(self.X_train, self.y_train)

        supertree = SuperTree()
        root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])), len(self.unique_labels))
        round_number = config.get("server_round", 1)
        self._save_local_tree(root_node, round_number, FEATURES, self.numeric_features, self.scaler, UNIQUE_LABELS, self.encoder)
        tree_json = json.dumps([root_node.to_dict()])

        
        if self.received_supertree is not None:
            self._explain_local_and_global(config)

        return 0.0, len(self.X_test), {
            f"tree_ensemble_{self.client_id}": tree_json,
            f"scaler_mean_{self.client_id}": json.dumps(self.scaler.mean_.tolist()),
            f"scaler_std_{self.client_id}": json.dumps(self.scaler.scale_.tolist()),
            f"encoded_feature_names_{self.client_id}": json.dumps(FEATURES),
            f"numeric_features_{self.client_id}": json.dumps(self.numeric_features),
            f"unique_labels_{self.client_id}": json.dumps(self.unique_labels),
            f"encoder_descriptor_{self.client_id}": json.dumps(self.encoder.dataset_descriptor),
            f"distinct_values_{self.client_id}": json.dumps(self.encoder.dataset_descriptor["categorical"])
        }
    
    def _explain_local_and_global(self, config):
        from sklearn.metrics import accuracy_score
        import numpy as np
    
        num_row = 2

        # Reconstruir DataFrame original codificado
        # feature_cols = self.feature_names

        # 1. Visualizar instancia escalada y decodificada usando el encoder/preprocessor ORIGINAL
        # print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar (escalada):")
        # print(pd.Series(self.X_test[num_row], index=self.feature_names))

        decoded_instance = self.decode_X(
            self.X_test[num_row].reshape(1, -1),
            self.preprocessor,
            self.numeric_features,
            self.encoder
        ).iloc[0]

        print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar (decodificada):")
        print(decoded_instance)
        print(f"[CLIENTE {self.client_id}] 🧪 Clase real: {self.label_encoder.inverse_transform([self.y_test_nn[num_row]])[0]}")

        # 2. Construir DataFrame para LORE (si es necesario, solo para TabularDataset)
        local_df = pd.DataFrame(self.X_test, columns=FEATURES).astype(np.float32)
        local_df["target"] = self.label_encoder.inverse_transform(self.y_test_nn)
        local_tabular_dataset = TabularDataset(local_df, class_name="target")

        # Explicabilidad local
        nn_wrapper = TorchNNWrapper(self.nn_model)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)

        instance_scaled = local_tabular_dataset.df.iloc[num_row][:-1]



        # Explicación LORE
        explanation = lore.explain_instance(instance_scaled.astype(np.float32), merge=True, num_classes=len(UNIQUE_LABELS), feature_names= self.feature_names, categorical_features=list(self.global_mapping.keys()), global_mapping=self.global_mapping)
        lore_tree = explanation["merged_tree"]
        round_number = config.get("server_round", 1)

        encoder_for_print = {
            "categorical": {
                k: {"distinct_values": v} for k, v in self.global_mapping.items()
            }
        }


        self.print_tree_readable(
            lore_tree.root,
            self.feature_names,
            self.unique_labels,
            numeric_features=self.numeric_features,
            scaler={
                "mean": [self.scaler.mean_[self.numeric_features.index(f)] for f in self.numeric_features],
                "std": [self.scaler.scale_[self.numeric_features.index(f)] for f in self.numeric_features]
            },
            encoder=encoder_for_print
        )

        
        self._save_lore_tree(lore_tree.root, round_number)



    def pretty_print_rule(rule, feature_names):
        conditions = []
        for i, cond in enumerate(rule):
            # Si es columna onehot
            for feat in feature_names:
                if cond.startswith(feat) and "_"+feat in cond:
                    var, val = feat.split('_', 1)
                    conditions.append(f'{var} = "{val}"')
        print(" AND ".join(conditions))



    
    
    def print_tree_readable(self, node, feature_names, class_names, numeric_features, scaler, encoder, depth=0):
        indent = "|   " * depth
        if node.is_leaf:
            class_idx = int(np.argmax(node.labels))
            print(f"{indent}|--- class: {class_names[class_idx]}")
            return

        feat_name = feature_names[node.feat]
        base_feat = feat_name.split("=")[0] if "=" in feat_name else feat_name

        if base_feat in encoder["categorical"]:
            val_idx = int(node.thresh)
            try:
                val = encoder["categorical"][base_feat]["distinct_values"][val_idx]
            except IndexError:
                val = f"[desconocido ({val_idx})]"
            print(f"{indent}|--- {base_feat} <= \"{val}\"")
            self.print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
            print(f"{indent}|--- {base_feat} > \"{val}\"")
            self.print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)

        elif base_feat in numeric_features:
            idx = numeric_features.index(base_feat)
            threshold = node.thresh * scaler["std"][idx] + scaler["mean"][idx]
            print(f"{indent}|--- {base_feat} <= {threshold:.2f}")
            self.print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
            print(f"{indent}|--- {base_feat} > {threshold:.2f}")
            self.print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)


            

    
    def _save_local_tree(self, root_node, round_number, feature_names, numeric_features, scaler, unique_labels, encoder, tree_type= "LocalTree"):
        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            return feat.split('=')[0] if '=' in feat else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = feature_names[node.feat]
                    label = base_name(fname)
                except:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Árbol tipo SuperTree
            if hasattr(node, "children") and node.children is not None and hasattr(node, "intervals"):
                for i, child in enumerate(node.children):
                    try:
                        fname = feature_names[node.feat]
                    except:
                        fname = f"X_{node.feat}"

                    original_feat = base_name(fname)
                    if original_feat in encoder.dataset_descriptor["categorical"]:
                        val_idx = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val_idx = int(val_idx)
                        val = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"][val_idx] if val_idx < len(encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]) else f"desconocido({val_idx})"
                        edge = f'≠ "{val}"' if i == 0 else f'= "{val}"'
                    elif original_feat in numeric_features:
                        idx = numeric_features.index(original_feat)
                        mean = scaler.mean_[idx]
                        std = scaler.scale_[idx]
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val = val * std + mean
                        edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"
                    else:
                        edge = "?"

                    add_node(child, curr, edge)

            elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                try:
                    fname = feature_names[node.feat]
                except:
                    fname = f"X_{node.feat}"

                original_feat = base_name(fname)
                if original_feat in encoder.dataset_descriptor["categorical"]:
                    val_idx = int(node.thresh)
                    val = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"][val_idx] if val_idx < len(encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]) else f"desconocido({val_idx})"
                    left_label = f'= "{val}"'
                    right_label = f'≠ "{val}"'
                elif original_feat in numeric_features:
                    idx = numeric_features.index(original_feat)
                    mean = scaler.mean_[idx]
                    std = scaler.scale_[idx]
                    thresh = node.thresh * std + mean
                    left_label = f"<= {thresh:.2f}"
                    right_label = f"> {thresh:.2f}"
                else:
                    left_label = "≤ ?"
                    right_label = "> ?"

                if node._left_child:
                    add_node(node._left_child, curr, left_label)
                if node._right_child:
                    add_node(node._right_child, curr, right_label)

        add_node(root_node)
        folder = f"Ronda_{round_number}/{tree_type}_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        filepath = f"{folder}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)

    def _save_lore_tree(self, root_node, round_number):
         self._save_generic_tree(
            root_node,
            round_number,
            tree_type="LoreTree",
            feature_names=self.feature_names,
            categorical_features=list(self.global_mapping.keys()),
            global_mapping=self.global_mapping,
            scaler=self.scaler,
            numeric_features=self.numeric_features
        )

    def _save_merged_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="MergedTree"
        )

    def _save_generic_tree(self, root_node, round_number, tree_type, feature_names=None,categorical_features=None,global_mapping=None,scaler=None,numeric_features=None):
        from graphviz import Digraph
        import numpy as np
        import os

        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            return feat.split('=')[0] if '=' in feat else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = self.unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = feature_names[node.feat] if feature_names else f"X_{node.feat}"
                    label = base_name(fname)
                except Exception:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Árbol tipo SuperTree (ramas como lista children+intervals)
            if hasattr(node, "children") and node.children is not None and hasattr(node, "intervals"):
                for i, child in enumerate(node.children):
                    try:
                        fname = feature_names[node.feat] if feature_names else f"X_{node.feat}"
                    except Exception:
                        fname = f"X_{node.feat}"

                    original_feat = base_name(fname)
                    edge = "?"

                    # Categórica
                    if categorical_features and original_feat in categorical_features and global_mapping:
                        idx = node.intervals[i]
                        if original_feat in global_mapping and idx < len(global_mapping[original_feat]):
                            val_real = global_mapping[original_feat][idx]
                            if len(node.children) == 2:
                                edge = f"= {val_real}" if i == 0 else f"≠ {val_real}"
                            else:
                                edge = f"= {val_real}"
                        else:
                            edge = f"= {idx}"
                    # Numérica
                    elif scaler and numeric_features and original_feat in numeric_features:
                        idx = numeric_features.index(original_feat)
                        mean = scaler.mean_[idx]
                        std = scaler.scale_[idx]
                        val = node.intervals[i]
                        val_real = val * std + mean
                        if len(node.children) == 2:
                            edge = f"<= {val_real:.2f}" if i == 0 else f"> {val_real:.2f}"
                        else:
                            edge = f"= {val_real:.2f}"
                    # Genérico
                    else:
                        val = node.intervals[i]
                        if len(node.children) == 2:
                            edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"
                        else:
                            edge = f"= {val:.2f}"

                    add_node(child, curr, edge)

            # Árbol binario clásico (ramas como _left_child/_right_child)
            elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                try:
                    fname = feature_names[node.feat] if feature_names else f"X_{node.feat}"
                except Exception:
                    fname = f"X_{node.feat}"

                original_feat = base_name(fname)
                # Categórica
                if categorical_features and original_feat in categorical_features and global_mapping:
                    idx = int(node.thresh) if node.thresh is not None else None
                    if idx is not None and original_feat in global_mapping and idx < len(global_mapping[original_feat]):
                        val_real = global_mapping[original_feat][idx]
                        left_label = f"= {val_real}"
                        right_label = f"≠ {val_real}"
                    else:
                        left_label = "= ?"
                        right_label = "≠ ?"
                # Numérica
                elif scaler and numeric_features and original_feat in numeric_features:
                    idx = numeric_features.index(original_feat)
                    mean = scaler.mean_[idx]
                    std = scaler.scale_[idx]
                    thresh = node.thresh * std + mean if node.thresh is not None else None
                    left_label = f"<= {thresh:.2f}" if thresh is not None else "≤ ?"
                    right_label = f"> {thresh:.2f}" if thresh is not None else "> ?"
                # Genérico
                else:
                    thresh = node.thresh
                    left_label = f"<= {thresh:.2f}" if thresh is not None else "≤ ?"
                    right_label = f"> {thresh:.2f}" if thresh is not None else "> ?"

                if hasattr(node, "_left_child") and node._left_child:
                    add_node(node._left_child, curr, left_label)
                if hasattr(node, "_right_child") and node._right_child:
                    add_node(node._right_child, curr, right_label)

        add_node(root_node)
        folder = f"Ronda_{round_number}/{tree_type}_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        filepath = f"{folder}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)



def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]

    print("partition_id:", partition_id)
    
    num_partitions = context.node_config["num-partitions"]

    dataset_name = DATASET_NAME 
    class_col = CLASS_COLUMN 

    (X_parts, y_parts, feature_names, label_encoder, scaler, numeric_features, preprocessor) = load_data_general(
        flower_dataset_name=dataset_name,
        class_col=class_col,
        num_partitions=num_partitions
    )
    X_train = X_parts[partition_id]
    y_train = y_parts[partition_id]

    input_dim = X_train.shape[1]
    output_dim = len(np.unique(y_train))
    
    print("Shape X_train:", X_train.shape)
    print("Features:", feature_names)
    print("Modelo input_dim:", input_dim)

    tree_model = DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)
    nn_model = Net(input_dim, output_dim)
    return FlowerClient(
        tree_model=tree_model,
        nn_model=nn_model,
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        dataset=dataset,
        client_id=partition_id + 1,
        feature_names=feature_names,
        label_encoder=label_encoder,
        scaler=scaler,
        numeric_features=numeric_features,
        encoder=encoder,
        preprocessor=preprocessor
    ).to_client()

client_app = ClientApp(client_fn=client_fn)


# Servidor

In [None]:
# ============================
# 📦 IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================
# MIN_AVAILABLE_CLIENTS = 4
# NUM_SERVER_ROUNDS = 2

FEATURES = []  # se rellenan dinámicamente
UNIQUE_LABELS = []
LATEST_SUPERTREE_JSON = None
GLOBAL_MAPPING_JSON = None
FEATURE_NAMES_JSON = None


# ============================
# 🧠 UTILIDADES MODELO
# ============================
def create_model(input_dim, output_dim):
    from __main__ import Net  # necesario si Net está en misma libreta
    return Net(input_dim, output_dim)


def get_model_parameters(tree_model, nn_model):
    tree_params = [-1, 2, 1]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    total = sum(n for n, _ in metrics)
    avg: Dict[str, List[float]] = {}
    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                avg.setdefault(k, []).append(n * float(v))
    return {k: sum(vs) / total for k, vs in avg.items()}

# ============================
# 🚀 SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    global FEATURES, UNIQUE_LABELS

    # Justo antes de llamar a create_model
    if not FEATURES or not UNIQUE_LABELS:
        
        load_data_general(DATASET_NAME, CLASS_COLUMN, num_partitions=NUM_CLIENTS)


    FEATURES = FEATURES or ["feat_0", "feat_1"]  # fallback por si no se cargó antes
    UNIQUE_LABELS = UNIQUE_LABELS or ["Class_0", "Class_1"]


    model = create_model(len(FEATURES), len(UNIQUE_LABELS))
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])

    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)
    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)

        try:
            print(f"\n[SERVIDOR] 🌲 Generando SuperTree - Ronda {server_round}")
            tree_dicts = []
            all_distincts = defaultdict(set)
            client_encoders = {}

            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                for key, value in metrics.items():
                    if key.startswith("distinct_values_"):
                        client_id = key.split("_")[-1]
                        client_encoders[client_id] = json.loads(value)
                        for feat, d in client_encoders[client_id].items():
                            all_distincts[feat].update(d["distinct_values"])

            global_mapping = {feat: sorted(list(vals)) for feat, vals in all_distincts.items()}

            for (_, evaluate_res) in results:
                metrics = evaluate_res.metrics
                for key, value in metrics.items():
                    if key.startswith("tree_ensemble_"):
                        client_id = key.split("_")[-1]
                        trees_list = json.loads(value)
                        local_encoder = client_encoders[client_id]
                        feature_names = json.loads(metrics.get(f"encoded_feature_names_{client_id}"))
                        numeric_features = json.loads(metrics.get(f"numeric_features_{client_id}"))
                        unique_labels = json.loads(metrics.get(f"unique_labels_{client_id}"))
                        scaler = {
                            "mean": json.loads(metrics.get(f"scaler_mean_{client_id}")),
                            "std": json.loads(metrics.get(f"scaler_std_{client_id}")),
                        }

                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)

                            def normalize_thresholds(node):
                                if node is None or node.is_leaf:
                                    return

                                fname = feature_names[node.feat]

                                if fname in numeric_features:
                                    try:
                                        idx = numeric_features.index(fname)
                                        real_val = node.thresh * scaler["std"][idx] + scaler["mean"][idx]
                                        node.real_thresh = real_val  # ⬅️ Lo guardamos
                                    except Exception as e:
                                        print(f"[WARNING] No se pudo desescalar {fname}: {e}")

                                elif fname in local_encoder and fname in global_mapping:
                                    try:
                                        local_vals = local_encoder[fname]["distinct_values"]
                                        real_val = local_vals[int(node.thresh)]
                                        node.thresh = global_mapping[fname].index(real_val)
                                    except Exception as e:
                                        print(f"[WARNING] No se pudo normalizar {fname}: {e}")

                                normalize_thresholds(getattr(node, "_left_child", None))
                                normalize_thresholds(getattr(node, "_right_child", None))

                            normalize_thresholds(root)

                            encoder_for_print = {
                                "categorical": {
                                    k: {"distinct_values": v}
                                    for k, v in global_mapping.items()
                                    if k in FEATURES and k not in numeric_features
                                }
                            }

                            # print(f"\n[CLIENTE {client_id}] 🌳 Árbol normalizado:")
                            # print_tree_readable(
                            #     root,
                            #     feature_names,
                            #     unique_labels,
                            #     numeric_features=numeric_features,
                            #     scaler={
                            #         "mean": [scaler["mean"][i] for i, f in enumerate(feature_names) if f in numeric_features],
                            #         "std": [scaler["std"][i] for i, f in enumerate(feature_names) if f in numeric_features]
                            #     },
                            #     encoder=encoder_for_print
                            # )
                            # print("Local tree del cliente", client_id)
                            # print(root.to_dict())
                            tree_dicts.append(root)
                            
            # print(tree_dicts)
            
            if not tree_dicts:
                print("[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.")
                return aggregated_metrics
            
                        
            supertree = SuperTree()
            # print("feature_names: ", feature_names)
            # print("global_mapping: ", global_mapping)
            # print("global_mapping keys: ", list(global_mapping.keys()))
            
            supertree.mergeDecisionTrees(tree_dicts, num_classes=len(UNIQUE_LABELS), feature_names=feature_names, categorical_features=list(global_mapping.keys()), global_mapping=global_mapping)
            supertree.prune_redundant_leaves_full()
            supertree.merge_equal_class_leaves()

            # print("\n[SERVIDOR] 🌳 SuperTree legible:")
            
            # print_supertree_readable_fusionado(
            #     node=supertree.root,
            #     global_mapping=global_mapping,
            #     feature_names=feature_names,
            #     class_names=UNIQUE_LABELS,
            #     numeric_features=numeric_features,
            #     scaler=scaler  # sin acceder a .mean_ ni .scale_
            # )

            _save_supertree_plot(
                root_node=supertree.root,
                round_number=server_round,
                feature_names=feature_names,
                class_names=UNIQUE_LABELS,
                scaler_means=scaler["mean"],
                scaler_stds=scaler["std"],
                global_mapping=global_mapping,
                numeric_features=numeric_features
            )

            LATEST_SUPERTREE_JSON = json.dumps(supertree.root.to_dict())
            GLOBAL_MAPPING_JSON = json.dumps(global_mapping)
            FEATURE_NAMES_JSON = json.dumps(feature_names)

        except Exception as e:
            print(f"[SERVIDOR] ❌ Error en SuperTree: {e}")

        time.sleep(3)
        return aggregated_metrics

    strategy.aggregate_evaluate = custom_aggregate_evaluate
    return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=NUM_SERVER_ROUNDS))

# ============================
# 🧩 FUNCIONES AUXILIARES
# ============================
def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON, GLOBAL_MAPPING_JSON, FEATURE_NAMES_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round
            
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
                ins.config["global_mapping"] = GLOBAL_MAPPING_JSON
                ins.config["feature_names"] = FEATURE_NAMES_JSON
                
        return instructions
    return wrapper

def print_supertree_readable_fusionado(node, global_mapping, feature_names, class_names, numeric_features, scaler, depth=0):
    if node is None:
        print(f"{'|   ' * depth}|--- [Nodo None]")
        return

    indent = "|   " * depth

    if node.is_leaf:
        class_idx = int(np.argmax(node.labels))
        print(f"{indent}|--- class: {class_names[class_idx]}")
        return

    feat_name = feature_names[node.feat]
    base_feat = feat_name.split("=")[0] if "=" in feat_name else feat_name

    is_numeric = base_feat in numeric_features
    is_categorical = base_feat in global_mapping and not is_numeric

    for i, child in enumerate(node.children):
        if hasattr(node, "intervals") and node.intervals is not None:
            if i < len(node.intervals):
                val_idx = node.intervals[i]
            else:
                val_idx = node.intervals[0]
        else:
            val_idx = i  # fallback

        if is_numeric:
            idx = numeric_features.index(base_feat)
            if hasattr(node, "real_thresh"):
                real_val = node.real_thresh
            else:
                real_val = val_idx * scaler["std"][idx] + scaler["mean"][idx]
            op = "≤" if i == 0 else ">"
            print(f"{indent}|--- {base_feat} {op} {real_val:.2f}")
        elif is_categorical:
            values = global_mapping[base_feat]
            val_idx = int(val_idx)
            val = values[val_idx] if 0 <= val_idx < len(values) else f"[desconocido {val_idx}]"
            op = "=" if i == 0 else "≠"
            print(f"{indent}|--- {base_feat} {op} \"{val}\"")
        else:
            print(f"{indent}|--- {base_feat} [tipo desconocido]")

        # Recursivo
        print_supertree_readable_fusionado(
            node=child,
            global_mapping=global_mapping,
            feature_names=feature_names,
            class_names=class_names,
            numeric_features=numeric_features,
            scaler=scaler,
            depth=depth + 1
        )


def print_tree_readable(node, feature_names, class_names, numeric_features, scaler, encoder, depth=0):
    indent = "|   " * depth
    if node.is_leaf:
        class_idx = int(np.argmax(node.labels))
        print(f"{indent}|--- class: {class_names[class_idx]}")
        return

    feat_name = feature_names[node.feat]
    base_feat = feat_name.split("=")[0] if "=" in feat_name else feat_name

    if base_feat in encoder["categorical"]:
        val_idx = int(node.thresh)
        try:
            val = encoder["categorical"][base_feat]["distinct_values"][val_idx]
        except IndexError:
            val = f"[desconocido ({val_idx})]"
        print(f"{indent}|--- {base_feat} <= \"{val}\"")
        print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
        print(f"{indent}|--- {base_feat} > \"{val}\"")
        print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)

    elif base_feat in numeric_features:
        idx = numeric_features.index(base_feat)
        threshold = node.thresh * scaler["std"][idx] + scaler["mean"][idx]
        print(f"{indent}|--- {base_feat} <= {threshold:.2f}")
        print_tree_readable(node._left_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)
        print(f"{indent}|--- {base_feat} > {threshold:.2f}")
        print_tree_readable(node._right_child, feature_names, class_names, numeric_features, scaler, encoder, depth + 1)


def _save_supertree_plot(root_node, round_number, feature_names=None, class_names=None,
                         scaler_means=None, scaler_stds=None,
                         global_mapping=None, numeric_features=None):
    dot = Digraph()
    node_id = [0]

    def base_name(feat):
        return feat.split('=')[0] if '=' in feat else feat

    def add_node(node, parent=None, label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        if node.is_leaf:
            class_index = int(np.argmax(node.labels))
            class_label = class_names[class_index] if class_names else f"Clase {class_index}"
            label_text = f"Clase: {class_label}\n{node.labels}"
        else:
            try:
                fname = feature_names[node.feat]
                label_text = base_name(fname)
            except Exception:
                label_text = f"X_{node.feat}"

        dot.node(curr, label_text)
        if parent:
            dot.edge(parent, curr, label=label)

        if not node.is_leaf:
            for i, child in enumerate(node.children):
                try:
                    fname = feature_names[node.feat]
                    base = base_name(fname)
                except Exception:
                    fname = f"X_{node.feat}"
                    base = fname

                if global_mapping and base in global_mapping:
                    # Variable categórica
                    values = global_mapping[base]
                    try:
                        val_idx = node.intervals[i] if node.intervals and i < len(node.intervals) else node.thresh
                        val = values[int(val_idx)] if int(val_idx) < len(values) else f"? ({val_idx})"
                    except Exception:
                        val = "?"
                    edge_label = f'= "{val}"' if i == 0 else f'≠ "{val}"'

                elif numeric_features and base in numeric_features:
                    # Variable numérica
                    try:
                        idx = feature_names.index(base)
                        raw_val = node.intervals[i] if node.intervals and i < len(node.intervals) else node.thresh
                        val = raw_val * scaler_stds[idx] + scaler_means[idx] if scaler_means and scaler_stds else raw_val
                        edge_label = f"≤ {val:.2f}" if i == 0 else f"> {val:.2f}"
                    except Exception:
                        edge_label = "?"

                else:
                    edge_label = "?"

                add_node(child, curr, edge_label)

    add_node(root_node)
    folder = f"Ronda_{round_number}/Supertree"
    os.makedirs(folder, exist_ok=True)
    filename = f"{folder}/supertree_ronda_{round_number}"
    dot.render(filename, format="png", cleanup=True)
    return f"{filename}.png"

# ============================
# 🔧 INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)



In [None]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesión previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)


2025-06-30 12:43:33,670	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-06-30 12:43:37,120 flwr         DEBUG    Asyncio event loop already running.
2025-06-30 12:43:37,140 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
2025-06-30 12:43:37,292 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/main/README.md HTTP/11" 404 0


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


2025-06-30 12:43:37,412 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/adult_small HTTP/11" 200 612
2025-06-30 12:43:37,533 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/475f19aed5f80dea1d48deab705f11928fe27493/adult_small.py HTTP/11" 404 0
2025-06-30 12:43:37,533 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-06-30 12:43:37,835 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/pablopalacios23/adult_small/pablopalacios23/adult_small.py HTTP/11" 404 0
2025-06-30 12:43:37,956 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/475f19aed5f80dea1d48deab705f11928fe27493/README.md HTTP/11" 404 0
2025-06-30 12:43:38,078 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/adult_sm

partition_id:partition_id: 1
 0
partition_id: 0
partition_id: 1


[92mINFO [0m:      aggregate_evaluate: received 0 results and 2 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=16296, ip=127.0.0.1, actor_id=7808a3678f11f291ffb7abdc01000000, repr=<flwr.simulation.ray_transport.ray_actor._modify_class.<locals>.Class object at 0x000002E6A2A269F0>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 143, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\sit


[SERVIDOR] 🌲 Generando SuperTree - Ronda 1
[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.
partition_id: 0
partition_id: 1


[92mINFO [0m:      aggregate_fit: received 0 results and 2 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=16296, ip=127.0.0.1, actor_id=c2678289c86a2d620cf4a23b01000000, repr=<flwr.simulation.ray_transport.ray_actor._modify_class.<locals>.Class object at 0x000002E6A2A267B0>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 143, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "

partition_id: 0
partition_id: 1

[SERVIDOR] 🌲 Generando SuperTree - Ronda 2
[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.


