In [15]:
import os
import numpy as np
import joblib
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import BaseCallback
from envs.v2.defender_env_3 import DefenderEnvV3
from envs.v2.multi_step_limited_attack_env import AttackEnvLimitedMultiStep

In [16]:
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback


class DefenderLoggerCallback(BaseCallback):
    def __init__(self, log_freq: int = 1000, verbose: int = 0):
        super().__init__(verbose)
        self.log_freq = log_freq
        self.episode_count = 0

        # Buffers por ventana
        self.rewards_buffer = []            # Recompensas de los pasos
        self.threshold_buffer = []          # Threshold global
        self.p_attack_buffer = []           # p_attack del clasificador base
        self.adv_dist_buffer = []           # Distancia adversarial (||x_adv - x_orig||)
        self.adv_prog_buffer = []           # Progreso del ataque (step/max_steps)
        self.dist_center_buffer = []        # Distancia normalizada al centro del cluster normal
        self.cluster_score_buffer = []      # 1/(1+dist_center_norm)

        # Contadores para TPR / FPR
        self.attack_count = 0               # Nº muestras de ataque
        self.tp_count = 0                   # Ataques bien detectados
        self.normal_count = 0               # Nº muestras normales
        self.fp_count = 0                   # Falsos positivos

        # Históricos para graficar luego
        self.history_episodes = []
        self.history_mean_reward = []
        self.history_detection_rate = []      # TPR
        self.history_false_positive_rate = [] # FPR
        self.history_mean_threshold = []
        self.history_mean_p_attack = []
        self.history_mean_adv_dist = []
        self.history_mean_adv_prog = []
        self.history_mean_dist_center = []
        self.history_mean_cluster_score = []

    def _on_step(self) -> bool:

        infos = self.locals["infos"]
        rewards = self.locals["rewards"]
        dones = self.locals["dones"]

        for info, r, done in zip(infos, rewards, dones):

            # Reward del paso
            self.rewards_buffer.append(float(r))

            if isinstance(info, dict):

                # Threshold global
                if "threshold" in info:
                    self.threshold_buffer.append(float(info["threshold"]))

                # Probabilidad de ataque del clasificador
                if "p_attack" in info:
                    self.p_attack_buffer.append(float(info["p_attack"]))

                # Extra de V3: distancia adversarial, progreso, etc.
                if "adv_distance" in info:
                    self.adv_dist_buffer.append(float(info["adv_distance"]))

                if "adv_progress" in info:
                    self.adv_prog_buffer.append(float(info["adv_progress"]))

                if "dist_center_norm" in info:
                    self.dist_center_buffer.append(float(info["dist_center_norm"]))

                if "cluster_score" in info:
                    self.cluster_score_buffer.append(float(info["cluster_score"]))

                # Etiqueta real y predicción del defensor
                if "y_sample" in info and "pred" in info:
                    y = int(info["y_sample"])   # 0 = normal, 1 = ataque
                    pred = int(info["pred"])    # 0 = normal, 1 = ataque

                    if y == 1:
                        self.attack_count += 1
                        if pred == 1:
                            self.tp_count += 1
                    else:
                        self.normal_count += 1
                        if pred == 1:
                            self.fp_count += 1

            # Fin de episodio
            if done:
                self.episode_count += 1

                if self.episode_count > 0 and self.episode_count % self.log_freq == 0:

                    # Reward media por paso en la ventana
                    mean_reward = np.mean(self.rewards_buffer) if self.rewards_buffer else 0.0

                    # TPR/FPR
                    detection_rate = (self.tp_count / self.attack_count) if self.attack_count > 0 else 0.0
                    false_positive_rate = (self.fp_count / self.normal_count) if self.normal_count > 0 else 0.0

                    # Stats de threshold y p_attack
                    mean_threshold = np.mean(self.threshold_buffer) if self.threshold_buffer else 0.0
                    mean_p_attack = np.mean(self.p_attack_buffer) if self.p_attack_buffer else 0.0

                    # Stats extra V3
                    mean_adv_dist = np.mean(self.adv_dist_buffer) if self.adv_dist_buffer else 0.0
                    mean_adv_prog = np.mean(self.adv_prog_buffer) if self.adv_prog_buffer else 0.0
                    mean_dist_center = np.mean(self.dist_center_buffer) if self.dist_center_buffer else 0.0
                    mean_cluster_score = np.mean(self.cluster_score_buffer) if self.cluster_score_buffer else 0.0

                    # Guardamos históricos
                    self.history_episodes.append(self.episode_count)
                    self.history_mean_reward.append(mean_reward)
                    self.history_detection_rate.append(detection_rate)
                    self.history_false_positive_rate.append(false_positive_rate)
                    self.history_mean_threshold.append(mean_threshold)
                    self.history_mean_p_attack.append(mean_p_attack)
                    self.history_mean_adv_dist.append(mean_adv_dist)
                    self.history_mean_adv_prog.append(mean_adv_prog)
                    self.history_mean_dist_center.append(mean_dist_center)
                    self.history_mean_cluster_score.append(mean_cluster_score)

                    # Log por consola
                    print(
                        f"[DefenderLoggerV3] Episodios: {self.episode_count:6d} | "
                        f"Reward media: {mean_reward: .3f} | "
                        f"TPR: {detection_rate*100:5.1f}% | "
                        f"FPR: {false_positive_rate*100:5.1f}% | "
                        f"th: {mean_threshold: .3f} | "
                        f"p_attack: {mean_p_attack: .3f} | "
                        f"adv_dist: {mean_adv_dist: .3f} | "
                        f"adv_prog: {mean_adv_prog: .3f} | "
                        f"dist_centro: {mean_dist_center: .3f}"
                    )

                    # Reset buffers ventana
                    self.rewards_buffer.clear()
                    self.threshold_buffer.clear()
                    self.p_attack_buffer.clear()
                    self.adv_dist_buffer.clear()
                    self.adv_prog_buffer.clear()
                    self.dist_center_buffer.clear()
                    self.cluster_score_buffer.clear()
                    self.attack_count = 0
                    self.tp_count = 0
                    self.normal_count = 0
                    self.fp_count = 0

        return True


