Preparación del entorno

In [9]:
import numpy as np
import joblib
from envs.multi_step_attack_env import AttackEnvMultiStep
from stable_baselines3.common.callbacks import BaseCallback
from envs.defender_env import DefenderEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO, SAC

Funciones auxiliares

In [13]:
class DefenderLoggerCallback(BaseCallback):

    def __init__(self, log_freq: int = 100, verbose: int = 0):
        super().__init__(verbose)
        self.log_freq = log_freq
        self.episode_count = 0

        # Buffers por ventana
        self.rewards_buffer = []
        self.det_rate_buffer = []
        self.tn_rate_buffer = []
        self.threshold_buffer = []

        # Histórico
        self.history_episodes = []
        self.history_mean_reward = []
        self.history_det_rate = []
        self.history_tn_rate = []
        self.history_mean_threshold = []

        # Estado interno del episodio
        self._reset_ep_stats()

    def _reset_ep_stats(self):
        self.ep_reward = 0.0
        self.ep_steps = 0
        self.ep_attacks = 0
        self.ep_attacks_correct = 0
        self.ep_normals = 0
        self.ep_normals_correct = 0
        self.ep_thresh_sum = 0.0


    def _on_step(self) -> bool:

        infos = self.locals["infos"]
        rewards = self.locals["rewards"]
        dones = self.locals["dones"]

        for info, r, done in zip(infos, rewards, dones):

            # Acumular reward
            self.ep_reward += float(r)
            self.ep_steps += 1

            y = info.get("y_sample")
            pred = info.get("pred")
            th = info.get("threshold")

            if y is not None and pred is not None:
                if y == 1:
                    self.ep_attacks += 1
                    if pred == 1:
                        self.ep_attacks_correct += 1
                else:
                    self.ep_normals += 1
                    if pred == 0:
                        self.ep_normals_correct += 1

            if th is not None:
                self.ep_thresh_sum += float(th)

            # Cuando termina el episodio
            if done:
                self.episode_count += 1

                # Métricas del episodio
                det_rate = (self.ep_attacks_correct / self.ep_attacks) if self.ep_attacks > 0 else 0.0
                tn_rate = (self.ep_normals_correct / self.ep_normals) if self.ep_normals > 0 else 0.0
                mean_th = self.ep_thresh_sum / self.ep_steps

                # Añadir a buffers
                self.rewards_buffer.append(self.ep_reward)
                self.det_rate_buffer.append(det_rate)
                self.tn_rate_buffer.append(tn_rate)
                self.threshold_buffer.append(mean_th)

                # LOG cada ventana
                if self.episode_count % self.log_freq == 0:

                    mean_reward = np.mean(self.rewards_buffer)
                    mean_det = np.mean(self.det_rate_buffer)
                    mean_tn = np.mean(self.tn_rate_buffer)
                    mean_th = np.mean(self.threshold_buffer)

                    # Guardar en histórico
                    self.history_episodes.append(self.episode_count)
                    self.history_mean_reward.append(mean_reward)
                    self.history_det_rate.append(mean_det)
                    self.history_tn_rate.append(mean_tn)
                    self.history_mean_threshold.append(mean_th)

                    # Log bonito
                    print(
                        f"[DefenderLogger] Episodios: {self.episode_count} | "
                        f"Reward medio: {mean_reward:.3f} | "
                        f"Detección ataques: {mean_det*100:.1f}% | "
                        f"Acierto normales: {mean_tn*100:.1f}% | "
                        f"Threshold medio: {mean_th:.3f}"
                    )

                    # limpiar buffers
                    self.rewards_buffer.clear()
                    self.det_rate_buffer.clear()
                    self.tn_rate_buffer.clear()
                    self.threshold_buffer.clear()

                # reset stats episodio
                self._reset_ep_stats()

        return True


Creación del área de trabajo

In [None]:
# Cargar datos
data = np.load("../data/synthetic_2d.npz")
X_train, X_test = data["X_train"], data["X_test"]
y_train, y_test = data["y_train"], data["y_test"]

# Cargamos las muestras normales y de ataque
normal_samples = X_train[y_train == 0]
attack_samples = X_train[y_train == 1]

