In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================
import warnings
import time
import sys
import random
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score, 
    f1_score, confusion_matrix, roc_auc_score
)

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context, NDArrays, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

from graphviz import Digraph

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder.tabular_enc import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OrdinalEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F

# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 4
MIN_AVAILABLE_CLIENTS = 4
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None


class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# 📥 CARGAR DATOS
# =======================

def load_data_general(flower_dataset_name: str, class_col: str, partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES

    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner})

    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]

    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset = dataset[~dataset["workclass"].isin([" ?"])]
        dataset = dataset[~dataset["occupation"].isin([" ?"])]

    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset['MonthlyCharges'] = pd.to_numeric(dataset['MonthlyCharges'], errors='coerce')
        dataset['tenure'] = pd.to_numeric(dataset['tenure'], errors='coerce')
        dataset.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)


    for col in dataset.select_dtypes(include=["object"]).columns:
        if dataset[col].nunique() < 50:
            dataset[col] = dataset[col].astype("category")

    class_original = dataset[class_col].copy()

    tabular_dataset = TabularDataset(dataset.copy(), class_name=class_col)
    descriptor = tabular_dataset.descriptor

    # AÑADIR DISTINCT_VALUES si falta en categóricas
    for col, info in descriptor["categorical"].items():
        if "distinct_values" not in info:
            info["distinct_values"] = list(dataset[col].dropna().unique())

    label_encoder = LabelEncoder()
    dataset[class_col] = label_encoder.fit_transform(dataset[class_col])
    dataset.rename(columns={class_col: "class"}, inplace=True)
    y = dataset["class"]

    if not UNIQUE_LABELS:
        UNIQUE_LABELS[:] = label_encoder.classes_.tolist()

    numeric_features = list(descriptor["numeric"].keys())
    categorical_features = list(descriptor["categorical"].keys())
    FEATURES[:] = numeric_features + categorical_features

    numeric_indices = list(range(len(numeric_features)))
    categorical_indices = list(range(len(numeric_features), len(FEATURES)))

    X_array = dataset[FEATURES].to_numpy()

    preprocessor = ColumnTransformer([
        ("num", StandardScaler(), numeric_indices),
        ("cat", OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1), categorical_indices)
    ])

    X_encoded = preprocessor.fit_transform(X_array)

    encoder = ColumnTransformerEnc(descriptor)
    feature_names = list(encoder.encoded_features.values())

    split_idx = int(0.8 * len(X_encoded))
    return (
        X_encoded[:split_idx], y[:split_idx],
        X_encoded[split_idx:], y[split_idx:],
        tabular_dataset, feature_names, label_encoder,
        preprocessor.named_transformers_["num"], numeric_features, encoder, preprocessor
    )

# =======================



# DATASET_NAME = "pablopalacios23/adult_small"
# CLASS_COLUMN = "class"


DATASET_NAME = "pablopalacios23/Iris"
CLASS_COLUMN = "target"


# DEMASIADO GRANDE EL DATASET :/
# DATASET_NAME = "pablopalacios23/churn"
# CLASS_COLUMN = "Churn" 
 

# =======================


load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)

2025-06-10 13:53:24,119	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-06-10 13:53:27,431 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-06-10 13:53:27,598 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/Iris/resolve/main/README.md HTTP/11" 404 0
2025-06-10 13:53:27,722 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/Iris HTTP/11" 200 610
2025-06-10 13:53:27,848 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/Iris/resolve/6bbdbfec420ddde25fd56eb2d01f4bb904d94740/Iris.py HTTP/11" 404 0
2025-06-10 13:53:27,851 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-06-10 13:53:28,187 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "

