Preparación del entorno

In [3]:
import numpy as np
import joblib
from envs.v2.multi_step_limited_attack_env import AttackEnvLimitedMultiStep
from stable_baselines3.common.callbacks import BaseCallback
from envs.v2.defender_env import DefenderEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO, SAC

Funciones auxiliares

In [None]:
class DefenderLoggerCallback(BaseCallback):

    def __init__(self, log_freq: int = 100, verbose: int = 0):
        super().__init__(verbose)
        self.log_freq = log_freq
        self.episode_count = 0

        # Buffers por ventana
        self.rewards_buffer = []
        self.det_rate_buffer = []
        self.tn_rate_buffer = []
        self.threshold_buffer = []

        # Histórico
        self.history_episodes = []
        self.history_mean_reward = []
        self.history_det_rate = []
        self.history_tn_rate = []
        self.history_mean_threshold = []

        # Estado interno del episodio
        self._reset_ep_stats()

    def _reset_ep_stats(self):
        self.ep_reward = 0.0
        self.ep_steps = 0
        self.ep_attacks = 0
        self.ep_attacks_correct = 0
        self.ep_normals = 0
        self.ep_normals_correct = 0
        self.ep_thresh_sum = 0.0


    def _on_step(self) -> bool:

        infos = self.locals["infos"]
        rewards = self.locals["rewards"]
        dones = self.locals["dones"]

        for info, r, done in zip(infos, rewards, dones):

            # Acumular reward
            self.ep_reward += float(r)
            self.ep_steps += 1

            y = info.get("y_sample")
            pred = info.get("pred")
            th = info.get("threshold")

            if y is not None and pred is not None:
                if y == 1:
                    self.ep_attacks += 1
                    if pred == 1:
                        self.ep_attacks_correct += 1
                else:
                    self.ep_normals += 1
                    if pred == 0:
                        self.ep_normals_correct += 1

            if th is not None:
                self.ep_thresh_sum += float(th)

            # Cuando termina el episodio
            if done:
                self.episode_count += 1

                # Métricas del episodio
                det_rate = (self.ep_attacks_correct / self.ep_attacks) if self.ep_attacks > 0 else 0.0
                tn_rate = (self.ep_normals_correct / self.ep_normals) if self.ep_normals > 0 else 0.0
                mean_th = self.ep_thresh_sum / self.ep_steps

                # Añadir a buffers
                self.rewards_buffer.append(self.ep_reward)
                self.det_rate_buffer.append(det_rate)
                self.tn_rate_buffer.append(tn_rate)
                self.threshold_buffer.append(mean_th)

                # LOG cada ventana
                if self.episode_count % self.log_freq == 0:

                    mean_reward = np.mean(self.rewards_buffer)
                    mean_det = np.mean(self.det_rate_buffer)
                    mean_tn = np.mean(self.tn_rate_buffer)
                    mean_th = np.mean(self.threshold_buffer)

                    # Guardar en histórico
                    self.history_episodes.append(self.episode_count)
                    self.history_mean_reward.append(mean_reward)
                    self.history_det_rate.append(mean_det)
                    self.history_tn_rate.append(mean_tn)
                    self.history_mean_threshold.append(mean_th)

                    # Log bonito
                    print(
                        f"[DefenderLogger] Episodios: {self.episode_count} | "
                        f"Reward medio: {mean_reward:.3f} | "
                        f"Detección ataques: {mean_det*100:.1f}% | "
                        f"Acierto normales: {mean_tn*100:.1f}% | "
                        f"Threshold medio: {mean_th:.3f}"
                    )

                    # limpiar buffers
                    self.rewards_buffer.clear()
                    self.det_rate_buffer.clear()
                    self.tn_rate_buffer.clear()
                    self.threshold_buffer.clear()

                # reset stats episodio
                self._reset_ep_stats()

        return True

Creación del área de trabajo

In [6]:
# Cargar datos
data = np.load("../../data/v2/synthetic_2d.npz")
X_train, X_test = data["X_train"], data["X_test"]
y_train, y_test = data["y_train"], data["y_test"]

# Cargamos las muestras normales y de ataque
normal_samples = X_train[y_train == 0]
attack_samples = X_train[y_train == 1]

# Imprimir formas de los datos cargados
print("Normal:", normal_samples.shape, "Attack:", attack_samples.shape)

Normal: (800, 2) Attack: (800, 2)


