In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================
import warnings
import time
import sys
import random
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score, 
    f1_score, confusion_matrix, roc_auc_score
)

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context, NDArrays, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

from graphviz import Digraph

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder.tabular_enc import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OrdinalEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F

# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 2
MIN_AVAILABLE_CLIENTS = 2
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}


class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# 📥 CARGAR DATOS
# =======================

def get_feature_names_from_column_transformer(preprocessor):
    output_features = []

    for name, trans, cols in preprocessor.transformers_:
        if hasattr(trans, 'get_feature_names_out'):
            names = trans.get_feature_names_out()
            # Quitar prefijos tipo "num__age"
            names = [n.split('__')[-1] for n in names]
        else:
            names = cols  # Fallback
        output_features.extend(names)

    return output_features

def load_data(partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES, CAT_ENCODINGS
    
    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(dataset="pablopalacios23/adult_small", partitioners={"train": partitioner})

    # Cargar y preparar dataset
    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]
    # dataset = dataset.applymap(lambda x: np.nan if isinstance(x, str) and x.strip() == "?" else x)
    # dataset.dropna(inplace=True)
    target_column = dataset.columns[-1]

    # Codificar clases
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(dataset[target_column])
    dataset[target_column] = label_encoder.inverse_transform(y_encoded)
    dataset.rename(columns={target_column: "class"}, inplace=True)

    # Guardar etiquetas únicas
    if not UNIQUE_LABELS:
        UNIQUE_LABELS[:] = label_encoder.classes_.tolist()
        print("UNIQUE_LABELS:", UNIQUE_LABELS)

    # Eliminar columnas no útiles
    dataset.drop(['fnlwgt', 'education-num', 'capital-gain', 'capital-loss'], axis=1, inplace=True)
    

    # TabularDataset con clases legibles
    tabular_dataset = TabularDataset(dataset.copy(), class_name="class")
    descriptor = tabular_dataset.descriptor

    numeric_features = list(descriptor["numeric"].keys())
    categorical_features = list(descriptor["categorical"].keys())

    # Guardar nombres de features
    if not FEATURES:
        FEATURES[:] = numeric_features + categorical_features
        print("FEATURES:", FEATURES)

    # Codificar X con ColumnTransformer
    scaler = StandardScaler()
    
    preprocessor = ColumnTransformer(
        transformers=[
            ("num", scaler, numeric_features),
            ("cat", OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1), categorical_features),
        ]
    )


    X = dataset[FEATURES]
    scaler.fit(X[numeric_features])  # Ajustar antes de guardar
    X_encoded = preprocessor.fit_transform(X)

    encoder = ColumnTransformerEnc(tabular_dataset.descriptor)
    feature_names = list(encoder.encoded_features.values())

    # Volver a codificar la clase a enteros
    dataset["class"] = y_encoded
    y = y_encoded

    # Separar train/test
    split_idx = int(0.8 * len(X_encoded))
    X_train, X_test = X_encoded[:split_idx], X_encoded[split_idx:]
    y_train, y_test = y[:split_idx], y[split_idx:]

    return X_train, y_train, X_test, y_test, tabular_dataset, feature_names, scaler, numeric_features, label_encoder, encoder


# =======================
# 🧪 PRUEBA DE CARGA LOCAL (solo en ejecución directa)
# =======================

if __name__ == "__main__":
    X_train, y_train, X_test, y_test, dataset, feature_names, scaler, numeric_features, label_encoder, encoder = load_data(partition_id=0, num_partitions=NUM_CLIENTS)


print(dataset.df.head())


2025-06-13 09:30:20,332	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-06-13 09:30:25,535 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-06-13 09:30:25,913 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/main/README.md HTTP/11" 404 0
2025-06-13 09:30:26,648 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/adult_small HTTP/11" 200 612
2025-06-13 09:30:26,802 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/475f19aed5f80dea1d48deab705f11928fe27493/adult_small.py HTTP/11" 404 0
2025-06-13 09:30:26,806 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-06-13 09:30:27,108 urllib3.connectionpool DEBUG    ht