(array([[ 0.0097124 ,  0.37652839,  0.52950125,  0.81472018],
        [-0.23633508, -1.21325815, -0.20940727, -0.24826237],
        [ 0.25575988, -1.21325815,  0.98421419,  0.2832289 ],
        [-1.58959622,  0.11156397, -1.34618962, -1.31124492],
        [ 0.50180736, -0.94829372,  0.58634037,  0.81472018],
        [ 1.2399498 ,  0.37652839,  1.0410533 ,  1.47908427],
        [-0.48238256, -1.74318699, -0.09572904, -0.24826237],
        [ 2.10111598, -0.15340046,  1.55260536,  1.21333863],
        [ 1.8550685 , -0.6833293 ,  1.26840977,  0.947593  ],
        [-0.48238256,  3.02617262, -1.40302874, -1.31124492],
        [ 0.99390232,  0.64149281,  1.0410533 ,  1.21333863],
        [-1.09750126,  1.43638608, -1.40302874, -1.31124492],
        [-1.34354874,  0.90645724, -1.11883315, -1.31124492],
        [ 0.6248311 , -0.94829372,  0.81369684,  0.947593  ],
        [ 1.60902102, -0.15340046,  1.09789242,  0.54897454],
        [-1.09750126, -0.15340046, -1.2893505 , -1.31124492],
        

# Cliente

In [None]:
# ==========================
# 🌼 CLIENTE FLOWER
# ==========================
import operator
import warnings
import os
import json
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
from sklearn.exceptions import NotFittedError

import torch
import torch.nn as nn
import torch.nn.functional as F

from flwr.client import NumPyClient
from flwr.common import Context
from flwr.common import parameters_to_ndarrays

from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.encoder_decoder import ColumnTransformerEnc

from sklearn.metrics import pairwise_distances

from graphviz import Digraph

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()
        

class FlowerClient(NumPyClient):
    def __init__(self, tree_model, nn_model, X_train, y_train, X_test, y_test, dataset, client_id, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor):
        self.tree_model = tree_model
        self.nn_model = nn_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.dataset = dataset
        self.client_id = client_id
        self.feature_names = feature_names
        self.label_encoder = label_encoder
        self.scaler = scaler
        self.numeric_features = numeric_features
        self.encoder = encoder
        self.unique_labels = label_encoder.classes_.tolist()
        self.y_train_nn = y_train.astype(np.int64)
        self.received_supertree = None
        self.preprocessor = preprocessor

    def _train_nn(self, epochs=10, lr=0.01):
        self.nn_model.train()
        optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        X_tensor = torch.tensor(self.X_train, dtype=torch.float32)
        y_tensor = torch.tensor(self.y_train_nn, dtype=torch.long)

        for _ in range(epochs):
            optimizer.zero_grad()
            outputs = self.nn_model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()
        print(f"[CLIENTE {self.client_id}] ✅ Red neuronal entrenada")



    def fit(self, parameters, config):
        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.tree_model.fit(self.X_train, self.y_train)
            self._train_nn()
        nn_weights = get_model_parameters(self.tree_model, self.nn_model)["nn"]
        return nn_weights, len(self.X_train), {}
    


    def evaluate(self, parameters, config):


        set_model_params(self.tree_model, self.nn_model, {"tree": [-1, 2, 1], "nn": parameters})

        if "supertree" in config:
            try:
                print("Recibiendo supertree....")
                supertree_dict = json.loads(config["supertree"])
                self.received_supertree = SuperTree.convert_SuperNode_to_Node(SuperTree.SuperNode.from_dict(supertree_dict))
                # self.received_supertree = SuperTree.SuperNode.from_dict(supertree_dict)

            except Exception as e:
                print(f"[CLIENTE {self.client_id}] ❌ Error al recibir SuperTree: {e}")

        try:
            _ = self.tree_model.predict(self.X_test)
        except NotFittedError:
            self.tree_model.fit(self.X_train, self.y_train)

        y_pred = self.tree_model.predict(self.X_test)
        y_proba = self.tree_model.predict_proba(self.X_test)

        supertree = SuperTree()
        root_node = supertree.rec_buildTree(self.tree_model, list(range(self.X_train.shape[1])), len(self.unique_labels))
        round_number = config.get("server_round", 1)
        self._save_local_tree(root_node, round_number)
        tree_json = json.dumps([root_node.to_dict()])

        if self.received_supertree is not None:
            self._explain_local_and_global(config)

        if len(np.unique(self.y_test)) == 2:
            # Clasificación binaria: usar la probabilidad de clase positiva
            auc = roc_auc_score(self.y_test, y_proba[:, 1])
        else:
            # Clasificación multiclase
            auc = roc_auc_score(self.y_test, y_proba, multi_class="ovr")

        return 0.0, len(self.X_test), {
            # "Accuracy": accuracy_score(self.y_test, y_pred),
            # "Precision": precision_score(self.y_test, y_pred, average="weighted", zero_division=1),
            # "Recall": recall_score(self.y_test, y_pred, average="weighted"),
            # "F1_Score": f1_score(self.y_test, y_pred, average="weighted"),
            # "AUC": auc,
            "tree_ensemble": tree_json,
            "scaler_mean": json.dumps(self.scaler.mean_.tolist()),
            "scaler_std": json.dumps(self.scaler.scale_.tolist()),
            "encoded_feature_names": json.dumps(self.feature_names)
        }

    def _explain_local_and_global(self, config):
        from sklearn.metrics import accuracy_score
        import numpy as np
    
        num_row = 5

        # Reconstruir DataFrame original codificado
        # feature_cols = self.feature_names
        
        local_df = pd.DataFrame(self.X_train, columns=FEATURES).astype(np.float32)
        local_df["target"] = self.label_encoder.inverse_transform(self.y_train_nn)

        # print(local_df.head())

        # print(local_df)

        # TabularDataset + descriptor
        local_tabular_dataset = TabularDataset(local_df, class_name="target")
        descriptor = local_tabular_dataset.get_descriptor()

        encoder = ColumnTransformerEnc(descriptor)
        encoder.set_classes(self.unique_labels)
        self.encoder = encoder  # opcional

        # Explicabilidad local
        nn_wrapper = TorchNNWrapper(self.nn_model)
        bbox = sklearn_classifier_bbox.sklearnBBox(nn_wrapper)
        lore = TabularGeneticGeneratorLore(bbox, local_tabular_dataset)

        instance_scaled = local_tabular_dataset.df.iloc[num_row][:-1]
        
        target = local_tabular_dataset.df.iloc[num_row][-1]
        instance_array = instance_scaled.values.reshape(1, -1).astype(np.float32)
        pred_idx = self.nn_model(torch.tensor(instance_array)).argmax(dim=1).item()
        pred_label = self.label_encoder.inverse_transform([pred_idx])[0]
        print(f"[CLIENTE {self.client_id}] 🤖 Predicción de la red neuronal: {pred_label}")

        # Explicación LORE
        explanation = lore.explain_instance(instance_scaled.astype(np.float32), merge=True)
        lore_tree = explanation["merged_tree"]
        round_number = config.get("server_round", 1)
        self._save_lore_tree(lore_tree.root, round_number)

        # Árbol fusionado
        merged_tree = SuperTree()
        node_LORE_Tree = SuperTree.convert_SuperNode_to_Node(lore_tree.root)
        merged_root = merged_tree.mergeDecisionTrees(
            roots=[node_LORE_Tree, self.received_supertree],
            num_classes=len(self.unique_labels),
            feature_names=self.dataset.df.columns[:-1].tolist()
        )
        merged_tree.root = merged_root
        merged_tree.prune_redundant_leaves_full()
        merged_tree.merge_equal_class_leaves()
        self._save_merged_tree(merged_tree.root, round_number)

        # Decodificación para mostrar instancia legible
        raw_instance = self.dataset.df.iloc[num_row][:-1].values.reshape(1, -1)
        z_encoded = encoder.encode(raw_instance)[0]
        decoded_instance = encoder.decode(np.array([z_encoded]))[0]
        decoded_instance = pd.Series(decoded_instance, index=self.dataset.df.columns[:-1])

        # Desescalar variables numéricas
        if hasattr(self, 'numeric_features') and hasattr(self, 'scaler'):
            for i, col in enumerate(self.numeric_features):
                if col in decoded_instance.index:
                    decoded_instance[col] = instance_scaled[col] * self.scaler.scale_[i] + self.scaler.mean_[i]

        print(f"\n[CLIENTE {self.client_id}] 🧪 Instancia a explicar:")
        print(decoded_instance)
        print(f"[CLIENTE {self.client_id}] 🧪 Clase real: {target}")

        # Preprocesamiento para obtener z_encoded
        # ⚠️ Utiliza instancia ya codificada numéricamente para evitar strings
        raw_instance_df = pd.DataFrame(instance_scaled.values.reshape(1, -1), columns=self.dataset.df.columns[:-1])
        raw_instance_preprocessed = self.preprocessor.transform(raw_instance_df)
        z_encoded = encoder.encode(raw_instance_preprocessed).astype(np.float32)[0]

        tree_str = self.tree_to_str(merged_tree.root, self.feature_names)


        # Guardar reglas legibles
        # Regla del árbol fusionado
        # Extraer reglas
        
        final_rule = []
        rules = self.extract_rules_from_str(tree_str, target_class=pred_idx)
        if rules:
            for cond in rules[0]:
                descaled = False
                for var in self.numeric_features:
                    if cond.startswith(var):
                        import re
                        match1 = re.match(rf"{re.escape(var)} (≤|<|>|≥) ([\d\.\-e]+)", cond)
                        match2 = re.match(rf"{re.escape(var)} > ([\d\.\-e]+) ∧ ≤ ([\d\.\-e]+)", cond)
                        if match2:
                            low, high = map(float, match2.groups())
                            idx = self.numeric_features.index(var)
                            low_real = low * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                            high_real = high * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                            final_rule.append(f"{var} > {low_real:.2f} ∧ ≤ {high_real:.2f}")
                            descaled = True
                            break
                        elif match1:
                            op, val = match1.groups()
                            idx = self.numeric_features.index(var)
                            val_real = float(val) * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                            final_rule.append(f"{var} {op} {val_real:.2f}")
                            descaled = True
                            break
                if not descaled:
                    final_rule.append(cond)

            print(f"\n[CLIENTE {self.client_id}] 📜 Regla desde árbol:")
            
            # print(final_rule)
            final_rule = self.filtrar_condiciones_redundantes(final_rule)
            # print("filtradas condiciones redundantes:",final_rule)
            
            for line in final_rule:
                print(f"   - {line}")
            print(f" ⇒ target = {pred_label}")
        else:
            print(f"[CLIENTE {self.client_id}] ⚠️ No se encontró regla para clase {pred_label}")
            

        cf_rules = self.extract_counterfactual_rules(tree_str, pred_idx) 
        

        if cf_rules:
            print(f"\n🧬 [CLIENTE {self.client_id}] Contrafactuales sugeridos:")
            for class_idx, rule in cf_rules.items():
                final_cf = []
                label = self.label_encoder.inverse_transform([class_idx])[0]
                print(f" ⇒ Posible clase: {label}")
                seen = set()
                for cond in rule:
                    if cond in seen:
                        continue
                    seen.add(cond)

                    descaled = False
                    for var in self.numeric_features:
                        if cond.startswith(var):
                            import re
                            match1 = re.match(rf"{re.escape(var)} (≤|<|>|≥) ([\d\.\-e]+)", cond)
                            match2 = re.match(rf"{re.escape(var)} > ([\d\.\-e]+) ∧ ≤ ([\d\.\-e]+)", cond)
                            if match2:
                                low, high = map(float, match2.groups())
                                idx = self.numeric_features.index(var)
                                low_real = low * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                                high_real = high * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                                final_cf.append(f"{var} > {low_real:.2f} ∧ ≤ {high_real:.2f}")
                                descaled = True
                                break
                            elif match1:
                                op, val = match1.groups()
                                idx = self.numeric_features.index(var)
                                val_real = float(val) * self.scaler.scale_[idx] + self.scaler.mean_[idx]
                                final_cf.append(f"{var} {op} {val_real:.2f}")
                                descaled = True
                                break
                    if not descaled:
                        final_cf.append(cond)

                # print(final_cf)
                final_cf = self.filtrar_condiciones_redundantes(final_cf)
                # print("filtradas condiciones redundantes:", final_cf)
                for line in final_cf:
                    print(f"   - {line}")

        

        # ==========================
        # 📏 MÉTRICAS DE EXPLICACIÓN
        # ==========================
        Z = explanation["neighborhood_Z"] # instancias del vecindario sintético generado alrededor del punto a explicar.
        y_surrogate = explanation["neighborhood_Yb"] # predicciones del modelo interpretable (arbol) sobre Z.
        y_nn = nn_wrapper.predict(Z) 

        # Convertir Z en DataFrame legible
        dfZ = pd.DataFrame(Z, columns=self.dataset.df.columns[:-1])


        # Silhouette
        # 1️⃣ Distancia media entre x y las instancias de su misma clase en el vecindario (Z+)
        mask_same_class = (y_nn == pred_idx)
        mask_diff_class = (y_nn != pred_idx)

        Z_plus = dfZ[mask_same_class]
        Z_minus = dfZ[mask_diff_class]

        # Evitar división por cero
        if not Z_plus.empty and not Z_minus.empty:
            dist_same = pairwise_distances(
                instance_scaled.values.reshape(1, -1).astype(float),
                Z_plus.values.astype(float)
            ).mean()
            dist_diff = pairwise_distances(
                instance_scaled.values.reshape(1, -1).astype(float),
                Z_minus.values.astype(float)
            ).mean()
            
            silhouette = (dist_diff - dist_same) / max(dist_diff, dist_same)
        else:
            silhouette = np.nan

        # Fidelity
        fidelity = accuracy_score(y_nn, y_surrogate)
        

        # Función para evaluar si una fila cumple la regla
        def instancia_cumple_regla(fila, condiciones):
            for cond in condiciones:
                cond = cond.replace("≤", "<=").replace("≥", ">=").strip()
                if "∧" in cond:
                    variable = None
                    partes = []
                    for parte in cond.split("∧"):
                        parte = parte.strip()
                        for col in fila.index:
                            if col in parte:
                                variable = col
                                break
                        if variable is None:
                            raise ValueError(f"No se encontró la variable en {parte}")
                        if variable not in parte:
                            parte = f"{variable} {parte}"
                        partes.append(parte)
                    if not all(eval(ajustar_expresion(p, fila)) for p in partes):
                        return False
                elif any(op in cond for op in [">", "<", "<=", ">="]):
                    if not eval(ajustar_expresion(cond, fila)):
                        return False
                else:
                    if " ≠ " in cond:
                        var, val = cond.split(" ≠ ")
                        if str(fila.get(var.strip(), "")) == val.strip():
                            return False
                    elif "=" in cond:
                        var, val = cond.split("=")
                        if str(fila.get(var.strip(), "")) != val.strip():
                            return False
            return True

        def ajustar_expresion(cond, fila):
            for col in fila.index:
                if col in cond:
                    cond = cond.replace(col, f"fila[{repr(col)}]")
            return cond
    


        dfZ_coverage = pd.DataFrame(Z, columns=self.dataset.df.columns[:-1])

        if hasattr(self, 'scaler'):
            for i, col in enumerate(self.numeric_features):
                if col in dfZ_coverage.columns:
                    dfZ_coverage[col] = dfZ_coverage[col] * self.scaler.scale_[i] + self.scaler.mean_[i]

        # Agregamos la clase predicha del modelo interpretable
        dfZ_coverage['class'] = y_surrogate


    
        # Coverage: mide cuántas instancias del vecindario 𝑍 (generado alrededor de la instancia a explicar) cumplen la regla factual 𝑝. Es decir, calcula la proporción de instancias en las que la regla es aplicable.
        mask_class_predicha = (dfZ_coverage['class'] == pred_idx)
        cumplen_regla = dfZ_coverage[mask_class_predicha].apply(lambda fila: instancia_cumple_regla(fila, final_rule), axis=1)

        dfZ_coverage[mask_class_predicha].to_csv(f"dfZ_coverage_cliente_{self.client_id}.csv", index=False)

        num_cumplen = cumplen_regla.sum()
        
        print("num_cumplen: ", num_cumplen)
        coverage = num_cumplen / len(dfZ_coverage)




        # Ahora filtra y_nn
        covered_target_match = (y_nn[mask_class_predicha][cumplen_regla.values] == pred_idx)

        # Calcular precisión
        if cumplen_regla.sum() > 0:
            precision = covered_target_match.sum() / cumplen_regla.sum()
        else:
            precision = 0

        # Complexity --> Más condiciones = menos interpretable (más complejo) /////  Menos condiciones = más interpretable (menos complejo)
        complexity = len(final_rule)



        
        
        # Dissimilarity

        instance_cf = instance_scaled.copy() # Esta es la instancia original (escalada)

        # print("cf_rules:")
        # print(cf_rules) # Contrafactual rule --> Esto NO es una instancia concreta todavía, sino instrucciones para modificar la instancia factual.

        def aplicar_condicion(cond, x_cf):
            import re
            match = re.match(r"(.+?)\s*(<=|<|>|>=)\s*([\d\.\-e]+)", cond)
            if match:
                var, op, val = match.groups()
                val = float(val)
                if op in [">", ">="]:
                    x_cf[var] = val + 1e-3  # Ajuste mínimo para cumplir
                else:
                    x_cf[var] = val  # Ajuste máximo para cumplir



        

        #  Instancia contrafactual generada

        # Iterar sobre las reglas contrafactuales
        # cf_rules es un diccionario donde las claves son las clases y los valores son listas de condiciones
        # que definen cómo modificar la instancia original para que pertenezca a esa clase.

        # Por ejemplo, si cf_rules = {0: ["X1 > 5", "X2 ≤ 3"], 1: ["X1 ≤ 4", "X2 > 2"]}, significa que
        # para que la instancia original pertenezca a la clase 0, debe cumplir "X1 > 5" y "X2 ≤ 3",
        # y para la clase 1, debe cumplir "X1 ≤ 4" y "X2 > 2".




        # Ejemplo real
        
        # instancia factual = {sepal length (cm): -1.37
        #                     sepal width (cm): 0.23
        #                     petal length (cm): -1.12
        #                     petal width (cm): -1.23}



        # Aplicamos las condiciones de la regla contrafactual:
        # 'petal length (cm) > -0.56 ∧ ≤ 0.99': ajustamos petal length (cm) para que sea justo 0.99 (o -0.56 + 1e-3 si queremos ir seguros)
        # 'sepal length (cm) ≤ 1.81': ajustamos sepal length (cm) a 1.81


        # Esto da como resultado una nueva instancia (la contrafactual):

        # instancia contrafactual =   {sepal length (cm): 1.81
        #                             sepal width (cm): 0.23
        #                             petal length (cm): 0.99
        #                             petal width (cm): -1.23}

        
        dissimilarities = {}

        for class_idx, cf_rule in cf_rules.items(): 
            instance_cf = instance_scaled.copy()
            for cond in cf_rule:
                cond = cond.replace("≤", "<=").replace("≥", ">=").strip()
                if "∧" in cond:
                    partes = cond.split("∧")
                    for parte in partes:
                        aplicar_condicion(parte.strip(), instance_cf)
                else:
                    aplicar_condicion(cond, instance_cf)

            # Calcular dissimilarity
            dissimilarity = pairwise_distances(
                instance_scaled.values.reshape(1, -1).astype(float),
                instance_cf.values.reshape(1, -1).astype(float),
                metric='euclidean'
            )[0][0]

            # Guardar en el diccionario
            label = self.label_encoder.inverse_transform([class_idx])[0]
            dissimilarities[label] = dissimilarity
            


        # Mostrar métricas al final
        print(f"\n📊 [CLIENTE {self.client_id}] Métricas de explicación:")
        print(f" - Silhouette: {silhouette:.3f}")
        print(f" - Fidelity: {fidelity:.2f}")
        print(f" - Coverage: {coverage:.3f}")
        print(f" - Precision: {precision:.3f}")
        print(f" - Complexity: {complexity:.3f}")
        for label, dissim in dissimilarities.items():
            print(f" - Dissimilarity con clase {label}: {dissim:.3f}")

        
    def filtrar_condiciones_redundantes(self, condiciones):
        import re
        from collections import defaultdict

        condiciones_filtradas = []
        agrupadas = defaultdict(list)

        for cond in condiciones:
            match_intervalo = re.match(r"(.+?) > ([\d\.\-e]+) ∧ ≤ ([\d\.\-e]+)", cond)
            match_simple = re.match(r"(.+?) (≤|<|>|≥) ([\d\.\-e]+)", cond)

            if match_intervalo:
                var, low, high = match_intervalo.groups()
                agrupadas[var].append(("intervalo", float(low), float(high), cond))
            elif match_simple:
                var, op, val = match_simple.groups()
                agrupadas[var].append((op, float(val), cond))
            else:
                condiciones_filtradas.append(cond)  # categóricas o fuera de formato

        for var, items in agrupadas.items():
            low_vals = []
            high_vals = []

            for item in items:
                if item[0] == "intervalo":
                    low_vals.append(item[1])
                    high_vals.append(item[2])
                elif item[0] in {">", "≥"}:
                    low_vals.append(item[1])
                elif item[0] in {"<", "≤"}:
                    high_vals.append(item[1])

            if low_vals or high_vals:
                low_final = min(low_vals) if low_vals else None
                high_final = max(high_vals) if high_vals else None

                if low_final is not None and high_final is not None and low_final < high_final:
                    condiciones_filtradas.append(f"{var} > {low_final:.2f} ∧ ≤ {high_final:.2f}")
                elif low_final is not None:
                    condiciones_filtradas.append(f"{var} > {low_final:.2f}")
                elif high_final is not None:
                    condiciones_filtradas.append(f"{var} ≤ {high_final:.2f}")
            # No se mete ninguna condición adicional (las intermedias quedan absorbidas)

        return condiciones_filtradas


    def extract_rules_from_str(self, tree_str, target_class):
        lines = tree_str.strip().split("\n")
        path = []
        rules = []

        def recurse(idx, indent_level):
            seen = set()  # ← Aquí guardamos las premisas únicas
            while idx < len(lines):
                line = lines[idx]
                current_indent = len(line) - len(line.lstrip())

                if current_indent < indent_level:
                    return idx
                if "⮕" in line:
                    if f"class = {target_class}" in line:
                        cleaned = []
                        for cond in path:
                            if cond not in seen:
                                cleaned.append(cond)
                                seen.add(cond)
                        rules.append(cleaned)
                    return idx + 1
                elif "if" in line:
                    condition = line.strip()[3:]  # remove 'if '
                    path.append(condition)
                    idx = recurse(idx + 1, current_indent + 2)
                    path.pop()
                else:
                    idx += 1
            return idx
        
        
        recurse(0, 0)
        return rules
    

    
    def extract_counterfactual_rules(self, tree_str, predicted_class):
        lines = tree_str.strip().split("\n")
        path = []
        counterfactuals = {}
        
        def recurse(idx, indent_level):
            while idx < len(lines):
                line = lines[idx]
                current_indent = len(line) - len(line.lstrip())

                if current_indent < indent_level:
                    return idx
                if "⮕" in line:
                    import re
                    match = re.search(r"class = (\d+)", line)
                    if match:
                        class_idx = int(match.group(1))
                        if class_idx != predicted_class and class_idx not in counterfactuals:
                            counterfactuals[class_idx] = list(path)
                    return idx + 1
                elif "if" in line:
                    condition = line.strip()[3:]
                    path.append(condition)
                    idx = recurse(idx + 1, current_indent + 2)
                    path.pop()
                else:
                    idx += 1
            return idx

        recurse(0, 0)
        return counterfactuals






    def tree_to_str(self, node, feature_names, depth=0):
        indent = "  " * depth
        result = ""

        if node.is_leaf:
            class_idx = np.argmax(node.labels)
            result += f"{indent}⮕ Leaf: class = {class_idx} | {node.labels}\n"
        else:
            fname = feature_names[node.feat] if node.feat < len(feature_names) else f"X_{node.feat}"

            is_cat = "=" in fname
            base_feat = fname.split("=")[0] if is_cat else fname
            cat_val = fname.split("=")[1] if is_cat else None

            for i, child in enumerate(node.children):
                if is_cat:
                    cond = f"{base_feat} {'≠' if i == 0 else '='} {cat_val}"
                else:
                    if i == 0:
                        cond = f"{base_feat} ≤ {node.intervals[i]:.3f}"
                    elif i < len(node.intervals):
                        cond = f"{base_feat} > {node.intervals[i-1]:.3f} ∧ ≤ {node.intervals[i]:.3f}"
                    else:
                        cond = f"{base_feat} > {node.intervals[i-1]:.3f}"

                result += f"{indent}if {cond}\n"
                result += self.tree_to_str(child, feature_names, depth + 1)

        return result
    
    def _save_local_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node,
            round_number,
            tree_type="Arbol_Local"
        )

    def _save_lore_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="LoreTree"
        )

    def _save_merged_tree(self, root_node, round_number):
        self._save_generic_tree(
            root_node, 
            round_number, 
            tree_type="MergedTree"
        )

    def _save_generic_tree(self, root_node, round_number, tree_type):
        from graphviz import Digraph
        import numpy as np
        import os

        dot = Digraph()
        node_id = [0]

        def base_name(feat):
            return feat.split('=')[0] if '=' in feat else feat

        def add_node(node, parent=None, edge_label=""):
            curr = str(node_id[0])
            node_id[0] += 1

            # Etiqueta del nodo
            if node.is_leaf:
                class_index = np.argmax(node.labels)
                class_label = self.unique_labels[class_index]
                label = f"class: {class_label}\n{node.labels}"
            else:
                try:
                    fname = self.feature_names[node.feat]
                    label = base_name(fname)
                except:
                    label = f"X_{node.feat}"

            dot.node(curr, label)
            if parent:
                dot.edge(parent, curr, label=edge_label)

            # Árbol tipo SuperTree
            if hasattr(node, "children") and node.children is not None and hasattr(node, "intervals"):
                for i, child in enumerate(node.children):
                    try:
                        fname = self.feature_names[node.feat]
                    except:
                        fname = f"X_{node.feat}"

                    if '=' in fname:
                        attr, val = fname.split('=')
                        edge = f"= {val}" if i == 1 else f"≠ {val}"
                    else:
                        original_feat = base_name(fname)
                        if hasattr(self, "scaler") and original_feat in self.numeric_features:
                            idx = self.numeric_features.index(original_feat)
                            mean = self.scaler.mean_[idx]
                            std = self.scaler.scale_[idx]
                            val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                            val = val * std + mean
                        else:
                            val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"

                    add_node(child, curr, edge)

            # Árbol binario clásico
            elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
                try:
                    fname = self.feature_names[node.feat]
                except:
                    fname = f"X_{node.feat}"

                if '=' in fname:
                    attr, val = fname.split('=')
                    left_label = f"≠ {val}"
                    right_label = f"= {val}"
                else:
                    original_feat = base_name(fname)
                    if hasattr(self, "scaler") and original_feat in self.numeric_features:
                        idx = self.numeric_features.index(original_feat)
                        mean = self.scaler.mean_[idx]
                        std = self.scaler.scale_[idx]
                        thresh = node.thresh * std + mean if node.thresh is not None else None
                    else:
                        thresh = node.thresh

                    if thresh is not None:
                        left_label = f"<= {thresh:.2f}"
                        right_label = f"> {thresh:.2f}"
                    else:
                        left_label = "≤ ?"
                        right_label = "> ?"

                if node._left_child:
                    add_node(node._left_child, curr, left_label)
                if node._right_child:
                    add_node(node._right_child, curr, right_label)

        add_node(root_node)
        folder = f"Ronda_{round_number}/{tree_type}_Cliente_{self.client_id}"
        os.makedirs(folder, exist_ok=True)
        filepath = f"{folder}/{tree_type.lower()}_cliente_{self.client_id}_ronda_{round_number}"
        dot.render(filepath, format="png", cleanup=True)



def client_fn(context: Context):
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    
    # dataset_name = context.node_config.get("dataset_name", "pablopalacios23/Iris")
    # class_col = context.node_config.get("class_col", "target")

    dataset_name = DATASET_NAME 
    class_col = CLASS_COLUMN 

    (X_train, y_train,X_test, y_test,dataset, feature_names,label_encoder, scaler,numeric_features, encoder, preprocessor) = load_data_general(flower_dataset_name=dataset_name,class_col=class_col,partition_id=partition_id,num_partitions=num_partitions)

    tree_model = DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)

    input_dim = X_train.shape[1]
    output_dim = len(np.unique(y_train))
    nn_model = Net(input_dim, output_dim)
    return FlowerClient(tree_model=tree_model, 
                        nn_model=nn_model,
                        X_train=X_train,
                        y_train=y_train,
                        X_test=X_test,
                        y_test=y_test,
                        dataset=dataset,
                        client_id=partition_id + 1,
                        feature_names=feature_names,
                        label_encoder=label_encoder,
                        scaler=scaler,
                        numeric_features=numeric_features,
                        encoder=encoder,
                        preprocessor=preprocessor).to_client()