In [17]:
# ----------------------------------------------------------
# 2. Cargar dataset y atacante
# ----------------------------------------------------------

data = np.load("../../data/v2/synthetic_2d.npz")
X_train, X_test = data["X_train"], data["X_test"]
y_train, _ = data["y_train"], data["y_test"]

normal_samples = X_train[y_train == 0]
attack_samples = X_train[y_train == 1]

clf = joblib.load("../../classifiers/v2/logreg_synthetic_2d.joblib")

attack_env = AttackEnvLimitedMultiStep(
    attack_samples=attack_samples,
    clf=clf,
    threshold=0.5,
    epsilon=0.7,
    penalty=0.01,
    max_steps=5,
)

attacker_model = SAC.load("../../agents/v2/sac_attacker_limited_multistep")

In [18]:
# ----------------------------------------------------------
# 3. Crear entorno DEFENSOR V3
# ----------------------------------------------------------

def make_defender_env():
    return DefenderEnvV3(
        normal_samples=normal_samples,
        attack_env=attack_env,
        attacker_model=attacker_model,
        init_threshold=0.5,
        delta_max=0.05,
        min_threshold=0.2,
        max_threshold=0.8,
        attack_prob=0.3,
        episode_length=80,
        tp_reward=3.0,
        fn_penalty=-6.0,
        tn_reward=1.0,
        fp_penalty=-2.0,
        move_penalty=0.1,
        sensitivity_penalty=0.05,
        extreme_penalty=0.5,
    )


venv_def = DummyVecEnv([make_defender_env])

# Normalizar observaciones
venv_def = VecNormalize(venv_def, norm_obs=True, norm_reward=False)