UNIQUE_LABELS: [' <=50K', ' >50K']
FEATURES: ['age', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
   age     workclass      education       marital-status        occupation  \
0   30       Private   Some-college        Never-married             Sales   
1   58       Private      Bachelors   Married-civ-spouse   Exec-managerial   
2   50   Federal-gov   Some-college             Divorced      Adm-clerical   
3   33       Private        HS-grad   Married-civ-spouse             Sales   
4   62       Private   Some-college   Married-civ-spouse      Adm-clerical   

     relationship    race      sex  hours-per-week  native-country   class  
0   Not-in-family   White     Male              50   United-States    >50K  
1         Husband   White     Male              45   United-States    >50K  
2       Unmarried   White   Female              40   United-States   <=50K  
3         Husband   White     Male           

# Definir el cliente federado con Flower

In [2]:
# ==========================
# 🌼 CLIENTE FLOWER (ADULT)
# ==========================
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from graphviz import Digraph

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()

class FlowerClient(NumPyClient):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id, feature_names, scaler, numeric_features, label_encoder, encoder):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.dataset = dataset
        self.label_encoder = label_encoder
        self.unique_labels = UNIQUE_LABELS
        self.y_train_nn = y_train.astype(np.int64)
        self.client_id = client_id
        self.received_supertree = None
        self.feature_names = feature_names
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder

    def _train_nn(self, epochs=10, lr=0.01):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        X_tensor = torch.tensor(self.X_train, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()

        print(f"[CLIENTE {self.client_id}] ✅ Red neuronal entrenada")

    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()
        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}

    def evaluate(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})

        if "supertree" in config:
            try:
                supertree_dict = json.loads(config["supertree"])
                self.received_supertree = SuperTree.convert_SuperNode_to_Node(SuperTree.SuperNode.from_dict(supertree_dict))
            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ❌ Error al recibir SuperTree: {e}")

        try:
            _ = self.tree_model.predict(self.X_test)
        except NotFittedError:
            self.tree_model.fit(self.X_train, self.y_train)

        y_pred = self.tree_model.predict(self.X_test)
        y_proba = self.tree_model.predict_proba(self.X_test)

        supertree = SuperTree()
        root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])), len(self.unique_labels))
        round_number = config.get("server_round", 1)
        self._save_local_tree(root_node, round_number)
        tree_json = json.dumps([root_node.to_dict()])

        if self.received_supertree is not None:
            self._explain_local_and_global(config)

        return float(log_loss(self.y_test, y_proba)), len(self.X_test), {
            "Accuracy": accuracy_score(self.y_test, y_pred),
            "Precision": precision_score(self.y_test, y_pred, average="weighted", zero_division=1),
            "Recall": recall_score(self.y_test, y_pred, average="weighted"),
            "F1_Score": f1_score(self.y_test, y_pred, average="weighted"),
            "AUC": roc_auc_score(self.y_test, y_proba[:, 1]),
            "tree_ensemble": tree_json,
            "scaler_mean": json.dumps(self.scaler.mean_.tolist()),
            "scaler_std": json.dumps(self.scaler.scale_.tolist()),
            "encoded_feature_names": json.dumps(self.feature_names)
        }
    
    def _explain_local_and_global(self, config): 
        num_row = 5

        local_df = pd.DataFrame(self.X_train, columns=self.dataset.df.columns[:-1]).astype(np.float32)
        local_df["target"] = self.label_encoder.inverse_transform(self.y_train_nn)

        local_tabular_dataset = TabularDataset(local_df, class_name="target")
        descriptor = local_tabular_dataset.get_descriptor()

        encoder = ColumnTransformerEnc(descriptor)
        encoder.set_classes(self.unique_labels)
        self.encoder = encoder  # guardar por si quieres usar fuera

        # Local explicabilidad (LORE)
        nn_wrapper = TorchNNWrapper(self.nn_model)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)

        instance = local_tabular_dataset.df.iloc[num_row][:-1]
        target = local_tabular_dataset.df.iloc[num_row][-1]

        instance_array = instance.values.reshape(1, -1).astype(np.float32)
        pred_idx = self.nn_model(torch.tensor(instance_array)).argmax(dim=1).item()
        pred_label = self.label_encoder.inverse_transform([pred_idx])[0]
        print(f"[CLIENTE {self.client_id}] 🤖 Predicción de la red neuronal: {pred_label}")

        explanation = lore.explain_instance(instance.astype(np.float32), merge=True)
        lore_tree = explanation["merged_tree"]
        round_number = config.get("server_round", 1)
        self._save_lore_tree(lore_tree.root, round_number)

        # Fusionar con supertree
        merged_tree = SuperTree()
        node_LORE_Tree = SuperTree.convert_SuperNode_to_Node(lore_tree.root)
        merged_root = merged_tree.mergeDecisionTrees(
            roots=[node_LORE_Tree, self.received_supertree],
            num_classes=len(self.unique_labels),
            feature_names=self.dataset.df.columns[:-1].tolist()
        )
        merged_tree.root = merged_root
        merged_tree.prune_redundant_leaves_full()
        merged_tree.merge_equal_class_leaves()
        self._save_merged_tree(merged_tree.root, round_number)

        z_encoded = encoder.encode([instance.values])[0]
        z_encoded = np.array([z_encoded], dtype=np.float32)
        decoded_instance = encoder.decode(z_encoded)[0]

        rule = merged_tree.get_rule(z=z_encoded[0], encoder=encoder)
        crules, _ = merged_tree.get_counterfactual_rules_merged(z_encoded[0], encoder)

        print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar:")
        print(pd.Series(decoded_instance, index=self.dataset.df.columns[:-1]))

        print(f" [CLIENTE {self.client_id}] 🧪 Clase real: {target}")

        print(f"\n [CLIENTE {self.client_id}] 📜 Regla de explicación del árbol fusionado:")
        for p in rule.premises:
            op = p.operator.__name__.replace("le", "≤").replace("lt", "<").replace("ge", "≥").replace("gt", ">").replace("eq", "=")
            value = p.value
            if isinstance(value, str):
                print(f"   - {p.variable} {op} {value}")
            else:
                print(f"   - {p.variable} {op} {value:.3f}")
        print(f" ⇒ {rule.consequences.variable} = {rule.consequences.value}")

        actual_class = rule.consequences.value

        print(f"\n🧬 [CLIENTE {self.client_id}] Contrafactuales sugeridos:")
        idx = 1
        for cf in crules:
            if cf.consequences.value == actual_class:
                continue
            print(f"\n  ⚡ Contrafactual #{idx}:")
            for p in cf.premises:
                op = p.operator.__name__.replace("le", "≤").replace("lt", "<").replace("ge", "≥").replace("gt", ">").replace("eq", "=")
                value = p.value
                if isinstance(value, str):
                    print(f"   - {p.variable} {op} {value}")
                else:
                    print(f"   - {p.variable} {op} {value:.3f}")
            print(f"   ⇒ {cf.consequences.variable} = {cf.consequences.value}")
            idx += 1

        

    def _save_local_tree(self, root_node, round_number):
        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            return feat.split('=')[0] if '=' in feat else feat

        def add_node(node, parent_id=None, edge_label=""):
            curr_id = str(node_id[0])
            node_id[0] += 1

            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = str(self.unique_labels[class_index])
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = self.feature_names[node.feat]
                    label = base_name(fname)
                except:
                    fname = f"X_{node.feat}"
                    label = fname

            dot.node(curr_id, label)

            if parent_id:
                dot.edge(parent_id, curr_id, label=edge_label)

            if not node.is_leaf:
                if hasattr(node, "intervals"):  # SuperNode con múltiples hijos
                    for i, child in enumerate(node.children):
                        try:
                            fname = self.feature_names[node.feat]
                        except:
                            fname = f"X_{node.feat}"

                        if '=' in fname:
                            attr, val = fname.split('=')
                            edge = f"= {val}" if i == 1 else f"≠ {val}"
                        else:
                            # Desescalar valores numéricos
                            original_feat = base_name(fname)
                            if hasattr(self, "scaler") and original_feat in self.numeric_features:
                                idx = self.numeric_features.index(original_feat)
                                mean = self.scaler.mean_[idx]
                                std = self.scaler.scale_[idx]
                                if i == 0:
                                    edge_val = node.intervals[i] * std + mean
                                    edge = f"<= {edge_val:.2f}"
                                else:
                                    edge_val = node.intervals[i - 1] * std + mean
                                    edge = f"> {edge_val:.2f}"
                            else:
                                if i == 0:
                                    edge = f"<= {node.intervals[i]:.2f}"
                                else:
                                    edge = f"> {node.intervals[i - 1]:.2f}"

                        add_node(child, curr_id, edge)
                else:  # Binario (árbol sklearn puro)
                    try:
                        fname = self.feature_names[node.feat]
                    except:
                        fname = f"X_{node.feat}"

                    if '=' in fname:
                        attr, val = fname.split('=')
                        left_label = f"≠ {val}"
                        right_label = f"= {val}"
                    else:
                        original_feat = base_name(fname)
                        if hasattr(self, "scaler") and original_feat in self.numeric_features:
                            idx = self.numeric_features.index(original_feat)
                            mean = self.scaler.mean_[idx]
                            std = self.scaler.scale_[idx]
                            thresh = node.thresh * std + mean
                            left_label = f"<= {thresh:.2f}"
                            right_label = f"> {thresh:.2f}"
                        else:
                            left_label = f"<= {node.thresh:.2f}"
                            right_label = f"> {node.thresh:.2f}"

                    if node._left_child:
                        add_node(node._left_child, curr_id, left_label)
                    if node._right_child:
                        add_node(node._right_child, curr_id, right_label)

        add_node(root_node)
        folder = f"Ronda_{round_number}/Arbol_Local_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        dot.render(f"{folder}/arbol_local_cliente_{self.client_id}_ronda_{round_number}", format="png", cleanup=True)

    def _save_lore_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="LoreTree"
        )

    def _save_merged_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="MergedTree"
        )

    def _save_generic_tree(self, root_node, round_number, tree_type):
        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            return feat.split('=')[0] if '=' in feat else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1

            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = self.unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = self.feature_names[node.feat]
                    label = base_name(fname)
                except:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            if not node.is_leaf:
                for i, child in enumerate(node.children):
                    try:
                        fname = self.feature_names[node.feat]
                    except:
                        fname = f"X_{node.feat}"

                    if '=' in fname:
                        attr, val = fname.split('=')
                        edge = f"= {val}" if i == 1 else f"≠ {val}"
                    else:
                        original_feat = base_name(fname)
                        if original_feat in self.numeric_features:
                            idx = self.numeric_features.index(original_feat)
                            mean = self.scaler.mean_[idx]
                            std = self.scaler.scale_[idx]
                            val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                            val = val * std + mean
                            edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"
                        else:
                            val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                            edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"

                    add_node(child, curr, edge)

        add_node(root_node)
        folder = f"Ronda_{round_number}/{tree_type}_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        filepath = f"{folder}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)