client_app = ClientApp(client_fn=client_fn)


# Servidor

In [3]:
# ============================
# 📦 IMPORTACIONES NECESARIAS
# ============================
import os
import time
import json
import numpy as np
from typing import List, Tuple, Dict
from sklearn.tree import DecisionTreeClassifier

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from graphviz import Digraph
from lore_sa.surrogate.decision_tree import SuperTree

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================
# MIN_AVAILABLE_CLIENTS = 4
# NUM_SERVER_ROUNDS = 2

FEATURES = []  # se rellenan dinámicamente
UNIQUE_LABELS = []
LATEST_SUPERTREE_JSON = None

# ============================
# 🧠 UTILIDADES MODELO
# ============================
def create_model(input_dim, output_dim):
    from __main__ import Net  # necesario si Net está en misma libreta
    return Net(input_dim, output_dim)


def get_model_parameters(tree_model, nn_model):
    tree_params = [-1, 2, 1]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
    total = sum(n for n, _ in metrics)
    avg: Dict[str, List[float]] = {}
    for n, met in metrics:
        for k, v in met.items():
            if isinstance(v, (float, int)):
                avg.setdefault(k, []).append(n * float(v))
    return {k: sum(vs) / total for k, vs in avg.items()}

# ============================
# 🚀 SERVIDOR FLOWER
# ============================