In [7]:
# Cargar clasificador clásico
clf = joblib.load("../../classifiers/v2/logreg_synthetic_2d.joblib")

# Crear AttackEnvLimitedMultiStep igual que lo hicmos para entrenar SAC
attack_env = AttackEnvLimitedMultiStep(
    attack_samples=attack_samples,
    clf=clf,
    threshold=0.5,   # Este threshold solo afecta a la reward del atacante, no al defensor
    epsilon=0.7,
    penalty=0.01,
    max_steps=5,
)

# Cargar modelo SAC atacante
attacker_model = SAC.load("../../agents/v2/sac_attacker_limited_multistep")

In [8]:
# Función para crear el entorno del defensor
def make_defender_env():
    return DefenderEnv(
        normal_samples=normal_samples,
        attack_env=attack_env,
        attacker_model=attacker_model,
        init_threshold=0.5,
        delta_max=0.1,
        min_threshold=0.05,
        max_threshold=0.95,
        attack_prob=0.5,        # mitad normal / mitad ataque
        episode_length=50,
        extremal_penalty=0.1,
    )

# Crear entorno vectorizado para el defensor
venv_def = DummyVecEnv([make_defender_env])

In [9]:
# Arquitectura de la red neuronal para el defensor
policy_kwargs = dict(net_arch=[128, 128])

defender_model = PPO(
    "MlpPolicy",
    venv_def,
    verbose=0,
    learning_rate=3e-4,
    n_steps=2048, # Número de pasos por actualización
    batch_size=64, # Grupos en los que se divide cada actualización
    policy_kwargs=policy_kwargs,
)

# Creación de callback personalizado para logging
callback = DefenderLoggerCallback(log_freq=100)

# Entrenamiento del modelo del defensor
defender_model.learn(total_timesteps=100_000, callback=callback)

# Guardar el modelo entrenado del defensor
defender_model.save("../../agents/v2/ppo_defender_threshold")

[DefenderLogger] Episodios: 100 | Reward medio: 7.886 | Detección ataques: 39.3% | Acierto normales: 60.1% | Threshold medio: 0.417
[DefenderLogger] Episodios: 200 | Reward medio: 15.316 | Detección ataques: 47.8% | Acierto normales: 47.9% | Threshold medio: 0.220
[DefenderLogger] Episodios: 300 | Reward medio: 17.725 | Detección ataques: 52.0% | Acierto normales: 44.9% | Threshold medio: 0.149
[DefenderLogger] Episodios: 400 | Reward medio: 17.339 | Detección ataques: 55.4% | Acierto normales: 44.0% | Threshold medio: 0.138
[DefenderLogger] Episodios: 500 | Reward medio: 16.257 | Detección ataques: 53.2% | Acierto normales: 45.9% | Threshold medio: 0.143
[DefenderLogger] Episodios: 600 | Reward medio: 16.936 | Detección ataques: 53.9% | Acierto normales: 44.3% | Threshold medio: 0.154
[DefenderLogger] Episodios: 700 | Reward medio: 18.037 | Detección ataques: 55.2% | Acierto normales: 44.3% | Threshold medio: 0.140
[DefenderLogger] Episodios: 800 | Reward medio: 17.475 | Detección ata

In [None]:
def evaluate_defender(model, env, n_episodes=200):
    correct_attacks = 0
    total_attacks = 0
    correct_normals = 0
    total_normals = 0

    obs = env.reset()
    for _ in range(n_episodes):
        done = False
        truncated = False
        while not (done or truncated):
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            info0 = info[0]

            y = info0["y_sample"]
            pred = info0["pred"]

            if y == 1:
                total_attacks += 1
                if pred == 1:
                    correct_attacks += 1
            else:
                total_normals += 1
                if pred == 0:
                    correct_normals += 1

            truncated = info0.get("step_count", 0) >= env.get_attr("episode_length")[0]

        obs = env.reset()

    det_rate = correct_attacks / total_attacks if total_attacks > 0 else 0.0
    tn_rate = correct_normals / total_normals if total_normals > 0 else 0.0
    return det_rate, tn_rate

det_rate, tn_rate = evaluate_defender(defender_model, venv_def, n_episodes=200)

print(f"Tasa detección ataques: {det_rate:.3f}")
print(f"Tasa acierto en tráfico normal: {tn_rate:.3f}")


Tasa detección ataques: 0.008
Tasa acierto en tráfico normal: 0.992