def create_tree_model():
    return DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)

def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    X_train, y_train, X_test, y_test, dataset, feature_names, scaler, numeric_features, label_encoder, encoder = load_data(partition_id, num_partitions)
    tree_model = create_tree_model()
    input_dim = X_train.shape[1]
    output_dim = len(np.unique(y_train))
    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id=partition_id + 1, feature_names=feature_names, scaler=scaler, numeric_features=numeric_features, label_encoder=label_encoder, encoder = encoder).to_client()

client_app = ClientApp(client_fn=client_fn)



# Configurar el Servidor de Flower

In [3]:
# ============================
# 📦 IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F



# ============================
# ⚖️ CONFIGURACIÓN GLOBAL
# ============================
MIN_AVAILABLE_CLIENTS = 2
NUM_SERVER_ROUNDS = 2
FEATURES = ['age', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
UNIQUE_LABELS = [' <=50K', ' >50K']
LATEST_SUPERTREE_JSON = None  # 🌲 Guardar árbol generado

# ============================
# 🧐 MODELO Y UTILIDADES
# ============================

def create_model():
    input_dim = len(FEATURES)
    output_dim = len(UNIQUE_LABELS)
    return Net(input_dim, output_dim)

def get_model_parameters(tree_model, nn_model):
    tree_params = [ -1, 2, 1 ]  # Valores por defecto para el servidor
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    total = sum(n for n, _ in metrics)
    avg: Dict[str, List[float]] = {}
    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                avg.setdefault(k, []).append(n * float(v))
    return {k: sum(vs) / total for k, vs in avg.items()}

# ============================
# 🚀 SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    model = create_model()
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])


    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)

    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON
        scaler_means = None
        scaler_stds = None
        aggregated_metrics = original_aggregate(server_round, results, failures)

        try:
            print(f"\n[SERVIDOR] 🌲 Generando SuperTree - Ronda {server_round}")
            tree_dicts = []
            total_arboles = 0

            for client_idx, (_, evaluate_res) in enumerate(results):
                metrics = evaluate_res.metrics
                trees_json = metrics.get("tree_ensemble", None)
                if metrics.get("scaler_mean") and metrics.get("scaler_std"):
                    scaler_means = json.loads(metrics["scaler_mean"])
                    scaler_stds = json.loads(metrics["scaler_std"])

                if "encoded_feature_names" in metrics:
                    feature_names = json.loads(metrics["encoded_feature_names"])

                if trees_json:
                    try:
                        trees_list = json.loads(trees_json)
                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)
                            if root:
                                tree_dicts.append(root)
                                total_arboles += 1
                    except Exception as e:
                        print(f"[CLIENTE {client_idx+1}] ❌ Error al parsear árbol: {e}")

            # print(f"[SERVIDOR] 📊 Total de árboles: {total_arboles}")

            if not tree_dicts:
                print("[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.")
                return aggregated_metrics

            supertree = SuperTree()
            supertree.mergeDecisionTrees(tree_dicts, num_classes=len(UNIQUE_LABELS), feature_names=feature_names)
            supertree.prune_redundant_leaves_full()
            supertree.merge_equal_class_leaves()

            _save_supertree_plot(supertree.root, server_round, feature_names=feature_names, class_names=UNIQUE_LABELS, scaler_means=scaler_means, scaler_stds=scaler_stds)
            LATEST_SUPERTREE_JSON = json.dumps(supertree.root.to_dict())

        except Exception as e:
            print(f"[SERVIDOR] ❌ Error en SuperTree: {e}")

        time.sleep(10)
        return aggregated_metrics

    strategy.aggregate_evaluate = custom_aggregate_evaluate
    config = ServerConfig(num_rounds=NUM_SERVER_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)