def server_fn(context: Context) -> ServerAppComponents:
    global FEATURES, UNIQUE_LABELS

    # Justo antes de llamar a create_model
    if not FEATURES or not UNIQUE_LABELS:
        
        load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)


    FEATURES = FEATURES or ["feat_0", "feat_1"]  # fallback por si no se cargó antes
    UNIQUE_LABELS = UNIQUE_LABELS or ["Class_0", "Class_1"]


    model = create_model(len(FEATURES), len(UNIQUE_LABELS))
    initial_params = ndarrays_to_parameters(get_model_parameters(None, model)["nn"])

    strategy = FedAvg(
        min_available_clients=MIN_AVAILABLE_CLIENTS,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=initial_params,
    )

    strategy.configure_fit = _inject_round(strategy.configure_fit)
    strategy.configure_evaluate = _inject_round(strategy.configure_evaluate)
    original_aggregate = strategy.aggregate_evaluate

    def custom_aggregate_evaluate(server_round, results, failures):
        global LATEST_SUPERTREE_JSON
        aggregated_metrics = original_aggregate(server_round, results, failures)

        try:
            print(f"\n[SERVIDOR] 🌲 Generando SuperTree - Ronda {server_round}")
            tree_dicts = []
            scaler_means = None
            scaler_stds = None
            feature_names = None

            for client_idx, (_, evaluate_res) in enumerate(results):
                metrics = evaluate_res.metrics

                if metrics.get("scaler_mean") and metrics.get("scaler_std"):
                    scaler_means = json.loads(metrics["scaler_mean"])
                    scaler_stds = json.loads(metrics["scaler_std"])

                if "encoded_feature_names" in metrics:
                    feature_names = json.loads(metrics["encoded_feature_names"])

                trees_json = metrics.get("tree_ensemble")
                if trees_json:
                    try:
                        trees_list = json.loads(trees_json)
                        for tdict in trees_list:
                            root = SuperTree.Node.from_dict(tdict)
                            if root:
                                tree_dicts.append(root)
                    except Exception as e:
                        print(f"[CLIENTE {client_idx+1}] ❌ Error al parsear árbol: {e}")

            if not tree_dicts:
                print("[SERVIDOR] ⚠️ No se recibieron árboles. Se omite SuperTree.")
                return aggregated_metrics

            supertree = SuperTree()
            supertree.mergeDecisionTrees(tree_dicts, num_classes=len(UNIQUE_LABELS), feature_names=feature_names)
            supertree.prune_redundant_leaves_full()
            supertree.merge_equal_class_leaves()

            _save_supertree_plot(supertree.root, server_round, feature_names, UNIQUE_LABELS, scaler_means, scaler_stds)
            LATEST_SUPERTREE_JSON = json.dumps(supertree.root.to_dict())

        except Exception as e:
            print(f"[SERVIDOR] ❌ Error en SuperTree: {e}")

        time.sleep(3)
        return aggregated_metrics

    strategy.aggregate_evaluate = custom_aggregate_evaluate
    return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=NUM_SERVER_ROUNDS))