# Imprimir formas de los datos cargados
print("Normal:", normal_samples.shape, "Attack:", attack_samples.shape)

Normal: (800, 2) Attack: (800, 2)


In [None]:
# Cargar clasificador clásico
clf = joblib.load("../classifiers/logreg_synthetic_2d.joblib")

# Crear AttackEnvMultiStep igual que lo hicmos para entrenar SAC
attack_env = AttackEnvMultiStep(
    attack_samples=attack_samples,
    clf=clf,
    threshold=0.5,   # Este threshold solo afecta a la reward del atacante, no al defensor
    epsilon=0.7,
    penalty=0.01,
    max_steps=5,
)

# Cargar modelo SAC atacante
attacker_model = SAC.load("../agents/sac_attacker_multistep")

In [None]:
# Función para crear el entorno del defensor
def make_defender_env():
    return DefenderEnv(
        normal_samples=normal_samples,
        attack_env=attack_env,
        attacker_model=attacker_model,
        init_threshold=0.5,
        delta_max=0.1,
        min_threshold=0.05,
        max_threshold=0.95,
        attack_prob=0.5,        # mitad normal / mitad ataque
        episode_length=50,
        extremal_penalty=0.1,
    )

# Crear entorno vectorizado para el defensor
venv_def = DummyVecEnv([make_defender_env])


In [None]:
policy_kwargs = dict(net_arch=[128, 128])

defender_model = PPO(
    "MlpPolicy",
    venv_def,
    verbose=0,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    policy_kwargs=policy_kwargs,
)

callback = DefenderLoggerCallback(log_freq=100)

defender_model.learn(total_timesteps=100_000, callback=callback)
defender_model.save("../agents/ppo_defender_threshold")


[DefenderLogger] Episodios: 100 | Reward medio: -39.643 | Detección ataques: 1.2% | Acierto normales: 99.2% | Threshold medio: 0.484
[DefenderLogger] Episodios: 200 | Reward medio: -38.690 | Detección ataques: 1.2% | Acierto normales: 99.0% | Threshold medio: 0.447
[DefenderLogger] Episodios: 300 | Reward medio: -41.617 | Detección ataques: 1.5% | Acierto normales: 98.9% | Threshold medio: 0.466
[DefenderLogger] Episodios: 400 | Reward medio: -40.107 | Detección ataques: 0.7% | Acierto normales: 98.9% | Threshold medio: 0.526
[DefenderLogger] Episodios: 500 | Reward medio: -38.919 | Detección ataques: 0.9% | Acierto normales: 99.4% | Threshold medio: 0.596
[DefenderLogger] Episodios: 600 | Reward medio: -39.866 | Detección ataques: 0.6% | Acierto normales: 99.4% | Threshold medio: 0.648
[DefenderLogger] Episodios: 700 | Reward medio: -38.708 | Detección ataques: 0.7% | Acierto normales: 99.1% | Threshold medio: 0.629


In [6]:
def evaluate_defender(model, env, n_episodes=200):
    correct_attacks = 0
    total_attacks = 0
    correct_normals = 0
    total_normals = 0

    obs = env.reset()
    for _ in range(n_episodes):
        done = False
        truncated = False
        while not (done or truncated):
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            info0 = info[0]

            y = info0["y_sample"]
            pred = info0["pred"]

            if y == 1:
                total_attacks += 1
                if pred == 1:
                    correct_attacks += 1
            else:
                total_normals += 1
                if pred == 0:
                    correct_normals += 1

            truncated = info0.get("step_count", 0) >= env.get_attr("episode_length")[0]

        obs = env.reset()

    det_rate = correct_attacks / total_attacks if total_attacks > 0 else 0.0
    tn_rate = correct_normals / total_normals if total_normals > 0 else 0.0
    return det_rate, tn_rate

det_rate, tn_rate = evaluate_defender(defender_model, venv_def, n_episodes=200)

print(f"Tasa detección ataques: {det_rate:.3f}")
print(f"Tasa acierto en tráfico normal: {tn_rate:.3f}")


Tasa detección ataques: 0.008
Tasa acierto en tráfico normal: 0.992