# ============================
# 📂 HELPERS
# ============================

def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
        return instructions
    return wrapper

def _save_supertree_plot(root_node, round_number, feature_names=None, class_names=None, scaler_means=None, scaler_stds=None):
    round_folder = f"Ronda_{round_number}"
    os.makedirs(round_folder, exist_ok=True)

    supertree_folder = f"{round_folder}/Supertree"
    os.makedirs(supertree_folder, exist_ok=True)

    dot = Digraph()
    node_id = [0]

    def base_name(feat):
        return feat.split('=')[0] if '=' in feat else feat

    def add_node(node, parent=None, label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        if node.is_leaf:
            class_index = np.argmax(node.labels)
            class_label = class_names[class_index] if class_names else f"Clase {class_index}"
            label_text = f"Clase: {class_label}\n{node.labels}"
        else:
            try:
                fname = feature_names[node.feat]
                label_text = base_name(fname)
            except:
                label_text = f"X_{node.feat}"

        dot.node(curr, label_text)

        if parent:
            dot.edge(parent, curr, label=label)

        if not node.is_leaf:
            for i, child in enumerate(node.children):
                try:
                    feat_val = feature_names[node.feat]
                except:
                    feat_val = f"X_{node.feat}"

                # Si categórico
                if '=' in feat_val:
                    attr, val = feat_val.split('=')
                    edge_label = f"= {val}" if i == 1 else f"≠ {val}"
                else:
                    if scaler_means and scaler_stds:
                        idx = feature_names.index(feat_val)
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val = val * scaler_stds[idx] + scaler_means[idx]
                    else:
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]

                    edge_label = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"

                add_node(child, curr, edge_label)

    add_node(root_node)
    filename = f"{supertree_folder}/supertree_ronda_{round_number}"
    dot.render(filename, format="png", cleanup=True)
    # print(f"[SERVIDOR] ✅ SuperTree guardado como '{filename}.png'")