In [19]:
# ----------------------------------------------------------
# 4. SAC para el DEFENSOR V3
# ----------------------------------------------------------

policy_kwargs = dict(net_arch=[256, 256])

model_defender = SAC(
    "MlpPolicy",
    venv_def,
    learning_rate=3e-4,
    buffer_size=400_000,
    batch_size=256,
    tau=0.02,
    train_freq=1,
    gradient_steps=1,
    gamma=0.99,
    policy_kwargs=policy_kwargs,
    verbose=0,
)

In [20]:
# ----------------------------------------------------------
# 5. Entrenamiento
# ----------------------------------------------------------
callback = DefenderLoggerCallback(log_freq=100)

model_defender.learn(total_timesteps=100_000, callback=callback)

[DefenderLoggerV3] Episodios:    100 | Reward media:  0.734 | TPR:  27.4% | FPR:  28.0% | th:  0.400 | p_attack:  0.255 | adv_dist:  0.000 | adv_prog:  0.000 | dist_centro:  0.000
[DefenderLoggerV3] Episodios:    200 | Reward media:  0.908 | TPR:  29.3% | FPR:  29.7% | th:  0.347 | p_attack:  0.255 | adv_dist:  0.000 | adv_prog:  0.000 | dist_centro:  0.000
[DefenderLoggerV3] Episodios:    300 | Reward media:  0.937 | TPR:  29.5% | FPR:  31.6% | th:  0.347 | p_attack:  0.261 | adv_dist:  0.000 | adv_prog:  0.000 | dist_centro:  0.000
[DefenderLoggerV3] Episodios:    400 | Reward media:  0.957 | TPR:  29.6% | FPR:  30.3% | th:  0.336 | p_attack:  0.254 | adv_dist:  0.000 | adv_prog:  0.000 | dist_centro:  0.000
[DefenderLoggerV3] Episodios:    500 | Reward media:  0.991 | TPR:  30.8% | FPR:  30.1% | th:  0.336 | p_attack:  0.254 | adv_dist:  0.000 | adv_prog:  0.000 | dist_centro:  0.000
[DefenderLoggerV3] Episodios:    600 | Reward media:  0.989 | TPR:  28.8% | FPR:  30.6% | th:  0.331

<stable_baselines3.sac.sac.SAC at 0x1510586aad0>

In [None]:
# ----------------------------------------------------------
# 6. Guardar modelo + normalizador
# ----------------------------------------------------------
os.makedirs("../../agents/v2/", exist_ok=True)

model_defender.save("../../agents/v2/sac_defender_3")
venv_def.save("../../agents/v2/sac_defender_3_vecnorm.pkl")

print("Modelo V3 guardado correctamente.")

Modelo V3 guardado correctamente.


In [24]:
# ----------------------------------------------------------
# 7. Evaluación rápida
# ----------------------------------------------------------

# Preparar entorno para evaluación
eval_env = DummyVecEnv([make_defender_env])
eval_env = VecNormalize.load("../../agents/v2/sac_defender_3_vecnorm.pkl", eval_env)
eval_env.training = False
eval_env.norm_reward = False

obs = eval_env.reset()

TP = 0
FP = 0
TA = 0  # ataques totales
TN = 0
NO = 0  # normales totales

for _ in range(200):
    action, _ = model_defender.predict(obs, deterministic=True)
    obs, reward, dones, infos = eval_env.step(action)

    info = infos[0]
    y = info["y_sample"]
    pred = info["pred"]

    if y == 1:
        TA += 1
        TP += 1 if pred == 1 else 0
    else:
        NO += 1
        FP += 1 if pred == 1 else 0
        TN += 1 if pred == 0 else 0

TPR = TP / TA if TA > 0 else 0
FPR = FP / NO if NO > 0 else 0

print("\n===== Evaluación rápida =====")
print(f"TPR: {TPR*100:.2f}%")
print(f"FPR: {FPR*100:.2f}%")


===== Evaluación rápida =====
TPR: 31.25%
FPR: 30.15%