# ============================
# 🧩 FUNCIONES AUXILIARES
# ============================
def _inject_round(original_fn):
    def wrapper(server_round, parameters, client_manager):
        global LATEST_SUPERTREE_JSON
        instructions = original_fn(server_round, parameters, client_manager)
        for _, ins in instructions:
            ins.config["server_round"] = server_round
            if LATEST_SUPERTREE_JSON:
                ins.config["supertree"] = LATEST_SUPERTREE_JSON
        return instructions
    return wrapper

def _save_supertree_plot(root_node, round_number, feature_names=None, class_names=None, scaler_means=None, scaler_stds=None):
    dot = Digraph()
    node_id = [0]

    def base_name(feat):
        return feat.split('=')[0] if '=' in feat else feat

    def add_node(node, parent=None, label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        if node.is_leaf:
            class_index = np.argmax(node.labels)
            class_label = class_names[class_index] if class_names else f"Clase {class_index}"
            label_text = f"Clase: {class_label}\n{node.labels}"
        else:
            try:
                fname = feature_names[node.feat]
                label_text = base_name(fname)
            except:
                label_text = f"X_{node.feat}"

        dot.node(curr, label_text)
        if parent:
            dot.edge(parent, curr, label=label)

        if not node.is_leaf:
            for i, child in enumerate(node.children):
                try:
                    feat_val = feature_names[node.feat]
                except:
                    feat_val = f"X_{node.feat}"

                if '=' in feat_val:
                    attr, val = feat_val.split('=')
                    edge_label = f"= {val}" if i == 1 else f"≠ {val}"
                else:
                    if scaler_means and scaler_stds:
                        idx = feature_names.index(feat_val)
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                        val = val * scaler_stds[idx] + scaler_means[idx]
                    else:
                        val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                    edge_label = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"

                add_node(child, curr, edge_label)

    add_node(root_node)
    folder = f"Ronda_{round_number}/Supertree"
    os.makedirs(folder, exist_ok=True)
    filename = f"{folder}/supertree_ronda_{round_number}"
    dot.render(filename, format="png", cleanup=True)

# ============================
# 🔧 INICIALIZAR SERVER APP
# ============================
server_app = ServerApp(server_fn=server_fn)



In [4]:
from flwr.simulation import run_simulation
import logging
import warnings
import ray

warnings.filterwarnings("ignore", category=DeprecationWarning)


logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.WARNING)
logging.getLogger('graphviz').setLevel(logging.WARNING)
# logging.getLogger("flwr").setLevel(logging.WARNING)