# ============================
# 🔧 INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)


**Pasos que se realizan en el notebook:**

1. El servidor inicializa el modelo y lo envía a cada uno de los clientes.

2. Cada cliente entrena un RandomForest con su respectivo subconjunto de datos o partición que hemos realizado al principio.

3. Los clientes entrenan, y mandan sus hiperparámetros (Nº de árboles, profundidad, etc.) al servidor.

4. El servidor combina los parámetros y actualiza el modelo global.

5. Se mide el rendimiento del modelo sobre cada cliente, obteniendo también sus contrafactuales y se repite el proceso las rondas que deseemos.

# Ejecutar la Simulación Federada


In [4]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesión previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)


2025-06-13 09:30:36,830	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-06-13 09:30:42,525 flwr         DEBUG    Asyncio event loop already running.
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=2, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[CLIENTE 2] ✅ Red neuronal entrenada
[CLIENTE 1] ✅ Red neuronal entrenada


[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=24436, ip=127.0.0.1, actor_id=7210065d006045db7502f5f901000000, repr=<flwr.simulation.ray_transport.ray_actor._modify_class.<locals>.Class object at 0x0000023F09EAE210>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 143, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\message_handler\message_handler.py", line 135, in handle_legacy_message_from_msgtype
    evaluate_res = maybe_call_evaluate(
                   ^^^^^^^^^^^^^^^^^^^^
  File "c:\Use


[SERVIDOR] 🌲 Generando SuperTree - Ronda 1


[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)


[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 2] ✅ Red neuronal entrenada
[CLIENTE 1] 🤖 Predicción de la red neuronal:  >50K
[CLIENTE 2] 🤖 Predicción de la red neuronal:  <=50K


[91mERROR [0m:     An exception was raised when processing a message by RayBackend



[CLIENTE 1] 🧪 Instancia a explicar:
age              -0.561812
workclass         1.131578
education         3.000000
marital-status    2.000000
occupation        1.000000
relationship      3.000000
race              0.000000
sex               1.000000
hours-per-week    1.000000
native-country    0.000000
dtype: float32
 [CLIENTE 1] 🧪 Clase real:  >50K

 [CLIENTE 1] 📜 Regla de explicación del árbol fusionado:
   - relationship > 2.000
 ⇒ target =  >50K

🧬 [CLIENTE 1] Contrafactuales sugeridos:

  ⚡ Contrafactual #1:
   - relationship ≤ 1.704
   - education > 2.785
   - race ≤ 1.980
   - race > 1.000
   ⇒ target =  <=50K

[CLIENTE 2] 🧪 Instancia a explicar:
age               0.065198
workclass        -0.039315
education         3.000000
marital-status    1.000000
occupation        0.000000
relationship      4.000000
race              1.000000
sex               2.000000
hours-per-week    0.000000
native-country    2.000000
dtype: float32
 [CLIENTE 2] 🧪 Clase real:  <=50K

 [CLIENTE 2] 📜 

[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=24436, ip=127.0.0.1, actor_id=a183fe86e33380b64fdda70101000000, repr=<flwr.simulation.ray_transport.ray_actor._modify_class.<locals>.Class object at 0x0000023F09EAE450>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 143, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client_app.py", line 126, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\message_handler\message_handler.py", line 135, in handle_legacy_message_from_msgtype
    evaluate_res = maybe_call_evaluate(
                   ^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\pablo\anaconda3\Lib\site-packages\flwr\client\client.py", line 244, in maybe_call_


[SERVIDOR] 🌲 Generando SuperTree - Ronda 2


[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 2 round(s) in 90.45s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 24.029102259411435
[92mINFO [0m:      		round 2: 24.029102259411435
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'AUC': [(1, 0.5), (2, 0.5)],
[92mINFO [0m:      	 'Accuracy': [(1, 0.3333333333333333), (2, 0.3333333333333333)],
[92mINFO [0m:      	 'F1_Score': [(1, 0.16666666666666666), (2, 0.16666666666666666)],
[92mINFO [0m:      	 'Precision': [(1, 0.7777777777777778), (2, 0.7777777777777778)],
[92mINFO [0m:      	 'Recall': [(1, 0.3333333333333333), (2, 0.3333333333333333)]}
[92mINFO [0m:      
