In [1]:
# =======================
# 📦 IMPORTACIONES
# =======================

# Built-in
import os
import sys
import time
import json
import random
import warnings
from typing import List, Tuple, Dict
import operator

# NumPy, Pandas, Matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Sklearn
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import (
    log_loss, accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score, pairwise_distances
)
from sklearn.exceptions import NotFittedError

# Flower
from flwr.client import ClientApp, NumPyClient
from flwr.common import (
    Context, NDArrays, Metrics, Scalar,
    ndarrays_to_parameters, parameters_to_ndarrays
)
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# LORE
from lore_sa.dataset import TabularDataset
from lore_sa.bbox import sklearn_classifier_bbox
from lore_sa.encoder_decoder import ColumnTransformerEnc
from lore_sa.lore import TabularGeneticGeneratorLore
from lore_sa.surrogate.decision_tree import SuperTree
from lore_sa.rule import Expression, Rule

# Otros
from graphviz import Digraph


2025-06-17 09:25:35,897	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-06-17 09:25:41,239 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.piping.pipe(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-06-17 09:25:41,242 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.rendering.render(['renderer', 'formatter', 'neato_no_op', 'quiet'])
2025-06-17 09:25:41,245 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.unflattening.unflatten(['stagger', 'fanout', 'chain', 'encoding'])
2025-06-17 09:25:41,247 graphviz._tools DEBUG    deprecate positional args: graphviz.backend.viewing.view(['quiet'])
2025-06-17 09:25:41,255 graphviz._tools DEBUG    deprecate positional args: graphviz.quoting.quote(['is_html_string', 'is_valid_id', 'dot_keywords', 'endswith_odd_number_of_backslashes', 'escape_unescaped

In [2]:
# =======================
# ⚙️ VARIABLES GLOBALES
# =======================
UNIQUE_LABELS = []
FEATURES = []
NUM_SERVER_ROUNDS = 2
NUM_CLIENTS = 2
MIN_AVAILABLE_CLIENTS = NUM_CLIENTS
fds = None  # Cache del FederatedDataset
CAT_ENCODINGS = {}
USING_DATASET = None





class Net(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(Net, self).__init__()
        hidden_dim = max(8, input_dim * 2)  # algo proporcional

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    

class TorchNNWrapper:
    def __init__(self, model):
        self.model = model
        self.model.eval()

    def predict(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            return outputs.argmax(dim=1).numpy()

    def predict_proba(self, X):
        X = np.array(X, dtype=np.float32)
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            outputs = self.model(X_tensor)
            probs = F.softmax(outputs, dim=1)
            return probs.numpy()

# =======================
# 🔧 UTILIDADES MODELO
# =======================

def get_model_parameters(tree_model, nn_model):
    tree_params = [
        int(tree_model.get_params()["max_depth"] or -1),
        int(tree_model.get_params()["min_samples_split"]),
        int(tree_model.get_params()["min_samples_leaf"]),
    ]
    nn_weights = [v.cpu().detach().numpy() for v in nn_model.state_dict().values()]
    return {
        "tree": tree_params,
        "nn": nn_weights,
    }


def set_model_params(tree_model, nn_model, params):
    tree_params = params["tree"]
    nn_weights = params["nn"]

    # Solo si tree_model no es None y tiene set_params
    if tree_model is not None and hasattr(tree_model, "set_params"):
        max_depth = tree_params[0] if tree_params[0] > 0 else None
        tree_model.set_params(
            max_depth=max_depth,
            min_samples_split=tree_params[1],
            min_samples_leaf=tree_params[2],
        )

    # Actualizar pesos de la red neuronal
    state_dict = nn_model.state_dict()
    for (key, _), val in zip(state_dict.items(), nn_weights):
        state_dict[key] = torch.tensor(val)
    nn_model.load_state_dict(state_dict)


# =======================
# 📥 CARGAR DATOS
# =======================

def load_data_general(flower_dataset_name: str, class_col: str, partition_id: int, num_partitions: int):
    global fds, UNIQUE_LABELS, FEATURES

    if fds is None:
        partitioner = IidPartitioner(num_partitions=num_partitions)
        fds = FederatedDataset(dataset=flower_dataset_name, partitioners={"train": partitioner})

    dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]

    if "adult_small" in flower_dataset_name.lower():
        drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset = dataset[~dataset["workclass"].isin([" ?"])]
        dataset = dataset[~dataset["occupation"].isin([" ?"])]

    elif "churn" in flower_dataset_name.lower():
        drop_cols = ['customerID', 'TotalCharges']
        dataset.drop(columns=[col for col in drop_cols if col in dataset.columns], inplace=True)
        dataset['MonthlyCharges'] = pd.to_numeric(dataset['MonthlyCharges'], errors='coerce')
        dataset['tenure'] = pd.to_numeric(dataset['tenure'], errors='coerce')
        dataset.dropna(subset=['MonthlyCharges', 'tenure'], inplace=True)


    for col in dataset.select_dtypes(include=["object"]).columns:
        if dataset[col].nunique() < 50:
            dataset[col] = dataset[col].astype("category")

    class_original = dataset[class_col].copy()

    tabular_dataset = TabularDataset(dataset.copy(), class_name=class_col)
    descriptor = tabular_dataset.descriptor

    # AÑADIR DISTINCT_VALUES si falta en categóricas
    for col, info in descriptor["categorical"].items():
        if "distinct_values" not in info:
            info["distinct_values"] = list(dataset[col].dropna().unique())

    label_encoder = LabelEncoder()
    dataset[class_col] = label_encoder.fit_transform(dataset[class_col])

    dataset.rename(columns={class_col: "class"}, inplace=True)
    y = dataset["class"].reset_index(drop=True).to_numpy()

    if not UNIQUE_LABELS:
        UNIQUE_LABELS[:] = label_encoder.classes_.tolist()

    numeric_features = list(descriptor["numeric"].keys())
    categorical_features = list(descriptor["categorical"].keys())
    FEATURES[:] = numeric_features + categorical_features

    numeric_indices = list(range(len(numeric_features)))
    categorical_indices = list(range(len(numeric_features), len(FEATURES)))

    X_array = dataset[FEATURES].to_numpy()

    preprocessor = ColumnTransformer([
        ("num", StandardScaler(), numeric_indices),
        ("cat", OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1), categorical_indices)
    ])

    X_encoded = preprocessor.fit_transform(X_array)

    encoder = ColumnTransformerEnc(descriptor)
    feature_names = list(encoder.encoded_features.values())

    split_idx = int(0.8 * len(X_encoded))
    return (
        X_encoded[:split_idx], y[:split_idx],
        X_encoded[split_idx:], y[split_idx:],
        tabular_dataset, feature_names, label_encoder,
        preprocessor.named_transformers_["num"], numeric_features, encoder, preprocessor
    )

# =======================



DATASET_NAME = "pablopalacios23/adult_small"
CLASS_COLUMN = "class"


# DATASET_NAME = "pablopalacios23/Iris"
# CLASS_COLUMN = "target"


# DEMASIADO GRANDE EL DATASET :/
# DATASET_NAME = "pablopalacios23/churn"
# CLASS_COLUMN = "Churn" 
 

# =======================


# load_data_general(DATASET_NAME, CLASS_COLUMN, partition_id=0, num_partitions=NUM_CLIENTS)

In [3]:
X_train, y_train, X_test, y_test, dataset, feature_names, label_encoder, scaler, numeric_features, encoder, preprocessor = load_data_general(
    DATASET_NAME, CLASS_COLUMN, partition_id=1, num_partitions=2
)

# Mostrar 5 primeros valores
print("\n📦 X_train (primeras filas):")
print(pd.DataFrame(X_train))

print("\n🎯 y_train (primeros valores):")
print(y_train)

print("\n📦 X_test (primeras filas):")
print(pd.DataFrame(X_test))

print("\n🎯 y_test (primeros valores):")
print(y_test)


2025-06-17 09:25:41,315 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2025-06-17 09:25:41,492 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/main/README.md HTTP/11" 404 0
2025-06-17 09:25:41,823 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/pablopalacios23/adult_small HTTP/11" 200 612
2025-06-17 09:25:41,950 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /datasets/pablopalacios23/adult_small/resolve/475f19aed5f80dea1d48deab705f11928fe27493/adult_small.py HTTP/11" 404 0
2025-06-17 09:25:41,954 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2025-06-17 09:25:42,261 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/pablopalacios23/adult_small/pablopalacios23/adult_small.py HTTP/11" 404 0
2025-06-17 09:25:42,386 urllib3.connectionpool DEBUG


📦 X_train (primeras filas):
           0         1    2    3    4    5    6    7    8    9
0  -1.789634  0.018802  2.0  3.0  2.0  4.0  1.0  2.0  0.0  2.0
1  -1.318677  0.808484  2.0  4.0  0.0  2.0  1.0  2.0  0.0  2.0
2  -0.753530 -0.112812  2.0  3.0  1.0  7.0  0.0  1.0  1.0  2.0
3   0.941912  1.203324  3.0  4.0  1.0  7.0  2.0  2.0  0.0  0.0
4   0.376765 -0.112812  2.0  4.0  1.0  8.0  0.0  2.0  1.0  2.0
5  -0.094191 -0.112812  2.0  1.0  0.0  3.0  1.0  2.0  0.0  2.0
6   1.789634 -1.757982  2.0  2.0  1.0  7.0  0.0  2.0  1.0  2.0
7   0.376765  1.861392  2.0  3.0  1.0  9.0  0.0  2.0  1.0  2.0
8   0.000000 -0.112812  1.0  2.0  1.0  6.0  0.0  2.0  1.0  2.0
9   0.470956  0.216222  2.0  2.0  0.0  5.0  1.0  2.0  1.0  2.0
10  1.695442 -2.218629  4.0  4.0  4.0  1.0  1.0  2.0  0.0  2.0

🎯 y_train (primeros valores):
[0 0 1 1 1 0 0 0 1 0 0]

📦 X_test (primeras filas):
          0         1    2    3    4    5    6    7    8    9
0 -0.470956 -0.112812  0.0  2.0  1.0  0.0  2.0  0.0  0.0  1.0
1 -0.470

In [5]:
import pandas as pd

def decode_X(X, preprocessor, numeric_features, encoder):
    num_scaler = preprocessor.named_transformers_["num"]
    cat_encoder = preprocessor.named_transformers_["cat"]

    # Separar partes numéricas y categóricas
    X_num = X[:, :len(numeric_features)]
    X_cat = X[:, len(numeric_features):]

    # Desescalar numéricas
    X_num_inv = num_scaler.inverse_transform(X_num)

    # Decodificar categóricas a índices enteros
    X_cat_inv = X_cat.astype(int)

    # Crear DataFrame
    df_num = pd.DataFrame(X_num_inv, columns=numeric_features)
    df_cat = pd.DataFrame(X_cat_inv, columns=list(encoder.dataset_descriptor["categorical"].keys()))

    # Mapear índices a valores originales
    for col in df_cat.columns:
        valores = encoder.dataset_descriptor["categorical"][col]["distinct_values"]
        df_cat[col] = df_cat[col].map(dict(enumerate(valores)))

    print(encoder.dataset_descriptor["categorical"]["occupation"]["distinct_values"])

    return pd.concat([df_num, df_cat], axis=1)


df_decoded = decode_X(X_train, preprocessor, numeric_features, encoder)
print(df_decoded.head())

[' Other-service', ' Farming-fishing', ' Sales', ' Tech-support', ' Handlers-cleaners', ' Transport-moving', ' Protective-serv', ' Prof-specialty', ' Craft-repair', ' Adm-clerical']
    age  hours-per-week          workclass   education       marital-status  \
0  27.0            42.0          Local-gov   Bachelors   Married-civ-spouse   
1  32.0            54.0          Local-gov         9th        Never-married   
2  38.0            40.0          Local-gov   Bachelors             Divorced   
3  56.0            60.0   Self-emp-not-inc         9th             Divorced   
4  50.0            40.0          Local-gov         9th             Divorced   

           occupation    relationship                 race      sex  \
0   Handlers-cleaners         Husband   Asian-Pac-Islander   Female   
1               Sales         Husband   Asian-Pac-Islander   Female   
2      Prof-specialty   Not-in-family                Black     Male   
3      Prof-specialty            Wife   Asian-Pac-Islander 

In [5]:
tree_model = DecisionTreeClassifier(max_depth=5, min_samples_split=2, random_state=42)

tree_model.fit(X_train, y_train)

In [6]:
from sklearn.tree import export_text

print(export_text(tree_model, feature_names=FEATURES))

|--- occupation <= 5.50
|   |--- class: 0
|--- occupation >  5.50
|   |--- occupation <= 8.50
|   |   |--- age <= 1.37
|   |   |   |--- class: 1
|   |   |--- age >  1.37
|   |   |   |--- class: 0
|   |--- occupation >  8.50
|   |   |--- class: 0



In [7]:
import re

def print_tree_human_readable(tree, feature_names, numeric_features, scaler, encoder):
    tree_text = export_text(tree, feature_names=feature_names)
    lines = tree_text.splitlines()

    for line in lines:
        stripped = line.strip()
        if "class:" in stripped:
            print(line)
            continue

        for feature in feature_names:
            if feature in stripped:
                match = re.search(rf"{re.escape(feature)}\s*(<=|<|>|>=)\s*([\d\.]+)", stripped)
                if match:
                    op, val = match.groups()
                    if feature in numeric_features:
                        idx = numeric_features.index(feature)
                        real_val = float(val) * scaler.scale_[idx] + scaler.mean_[idx]
                        new_condition = f'{feature} {op} {real_val:.2f}'
                    else:
                        val = int(float(val))
                        valores = encoder.dataset_descriptor["categorical"][feature]["distinct_values"]
                        real_cat = valores[val] if val < len(valores) else f"[desconocido ({val})]"
                        new_condition = f'{feature} {op} "{real_cat}"'
                    
                    # Reemplazar solo la condición exacta
                    line = re.sub(rf"{re.escape(feature)}\s*(<=|<|>|>=)\s*[\d\.]+", new_condition, line)
                print(line)
                break


In [8]:
print_tree_human_readable(tree_model, FEATURES, numeric_features, scaler, encoder)

|--- occupation <= " Transport-moving"
|   |--- class: 0
|--- occupation > " Transport-moving"
|   |--- occupation <= " Craft-repair"
|   |   |--- age <= 60.54
|   |   |   |--- class: 1
|   |   |--- age > 60.54
|   |   |   |--- class: 0
|   |--- occupation > " Craft-repair"
|   |   |--- class: 0


In [9]:
supertree = SuperTree()
root_node = supertree.rec_buildTree(tree_model, list(range(X_train.shape[1])), 2)
round_number = "1"

from graphviz import Digraph
import numpy as np
import os

def save_tree_image(root_node, round_number, feature_names, numeric_features, scaler, unique_labels, encoder, output_folder="Arbol_Local"):
    dot = Digraph()
    node_id = [0]

    def base_name(feat):
        return feat.split('=')[0] if '=' in feat else feat

    def add_node(node, parent=None, edge_label=""):
        curr = str(node_id[0])
        node_id[0] += 1

        # Etiqueta del nodo
        if node.is_leaf:
            class_index = np.argmax(node.labels)
            class_label = unique_labels[class_index]
            label = f"class: {class_label}\n{node.labels}"
        else:
            try:
                fname = feature_names[node.feat]
                label = base_name(fname)
            except:
                label = f"X_{node.feat}"

        dot.node(curr, label)
        if parent:
            dot.edge(parent, curr, label=edge_label)

        # Árbol tipo SuperTree
        if hasattr(node, "children") and node.children is not None and hasattr(node, "intervals"):
            for i, child in enumerate(node.children):
                try:
                    fname = feature_names[node.feat]
                except:
                    fname = f"X_{node.feat}"

                original_feat = base_name(fname)
                if original_feat in encoder.dataset_descriptor["categorical"]:
                    val_idx = node.intervals[i] if i == 0 else node.intervals[i - 1]
                    val_idx = int(val_idx)
                    val = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"][val_idx] if val_idx < len(encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]) else f"desconocido({val_idx})"
                    edge = f'≠ "{val}"' if i == 0 else f'= "{val}"'
                elif original_feat in numeric_features:
                    idx = numeric_features.index(original_feat)
                    mean = scaler.mean_[idx]
                    std = scaler.scale_[idx]
                    val = node.intervals[i] if i == 0 else node.intervals[i - 1]
                    val = val * std + mean
                    edge = f"<= {val:.2f}" if i == 0 else f"> {val:.2f}"
                else:
                    edge = "?"

                add_node(child, curr, edge)

        elif hasattr(node, "_left_child") or hasattr(node, "_right_child"):
            try:
                fname = feature_names[node.feat]
            except:
                fname = f"X_{node.feat}"

            original_feat = base_name(fname)
            if original_feat in encoder.dataset_descriptor["categorical"]:
                val_idx = int(node.thresh)
                val = encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"][val_idx] if val_idx < len(encoder.dataset_descriptor["categorical"][original_feat]["distinct_values"]) else f"desconocido({val_idx})"
                left_label = f'≠ "{val}"'
                right_label = f'= "{val}"'
            elif original_feat in numeric_features:
                idx = numeric_features.index(original_feat)
                mean = scaler.mean_[idx]
                std = scaler.scale_[idx]
                thresh = node.thresh * std + mean
                left_label = f"<= {thresh:.2f}"
                right_label = f"> {thresh:.2f}"
            else:
                left_label = "≤ ?"
                right_label = "> ?"

            if node._left_child:
                add_node(node._left_child, curr, left_label)
            if node._right_child:
                add_node(node._right_child, curr, right_label)

    add_node(root_node)
    folder = f"Ronda_{round_number}/{output_folder}"
    os.makedirs(folder, exist_ok=True)
    filepath = f"{folder}/arbol_ronda_{round_number}"
    dot.render(filepath, format="png", cleanup=True)
    return f"{filepath}.png"

# Example call:
save_tree_image(
    root_node,
    round_number="1",
    feature_names=FEATURES,
    numeric_features=numeric_features,
    scaler=scaler,
    unique_labels=UNIQUE_LABELS,
    encoder=encoder
)


2025-06-16 09:53:08,681 graphviz._tools DEBUG    os.makedirs('Ronda_1/Arbol_Local')
2025-06-16 09:53:08,682 graphviz.saving DEBUG    write lines to 'Ronda_1/Arbol_Local/arbol_ronda_1'
2025-06-16 09:53:08,682 graphviz.backend.execute DEBUG    run [WindowsPath('dot'), '-Kdot', '-Tpng', '-O', 'arbol_ronda_1']


2025-06-16 09:53:08,786 graphviz.rendering DEBUG    delete 'Ronda_1/Arbol_Local/arbol_ronda_1'


'Ronda_1/Arbol_Local/arbol_ronda_1.png'