ray.shutdown()  # Apagar cualquier sesión previa de Ray
ray.init(local_mode=True)  # Desactiva multiprocessing, usa un solo proceso principal

backend_config = {"num_cpus": 1}

run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)


2025-06-10 13:53:34,250	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-06-10 13:53:37,732 flwr         DEBUG    Asyncio event loop already running.
:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=2, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)


:job_id:01000000
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor
:actor_name:ClientAppActor


[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)


[CLIENTE 4] ✅ Red neuronal entrenada
[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 3] ✅ Red neuronal entrenada
[CLIENTE 2] ✅ Red neuronal entrenada


[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures



[SERVIDOR] 🌲 Generando SuperTree - Ronda 1


[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)
[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)


[CLIENTE 2] ✅ Red neuronal entrenada
[CLIENTE 4] ✅ Red neuronal entrenada
[CLIENTE 1] ✅ Red neuronal entrenada
[CLIENTE 3] ✅ Red neuronal entrenada
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
Recibiendo supertree....
[CLIENTE 2] 🤖 Predicción de la red neuronal: virginica
[CLIENTE 4] 🤖 Predicción de la red neuronal: setosa
[CLIENTE 3] 🤖 Predicción de la red neuronal: virginica
[CLIENTE 1] 🤖 Predicción de la red neuronal: virginica

[CLIENTE 1] 🧪 Instancia a explicar:
sepal length (cm)    6.9
sepal width (cm)     3.2
petal length (cm)    5.7
petal width (cm)     2.3
dtype: object
[CLIENTE 1] 🧪 Clase real: virginica

[CLIENTE 1] 📜 Regla desde árbol:
   - sepal length (cm) ≤ 5.79
   - sepal width (cm) ≤ 2.65
   - petal length (cm) > 3.19 ∧ ≤ 5.03
   - petal width (cm) > 0.85
 ⇒ target = virginica

🧬 [CLIENTE 1] Contrafactuales sugeridos:
 ⇒ Posible clase: setosa
   - sepal length (cm) ≤ 5.79
   - sepal width (cm) ≤ 2.65
   - petal length (cm) ≤ 3.19
 ⇒ Posibl

[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures



[CLIENTE 3] 🧪 Instancia a explicar:
sepal length (cm)    6.1
sepal width (cm)     2.9
petal length (cm)    4.7
petal width (cm)     1.4
dtype: object
[CLIENTE 3] 🧪 Clase real: versicolor

[CLIENTE 3] 📜 Regla desde árbol:
   - sepal length (cm) ≤ 5.61
   - petal length (cm) > 3.11
   - petal width (cm) > 1.10 ∧ ≤ 1.78
 ⇒ target = virginica

🧬 [CLIENTE 3] Contrafactuales sugeridos:
 ⇒ Posible clase: setosa
   - sepal length (cm) ≤ 5.61
   - petal length (cm) ≤ 3.11
 ⇒ Posible clase: versicolor
   - sepal length (cm) > 5.40 ∧ ≤ 5.61
   - petal length (cm) > 2.09 ∧ ≤ 3.11
   - petal width (cm) ≤ 1.10
num_cumplen:  0

📊 [CLIENTE 3] Métricas de explicación:
 - Silhouette: 0.943
 - Fidelity: 1.00
 - Coverage: 0.000
 - Precision: 0.000
 - Complexity: 3.000
 - Dissimilarity con clase setosa: 1.658
 - Dissimilarity con clase versicolor: 1.550

[SERVIDOR] 🌲 Generando SuperTree - Ronda 2


[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 2 round(s) in 118.04s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.0
[92mINFO [0m:      		round 2: 0.0
[92mINFO [0m:      
