In [2]:
# -*- coding: utf-8 -*-
"""
OOP Synthetic Traffic Generator + large preprocessing (Pandas) + PySpark-equivalent pipeline
Run as script. If pyspark available, pyspark path will be shown and can be used.

REFRACTOR NOTES (Jan 2026)

Goals (without changing observable behavior: schema/labels/attack meaning):
- `produce()` is O(1), single-sample, no Python loops, no rejection/correction.
- Batch generation is vectorized and can scale to millions of rows.
- Clear separation:
  - Timestamp sequencing
  - Attack sampling
  - Feature sampling (attack semantics + data drift)
  - Oracle/labeling (normal vs concept drift decision)

Correctness guarantees checklist (by construction):
- Same feature schema + attack labels.
- Monotonic, block-based timestamps.
- Supports `trend in {'normal','data_drift','concept_drift'}`.
- Reproducible via the injected RNG (no global randomness).
- No rejection sampling and no "call oracle to fix samples".
"""

import random
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, accuracy_score
import warnings, json, os
warnings.filterwarnings("ignore")

# reproducibilidad
SEED = 42
RNG = np.random.RandomState(SEED)
random.seed(SEED)
np.random.seed(SEED)

# -------------------------
# Config global / schema
# -------------------------
SCHEMA = {
    'src_port': {'type': 'int', 'range': (1024, 65535)},
    'dst_port': {'type': 'int', 'range': (1, 65535)},
    'protocol': {'type': 'cat', 'vals': ['TCP', 'UDP', 'ICMP']},
    'packet_count': {'type': 'int', 'range': (1, 2000)},
    'conn_state': {'type': 'cat', 'vals': ['EST', 'SYN', 'FIN', 'RST']},
    'bytes_transferred': {'type': 'float', 'range': (0.0, 2e6)},
    'timestamp': {'type': 'time'}
}
BASE_FEATURES = list(SCHEMA.keys())

NUM_CLASSES = 6
ATTACK_LABELS = {0:'Normal',1:'DoS',2:'Probe',3:'R2L',4:'U2R',5:'Worm'}
ATTACK_TO_ID = {v:k for k,v in ATTACK_LABELS.items()}
ATTACK_PRIORS = {'Normal':0.942,'DoS':0.02,'Probe':0.01,'R2L':0.015,'U2R':0.008,'Worm':0.005}
ATTACK_NAMES = list(ATTACK_PRIORS.keys())
ATTACK_PROBS = list(ATTACK_PRIORS.values())

_PROTO_VALS = np.array(SCHEMA['protocol']['vals'], dtype=object)
_CONN_VALS = np.array(SCHEMA['conn_state']['vals'], dtype=object)


# -------------------------
# Fast categorical sampling (Alias method)
# -------------------------
class _AliasSampler:
    """O(1) discrete sampler after O(K) preprocessing."""

    def __init__(self, probs, rng: np.random.RandomState):
        p = np.asarray(probs, dtype=np.float64)
        if p.ndim != 1:
            raise ValueError("probs must be a 1D array")
        if np.any(p < 0):
            raise ValueError("probs must be non-negative")
        s = float(p.sum())
        if not np.isfinite(s) or s <= 0:
            raise ValueError("probs must sum to a positive finite value")
        p = p / s

        self.rng = rng
        self.K = int(p.size)
        self.prob = np.empty(self.K, dtype=np.float64)
        self.alias = np.empty(self.K, dtype=np.int32)

        scaled = p * self.K
        small = []
        large = []
        for i in range(self.K):
            (small if scaled[i] < 1.0 else large).append(i)

        while small and large:
            sidx = small.pop()
            lidx = large.pop()
            self.prob[sidx] = scaled[sidx]
            self.alias[sidx] = lidx
            scaled[lidx] = (scaled[lidx] + scaled[sidx]) - 1.0
            (small if scaled[lidx] < 1.0 else large).append(lidx)

        for idx in large:
            self.prob[idx] = 1.0
            self.alias[idx] = idx
        for idx in small:
            self.prob[idx] = 1.0
            self.alias[idx] = idx

    def sample_one(self) -> int:
        k = int(self.rng.randint(0, self.K))
        return k if float(self.rng.random()) < float(self.prob[k]) else int(self.alias[k])

    def sample_n(self, n: int) -> np.ndarray:
        n = int(n)
        k = self.rng.randint(0, self.K, size=n)
        u = self.rng.random(size=n)
        return np.where(u < self.prob[k], k, self.alias[k]).astype(np.int32, copy=False)


# -------------------------
# SyntheticTrafficGenerator (refactored, same external API)
# -------------------------
class SyntheticTrafficGenerator:
    """Rule-based traffic generator with explicit drift modes.

    External behavior preserved:
    - Same schema (src_port, dst_port, protocol, packet_count, conn_state, bytes_transferred, timestamp)
    - Same labels (0..5 with ATTACK_LABELS mapping)
    - Same `produce()` output JSON schema
    - Same `generate_dataset()` return shape/columns

    Drift semantics:
    - normal: baseline distributions + deterministic label = sampled attack id
    - data_drift: feature distributions shift; label semantics unchanged
    - concept_drift: features sampled like normal; label assigned by concept oracle (fX_concept)

    Performance:
    - `produce()` is O(1), has no loops and no rejection/correction.
    - `generate_dataset()` is vectorized.
    """

    def __init__(
        self,
        start_ts: str = "2026-01-12 18:00:00",
        epsilon_seconds: int = 60,
        rng: np.random.RandomState = None,
    ):
        self.start_ts = pd.to_datetime(start_ts)
        self.epsilon = int(epsilon_seconds)
        self.rng = rng if rng is not None else np.random.RandomState(None)
        self.n_step = 0

        # Attack sampling is O(1) per sample (no loops in produce).
        self._attack_sampler = _AliasSampler(ATTACK_PROBS, self.rng)

        # Precompute categorical encodings for vectorized paths.
        self._proto_to_idx = {v: i for i, v in enumerate(_PROTO_VALS.tolist())}
        self._conn_to_idx = {v: i for i, v in enumerate(_CONN_VALS.tolist())}

    # -------- Timestamp sequencing (monotonic, block-based) --------
    def next_timestamp(self):
        i = self.n_step
        block_start = self.start_ts + pd.Timedelta(seconds=i * self.epsilon)
        delta = float(self.rng.uniform(0, self.epsilon))
        ts = block_start + pd.Timedelta(seconds=delta)
        self.n_step += 1
        return ts.replace(microsecond=0)

    def next_timestamps(self, n: int) -> pd.DatetimeIndex:
        n = int(n)
        i = self.n_step + np.arange(n, dtype=np.int64)
        base = self.start_ts.to_datetime64() + (i * self.epsilon).astype('timedelta64[s]')
        delta = (self.rng.uniform(0, self.epsilon, size=n)).astype(np.int64)
        ts = base + delta.astype('timedelta64[s]')
        self.n_step += n
        return pd.to_datetime(ts).floor('S')

    def reset(self, start_ts: str = None, epsilon_seconds: int = None):
        if start_ts is not None:
            self.start_ts = pd.to_datetime(start_ts)
        if epsilon_seconds is not None:
            self.epsilon = int(epsilon_seconds)
        self.n_step = 0

    # -------- Latent structure helpers (explicit, human interpretable) --------
    @staticmethod
    def sigmoid(x):
        return 1.0 / (1.0 + np.exp(-x))

    @staticmethod
    def cat_to_num(val, choices):
        idx = choices.index(val)
        return -1 + 2 * idx / (len(choices) - 1) if len(choices) > 1 else 0.0

    # -------- Oracle logic (kept compatible; used for concept drift + evaluation) --------
    def fX_normal(self, row):
        pc = float(row['packet_count']) / 2000.0
        bt = float(row['bytes_transferred']) / 2e6
        sp = float(row['src_port']) / 65535.0
        dp = float(row['dst_port']) / 65535.0
        proto = self.cat_to_num(row['protocol'], SCHEMA['protocol']['vals'])
        state = self.cat_to_num(row['conn_state'], SCHEMA['conn_state']['vals'])
        score = np.zeros(NUM_CLASSES)
        score[1] += 3.0 * pc**2 + 1.5 * self.sigmoid(proto - 0.5) + 0.5 * (1 - sp) + 0.3 * pc * proto
        score[2] += 1.5 * np.isin(row['dst_port'], [21, 22, 23, 80, 443]) + 2.0 * (pc * (1 - pc)) + 0.5 * dp * pc + 0.3 * dp * state
        score[3] += 2.0 * state * (1 - bt) + 0.5 * (1 - dp) + 0.3 * bt * state + 1.0 * (1 - pc) * state
        score[4] += 1.5 * bt**2 + 0.5 * state * sp + 0.3 * bt * pc + 0.5 * sp * (1 - dp)
        score[5] += 2.0 * bt * pc + 0.5 * (1 * (row['protocol'] != 'ICMP')) + 0.3 * sp * dp
        score[0] += 1.0 - abs(pc - 0.3) - abs(bt - 0.3) + 0.2 * (1 - abs(proto)) + 0.1 * (1 - abs(state))
        score += 0.02 * self.rng.randn(NUM_CLASSES)
        return int(np.argmax(score))

    def fX_concept(self, row):
        pc = float(row['packet_count']) / 2000.0
        bt = float(row['bytes_transferred']) / 2e6
        sp = float(row['src_port']) / 65535.0
        dp = float(row['dst_port']) / 65535.0
        proto = self.cat_to_num(row['protocol'], SCHEMA['protocol']['vals'])
        state = self.cat_to_num(row['conn_state'], SCHEMA['conn_state']['vals'])
        score = np.zeros(NUM_CLASSES)
        score[1] += 4.0 * bt + 0.8 * (row['protocol'] == 'ICMP') - 0.3 * pc
        src_mod = (row['src_port'] % 1000) / 1000.0
        score[2] += 3.0 * (src_mod < 0.2) + 1.0 * np.isin(row['dst_port'], [21, 22, 23, 80, 443])
        score[3] += 2.0 * (row['protocol'] == 'ICMP') + 1.0 * (0.05 < bt < 0.2)
        score[4] += 2.0 * (pc > 0.7) + 1.2 * bt
        score[5] += 1.5 * sp * dp + 0.7 * (0.2 < pc < 0.6)
        score[0] += 0.8 - 0.2 * (abs(pc - 0.35) + abs(bt - 0.35))
        score += 0.05 * self.rng.randn(NUM_CLASSES)
        return int(np.argmax(score))

    # -------- Feature sampling (attack semantics + drift; no rejection) --------
    def _sample_attack_id_one(self) -> int:
        return int(self._attack_sampler.sample_one())

    def _sample_attack_ids(self, n: int) -> np.ndarray:
        return self._attack_sampler.sample_n(n)

    def _baseline_features_one(self, trend: str):
        if trend in ('normal', 'concept_drift'):
            src_port = int(self.rng.randint(1024, 65535))
            dst_port = int(self.rng.randint(1, 65535))
            proto = str(self.rng.choice(['TCP', 'UDP']))
            packet_count = int(self.rng.randint(1, 400))
            conn_state = 'EST'
            bytes_transferred = float(self.rng.uniform(1e3, 5e5))
        else:  # data_drift
            # Moderate covariate shift: move normals towards heavier traffic and noisier states
            # but keep attack semantics unchanged (labels remain coherent).
            src_port = int(1024 + (45000 - 1024) * float(self.rng.beta(a=2, b=4)))
            dst_port = int(1 + (45000 - 1) * float(self.rng.beta(a=2, b=4)))
            proto = str(self.rng.choice(['ICMP', 'TCP', 'UDP'], p=[0.25, 0.40, 0.35]))
            packet_count = int(50 + (1200 - 50) * float(self.rng.beta(a=2, b=5)))
            conn_state = str(self.rng.choice(['EST', 'SYN', 'RST'], p=[0.55, 0.35, 0.10]))
            bytes_transferred = float(1e4 + (1.2e6 - 1e4) * float(self.rng.beta(a=2, b=5)))
        return src_port, dst_port, proto, packet_count, conn_state, bytes_transferred

    def _apply_attack_semantics_one(
        self,
        attack_id: int,
        src_port: int,
        dst_port: int,
        proto: str,
        packet_count: int,
        conn_state: str,
        bytes_transferred: float,
    ):
        # Deterministic-by-construction attack regions (no oracle correction).
        if attack_id == ATTACK_TO_ID['DoS']:
            packet_count = int(self.rng.randint(1200, 2000))
            proto = str(self.rng.choice(['ICMP', 'UDP']))
        elif attack_id == ATTACK_TO_ID['Probe']:
            dst_port = int(self.rng.choice([21, 22, 23, 80, 443]))
            packet_count = int(self.rng.randint(200, 800))
        elif attack_id == ATTACK_TO_ID['R2L']:
            conn_state = 'RST'
            bytes_transferred = float(self.rng.uniform(0, 200))
        elif attack_id == ATTACK_TO_ID['U2R']:
            bytes_transferred = float(self.rng.uniform(8e5, 1.3e6))
            conn_state = 'EST'
        elif attack_id == ATTACK_TO_ID['Worm']:
            bytes_transferred = float(self.rng.uniform(1.5e6, 2e6))
            packet_count = int(self.rng.randint(600, 1400))
        # Normal: keep baseline.
        return src_port, dst_port, proto, packet_count, conn_state, bytes_transferred

    def _sample_row_one(self, attack_id: int, trend: str):
        src_port, dst_port, proto, packet_count, conn_state, bytes_transferred = self._baseline_features_one(trend)
        src_port, dst_port, proto, packet_count, conn_state, bytes_transferred = self._apply_attack_semantics_one(
            attack_id, src_port, dst_port, proto, packet_count, conn_state, bytes_transferred
        )
        if trend == 'data_drift':
            packet_count = int(np.clip(packet_count + float(self.rng.normal(0, 80)), 1, 2000))
            bytes_transferred = float(np.clip(bytes_transferred * float(self.rng.lognormal(mean=0.0, sigma=0.20)), 0.0, 2e6))
        row = {
            'src_port': int(src_port),
            'dst_port': int(dst_port),
            'protocol': proto,
            'packet_count': int(packet_count),
            'conn_state': conn_state,
            'bytes_transferred': float(bytes_transferred),
            'timestamp': self.next_timestamp(),
        }
        return row

    # -------- Public single-sample API (fast, O(1), loop-free) --------
    def produce(self, trend: str = 'normal'):
        """Generate exactly one JSON-friendly sample.

        Requirements met:
        - No loops
        - No rejection sampling
        - No oracle correction
        - Only: sample features, deterministic label, timestamp, format output
        """
        attack_id = self._sample_attack_id_one()
        row = self._sample_row_one(attack_id, trend)

        if trend == 'concept_drift':
            label = self.fX_concept(row)
        else:
            label = int(attack_id)

        return {
            "timestamp": row["timestamp"].isoformat(),
            "properties": {
                "src_port": int(row["src_port"]),
                "dst_port": int(row["dst_port"]),
                "protocol": row["protocol"],
                "packet_count": int(row["packet_count"]),
                "conn_state": row["conn_state"],
                "bytes_transferred": float(row["bytes_transferred"]),
            },
            "label": int(label),
        }

    # -------- Vectorized oracles for batch (no Python loops over rows) --------
    def _fX_concept_vec(
        self,
        src_port: np.ndarray,
        dst_port: np.ndarray,
        proto_idx: np.ndarray,
        packet_count: np.ndarray,
        conn_idx: np.ndarray,
        bytes_transferred: np.ndarray,
    ) -> np.ndarray:
        pc = packet_count.astype(np.float64) / 2000.0
        bt = bytes_transferred.astype(np.float64) / 2e6
        sp = src_port.astype(np.float64) / 65535.0
        dp = dst_port.astype(np.float64) / 65535.0

        is_icmp = (proto_idx == int(self._proto_to_idx['ICMP']))

        score = np.zeros((pc.size, NUM_CLASSES), dtype=np.float64)
        score[:, 1] += 4.0 * bt + 0.8 * is_icmp.astype(np.float64) - 0.3 * pc
        src_mod = (src_port % 1000).astype(np.float64) / 1000.0
        score[:, 2] += 3.0 * (src_mod < 0.2).astype(np.float64) + 1.0 * np.isin(dst_port, [21, 22, 23, 80, 443]).astype(np.float64)
        score[:, 3] += 2.0 * is_icmp.astype(np.float64) + 1.0 * ((bt > 0.05) & (bt < 0.2)).astype(np.float64)
        score[:, 4] += 2.0 * (pc > 0.7).astype(np.float64) + 1.2 * bt
        score[:, 5] += 1.5 * sp * dp + 0.7 * ((pc > 0.2) & (pc < 0.6)).astype(np.float64)
        score[:, 0] += 0.8 - 0.2 * (np.abs(pc - 0.35) + np.abs(bt - 0.35))
        score += 0.05 * self.rng.randn(pc.size, NUM_CLASSES)
        return np.argmax(score, axis=1).astype(np.int32, copy=False)

    # -------- Batch generation (vectorized; shares same distributions/semantics) --------
    def generate_dataset(self, n: int, trend: str = 'normal'):
        """Generate a pandas DataFrame with an `attack` column.

        Vectorized generation:
        - Samples all attacks in bulk.
        - Samples timestamps in bulk.
        - Applies attack semantics via boolean masks.
        """
        n = int(n)
        attack_id = self._sample_attack_ids(n)
        ts = self.next_timestamps(n)

        # Baseline distributions
        if trend in ('normal', 'concept_drift'):
            src_port = self.rng.randint(1024, 65535, size=n).astype(np.int32)
            dst_port = self.rng.randint(1, 65535, size=n).astype(np.int32)
            proto_idx = self.rng.choice(
                np.array([self._proto_to_idx['TCP'], self._proto_to_idx['UDP']], dtype=np.int32),
                size=n,
            ).astype(np.int32)
            packet_count = self.rng.randint(1, 400, size=n).astype(np.int32)
            conn_idx = np.full(n, self._conn_to_idx['EST'], dtype=np.int32)
            bytes_transferred = self.rng.uniform(1e3, 5e5, size=n).astype(np.float64)
        else:  # data_drift
            # Moderate covariate shift (aim: val_data_drift accuracy ~0.8-0.9).
            src_port = (1024 + (45000 - 1024) * self.rng.beta(a=2, b=4, size=n)).astype(np.int32)
            dst_port = (1 + (45000 - 1) * self.rng.beta(a=2, b=4, size=n)).astype(np.int32)
            proto_idx = self.rng.choice(
                np.array([self._proto_to_idx['ICMP'], self._proto_to_idx['TCP'], self._proto_to_idx['UDP']], dtype=np.int32),
                p=[0.25, 0.40, 0.35],
                size=n,
            ).astype(np.int32)
            packet_count = (50 + (1200 - 50) * self.rng.beta(a=2, b=5, size=n)).astype(np.int32)
            conn_idx = self.rng.choice(
                np.array([self._conn_to_idx['EST'], self._conn_to_idx['SYN'], self._conn_to_idx['RST']], dtype=np.int32),
                p=[0.55, 0.35, 0.10],
                size=n,
            ).astype(np.int32)
            bytes_transferred = (1e4 + (1.2e6 - 1e4) * self.rng.beta(a=2, b=5, size=n)).astype(np.float64)

        # Attack masks
        m_dos = (attack_id == ATTACK_TO_ID['DoS'])
        m_probe = (attack_id == ATTACK_TO_ID['Probe'])
        m_r2l = (attack_id == ATTACK_TO_ID['R2L'])
        m_u2r = (attack_id == ATTACK_TO_ID['U2R'])
        m_worm = (attack_id == ATTACK_TO_ID['Worm'])

        # Apply attack semantics (vectorized; no rejection)
        if np.any(m_dos):
            k = int(m_dos.sum())
            packet_count[m_dos] = self.rng.randint(1200, 2000, size=k)
            proto_idx[m_dos] = self.rng.choice(
                np.array([self._proto_to_idx['ICMP'], self._proto_to_idx['UDP']], dtype=np.int32),
                size=k,
            )
        if np.any(m_probe):
            k = int(m_probe.sum())
            dst_port[m_probe] = self.rng.choice(np.array([21, 22, 23, 80, 443], dtype=np.int32), size=k)
            packet_count[m_probe] = self.rng.randint(200, 800, size=k)
        if np.any(m_r2l):
            k = int(m_r2l.sum())
            conn_idx[m_r2l] = self._conn_to_idx['RST']
            bytes_transferred[m_r2l] = self.rng.uniform(0, 200, size=k)
        if np.any(m_u2r):
            k = int(m_u2r.sum())
            conn_idx[m_u2r] = self._conn_to_idx['EST']
            bytes_transferred[m_u2r] = self.rng.uniform(8e5, 1.3e6, size=k)
        if np.any(m_worm):
            k = int(m_worm.sum())
            bytes_transferred[m_worm] = self.rng.uniform(1.5e6, 2e6, size=k)
            packet_count[m_worm] = self.rng.randint(600, 1400, size=k)
        if trend == 'data_drift':
            packet_count = np.clip(packet_count.astype(np.float64) + self.rng.normal(0, 80, size=n), 1, 2000).astype(np.int32)
            bytes_transferred = np.clip(bytes_transferred * self.rng.lognormal(mean=0.0, sigma=0.20, size=n), 0.0, 2e6).astype(np.float64)

        # Decode categoricals
        protocol = _PROTO_VALS[proto_idx]
        conn_state = _CONN_VALS[conn_idx]

        df = pd.DataFrame({
            'src_port': src_port.astype(int),
            'dst_port': dst_port.astype(int),
            'protocol': protocol,
            'packet_count': packet_count.astype(int),
            'conn_state': conn_state,
            'bytes_transferred': bytes_transferred.astype(float),
            'timestamp': ts,
        })

        if trend == 'concept_drift':
            df['attack'] = self._fX_concept_vec(
                src_port=src_port,
                dst_port=dst_port,
                proto_idx=proto_idx,
                packet_count=packet_count,
                conn_idx=conn_idx,
                bytes_transferred=bytes_transferred,
            ).astype(int)
        else:
            df['attack'] = attack_id.astype(int)

        df['timestamp'] = pd.to_datetime(df['timestamp'])
        return df

    # Optional explicit batch API for streaming simulators
    def produce_batch(self, n: int, trend: str = 'normal'):
        """Return a list of JSON-friendly dicts (like `produce()`) for integration ease."""
        df = self.generate_dataset(n, trend=trend)
        out = []
        # This loop is intentionally outside `produce()`; batch callers can stream/serialize.
        for r in df.itertuples(index=False):
            out.append({
                "timestamp": pd.Timestamp(r.timestamp).isoformat(),
                "properties": {
                    "src_port": int(r.src_port),
                    "dst_port": int(r.dst_port),
                    "protocol": str(r.protocol),
                    "packet_count": int(r.packet_count),
                    "conn_state": str(r.conn_state),
                    "bytes_transferred": float(r.bytes_transferred),
                },
                "label": int(r.attack),
            })
        return out


# -------------------------
# PREPROCESSING - PANDAS (large, realistic)
# -------------------------
def preprocess_pandas(df, fit_encoders=True, fit_scaler=True, global_stats=None):
    """
    Preprocesamiento complejo estilo PySpark-ready:
    - Categóricas -> LabelEncoder
    - Features log-transform / ratios
    - Rolling/window stats por session
    - Estadísticas globales (mean/std) para normalización
    """

    df = df.copy()

    # 1) timestamp -> segundos desde epoch
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df['timestamp_epoch'] = df['timestamp'].astype('int64') // 10**9

    # 2) crear feature log + small offset para evitar log(0)
    df['bytes_log'] = np.log1p(df['bytes_transferred'])
    df['packet_log'] = np.log1p(df['packet_count'])

    # 3) ratios
    df['bytes_per_packet'] = df['bytes_transferred'] / (df['packet_count'] + 1)
    df['bytes_per_packet_log'] = np.log1p(df['bytes_per_packet'])

    # 4) session_id: combinación de src_port + dst_port + protocol (puede traducirse a PySpark)
    df['session_id'] = df['src_port'].astype(str) + '-' + df['dst_port'].astype(str) + '-' + df['protocol']

    # 5) rolling stats por sesión (usando transform para mantener índice)
    df['prev_bytes_mean_3'] = df.groupby('session_id')['bytes_log'] \
                                .transform(lambda s: s.shift().rolling(3, min_periods=1).mean())
    df['prev_packet_mean_3'] = df.groupby('session_id')['packet_log'] \
                                 .transform(lambda s: s.shift().rolling(3, min_periods=1).mean())
    df['prev_event_count_3'] = df.groupby('session_id')['packet_count'] \
                                  .transform(lambda s: s.shift().rolling(3, min_periods=1).count())

    # 6) features agregadas por session (global session stats)
    df['session_bytes_max'] = df.groupby('session_id')['bytes_log'].transform('max')
    df['session_bytes_min'] = df.groupby('session_id')['bytes_log'].transform('min')
    df['session_packet_mean'] = df.groupby('session_id')['packet_log'].transform('mean')

    # 7) rolling stats pueden producir NaN en los primeros eventos de cada sesión
    roll_cols = ['prev_bytes_mean_3', 'prev_packet_mean_3', 'prev_event_count_3']
    df[roll_cols] = df[roll_cols].fillna(0)

    # 8) codificación de categóricas (fit nuevo o reusar encoders existentes)
    if isinstance(fit_encoders, dict):
        encoders = fit_encoders
        fit_new_encoders = False
    else:
        encoders = {}
        fit_new_encoders = bool(fit_encoders)

    for col in ['protocol', 'conn_state']:
        if fit_new_encoders:
            le = LabelEncoder()
            df[col] = le.fit_transform(df[col])
            encoders[col] = le
        else:
            le = encoders[col]
            mapping = {cls: int(i) for i, cls in enumerate(le.classes_)}
            df[col] = df[col].map(mapping).fillna(-1).astype(int)

    # 9) columnas finales para modelado (antes de normalizar)
    base_feature_cols = ['src_port', 'dst_port', 'packet_count', 'bytes_transferred', 'bytes_log', 'packet_log',
                         'bytes_per_packet', 'bytes_per_packet_log',
                         'prev_bytes_mean_3', 'prev_packet_mean_3', 'prev_event_count_3',
                         'session_bytes_max', 'session_bytes_min', 'session_packet_mean',
                         'protocol', 'conn_state', 'timestamp_epoch']

    # 10) escalado: aceptar `fit_scaler=True` (fit) o un StandardScaler ya entrenado
    X = df[base_feature_cols].fillna(0).values
    if isinstance(fit_scaler, StandardScaler):
        scaler = fit_scaler
    elif bool(fit_scaler):
        scaler = StandardScaler()
        scaler.fit(X)
    else:
        scaler = None

    if scaler is not None:
        Xs = scaler.transform(X)
        global_stats = {
            f: {'mean': float(scaler.mean_[i]), 'std': float(scaler.scale_[i])}
            for i, f in enumerate(base_feature_cols)
        }
        for i, f in enumerate(base_feature_cols):
            df[f + '_norm'] = Xs[:, i]
    else:
        # fallback: usar stats provistas o calcularlas desde el df
        if global_stats is None:
            global_stats = {f: {'mean': df[f].mean(), 'std': df[f].std()} for f in base_feature_cols}
        for f in base_feature_cols:
            df[f + '_norm'] = (df[f].fillna(0) - global_stats[f]['mean']) / (global_stats[f]['std'] + 1e-6)

    feature_cols = [f + '_norm' for f in base_feature_cols]
    return df, encoders, scaler, global_stats, feature_cols


# -------------------------
# Pyspark equivalent preprocessing (if you want distributed)
# -------------------------
def preprocess_spark(spark, spark_df):
    """
    Example PySpark pipeline equivalent. This returns a Spark DataFrame with similar derived features.
    This function is illustrative; to run it you need pyspark configured.
    """
    from pyspark.sql import functions as F
    from pyspark.sql.window import Window
    from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler

    df = spark_df
    # time features
    df = df.withColumn('timestamp_ts', F.col('timestamp').cast('timestamp'))
    df = df.withColumn('hour', F.hour('timestamp_ts'))
    df = df.withColumn('dayofweek', F.dayofweek('timestamp_ts') - 1)  # pyspark dayofweek 1..7
    df = df.withColumn('is_weekend', F.when(F.col('dayofweek').isin(5, 6), 1).otherwise(0))
    # sin/cos for hour
    df = df.withColumn('hour_sin', F.sin(2 * 3.1415926535 * F.col('hour') / F.lit(24.0)))
    df = df.withColumn('hour_cos', F.cos(2 * 3.1415926535 * F.col('hour') / F.lit(24.0)))

    # flags
    well_known = [21, 22, 23, 25, 53, 80, 110, 143, 443]
    df = df.withColumn('dst_well_known', F.when(F.col('dst_port').isin(well_known), 1).otherwise(0))
    df = df.withColumn('src_port_bucket', (F.col('src_port') / 10000).cast('int'))

    # logs & ratios
    df = df.withColumn('bytes_log', F.log1p(F.col('bytes_transferred')))
    df = df.withColumn('packet_log', F.log1p(F.col('packet_count')))
    df = df.withColumn('pkt_per_byte', F.col('packet_count') / (F.col('bytes_transferred') + F.lit(1.0)))

    # crosses
    df = df.withColumn('protocol_conn', F.concat_ws('_', F.col('protocol'), F.col('conn_state')))

    # session id
    df = df.withColumn('session_id', F.concat_ws('_', F.col('src_port').cast('string'), F.col('dst_port').cast('string'), F.col('protocol')))

    # window functions per session (previous 3 events)
    w = Window.partitionBy('session_id').orderBy('timestamp_ts').rowsBetween(-3, -1)
    df = df.withColumn('prev_bytes_mean_3', F.avg('bytes_log').over(w))
    df = df.withColumn('prev_packet_mean_3', F.avg('packet_log').over(w))
    df = df.withColumn('prev_event_count_3', F.count('packet_count').over(w))

    # encode categorical columns using StringIndexer (example)
    protocol_indexer = StringIndexer(inputCol='protocol', outputCol='protocol_idx').fit(df)
    df = protocol_indexer.transform(df)
    conn_indexer = StringIndexer(inputCol='conn_state', outputCol='conn_state_idx').fit(df)
    df = conn_indexer.transform(df)
    pc_indexer = StringIndexer(inputCol='protocol_conn', outputCol='protocol_conn_idx').fit(df)
    df = pc_indexer.transform(df)

    # choose features and assemble
    feat_cols = ['src_port', 'dst_port', 'packet_count', 'bytes_transferred', 'bytes_log', 'packet_log', 'pkt_per_byte',
                 'hour', 'dayofweek', 'is_weekend', 'hour_sin', 'hour_cos',
                 'dst_well_known', 'src_port_bucket', 'protocol_idx', 'conn_state_idx', 'protocol_conn_idx',
                 'prev_bytes_mean_3', 'prev_packet_mean_3', 'prev_event_count_3']
    assembler = VectorAssembler(inputCols=feat_cols, outputCol='features_vec')
    df = assembler.transform(df)

    # standardize
    scaler = StandardScaler(inputCol='features_vec', outputCol='features_scaled', withMean=True, withStd=True)
    scaler_model = scaler.fit(df)
    df = scaler_model.transform(df)

    return df  # features in 'features_scaled' vector


# -------------------------
# Helper: try to save Parquet (simulate distributed storage)
# -------------------------
def save_parquet(df, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    df.to_parquet(path, index=False)
    print(f"Saved parquet to {path}")


# -------------------------
# Full demo pipeline (generate -> preprocess -> train) using pandas preprocess
# -------------------------
def run_demo():
    gen = SyntheticTrafficGenerator(start_ts="2026-01-12 18:00:00", epsilon_seconds=60, rng=RNG)
    TRAIN_N, VAL_N = 2000, 3000

    print("Generating datasets...")
    df_train = gen.generate_dataset(TRAIN_N, trend='normal')
    df_val_normal = gen.generate_dataset(VAL_N, trend='normal')
    df_val_data = gen.generate_dataset(VAL_N, trend='data_drift')
    df_val_concept = gen.generate_dataset(VAL_N, trend='concept_drift')

    # optionally save to parquet for distributed load
    save_parquet(df_train, "out/train.parquet")
    save_parquet(df_val_normal, "out/val_normal.parquet")
    save_parquet(df_val_data, "out/val_data.parquet")
    save_parquet(df_val_concept, "out/val_concept.parquet")

    # build concept_train (val_normal features relabeled by concept)
    df_concept_train = df_val_normal.copy()
    # Keep prior behavior (concept oracle over normal-distribution features), but vectorized for scale.
    proto_idx = df_concept_train['protocol'].map(gen._proto_to_idx).astype(np.int32).values
    conn_idx = df_concept_train['conn_state'].map(gen._conn_to_idx).astype(np.int32).values
    df_concept_train['attack'] = gen._fX_concept_vec(
        src_port=df_concept_train['src_port'].astype(np.int32).values,
        dst_port=df_concept_train['dst_port'].astype(np.int32).values,
        proto_idx=proto_idx,
        packet_count=df_concept_train['packet_count'].astype(np.int32).values,
        conn_idx=conn_idx,
        bytes_transferred=df_concept_train['bytes_transferred'].astype(np.float64).values,
    ).astype(int)

    # preprocess train (fits encoders/scaler)
    print("Preprocessing train (Pandas)...")
    df_train_proc, encoders, scaler, global_stats, feature_cols = preprocess_pandas(df_train)

    # preprocess val using same encoders/scaler / global_stats
    print("Preprocessing val_normal (Pandas)...")
    df_valn_proc, _, _, _, _ = preprocess_pandas(df_val_normal, fit_encoders=encoders, fit_scaler=scaler, global_stats=global_stats)
    df_vald_proc, _, _, _, _ = preprocess_pandas(df_val_data, fit_encoders=encoders, fit_scaler=scaler, global_stats=global_stats)
    df_valc_proc, _, _, _, _ = preprocess_pandas(df_val_concept, fit_encoders=encoders, fit_scaler=scaler, global_stats=global_stats)
    df_concept_train_proc, _, _, _, _ = preprocess_pandas(df_concept_train, fit_encoders=encoders, fit_scaler=scaler, global_stats=global_stats)

    # optionally save to parquet for distributed load
    #save_parquet(df_train_proc, "out/train_proc.parquet")
    #save_parquet(df_valn_proc, "out/val_normal_proc.parquet")
    #save_parquet(df_vald_proc, "out/val_data_proc.parquet")
    #save_parquet(df_valc_proc, "out/val_concept_proc.parquet")

    # prepare arrays for xgboost (use feature_cols returned by preprocess)
    X_train = df_train_proc[feature_cols].values
    y_train = df_train_proc['attack'].values
    Xn = df_valn_proc[feature_cols].values
    yn = df_valn_proc['attack'].values
    Xd = df_vald_proc[feature_cols].values
    yd = df_vald_proc['attack'].values
    Xc = df_valc_proc[feature_cols].values
    yc = df_valc_proc['attack'].values

    # train helper (GPU fallback)
    def train_xgb(X_tr, y_tr, eval_set):
        params = {
            'n_estimators': 100,
            'max_depth': 7,
            'learning_rate': 0.1,
            'objective': 'multi:softprob',
            'num_class': NUM_CLASSES,
            'verbosity': 1,
            'eval_metric': ['mlogloss', 'merror'],
        }
        try:
            model = xgb.XGBClassifier(**params, tree_method='hist', device='gpu')
            model.fit(X_tr, y_tr, eval_set=eval_set, verbose=False)
            print("Trained with GPU.")
        except Exception as e:
            print("GPU failed:", e, "Falling back to CPU.")
            model = xgb.XGBClassifier(**params, tree_method='hist', device='cpu')
            model.fit(X_tr, y_tr, eval_set=eval_set, verbose=False)
        return model

    print("\n=== EXP: train on normal, eval on normal/data/concept ===")
    model_norm = train_xgb(X_train, y_train, eval_set=[(Xn, yn)])

    def run_report(m, Xv, yv, name):
        ypred = m.predict(Xv)
        print(f"\n--- {name} --- Accuracy: {accuracy_score(yv, ypred):.4f}")
        print(classification_report(yv, ypred, target_names=[ATTACK_LABELS[i] for i in range(NUM_CLASSES)], digits=4))

    run_report(model_norm, Xn, yn, "val_normal")
    run_report(model_norm, Xd, yd, "val_data_drift")
    run_report(model_norm, Xc, yc, "val_concept_drift")

    # retrain on concept (as before)
    print("\n=== EXP: retrain on concept (using df_concept_train_proc) ===")
    X_concept_train = df_concept_train_proc[feature_cols].values
    y_concept_train = df_concept_train_proc['attack'].values
    X_concept_val = df_valc_proc[feature_cols].values
    y_concept_val = df_valc_proc['attack'].values
    model_concept = train_xgb(X_concept_train, y_concept_train, eval_set=[(X_concept_val, y_concept_val)])
    run_report(model_concept, X_concept_val, y_concept_val, "concept_val (after retrain)")

    print("\nSample processed train head:")
    print(df_train_proc.head(6))
    print("\nFeature columns used:", feature_cols)


# -------------------------
# Entrypoint
# -------------------------
if __name__ == "__main__":
    run_demo()

    # if you want to try the pyspark pipeline, uncomment next lines (requires pyspark installed/configured):
    # from pyspark.sql import SparkSession
    # spark = SparkSession.builder.master("local[4]").appName("synth").getOrCreate()
    # df_train_spark = spark.read.parquet("out/train_proc.parquet")
    # df_train_spark.show(5)
    # spark.stop()


Generating datasets...
Saved parquet to out/train.parquet
Saved parquet to out/val_normal.parquet
Saved parquet to out/val_data.parquet
Saved parquet to out/val_concept.parquet
Preprocessing train (Pandas)...
Preprocessing val_normal (Pandas)...

=== EXP: train on normal, eval on normal/data/concept ===
Trained with GPU.

--- val_normal --- Accuracy: 0.9960
              precision    recall  f1-score   support

      Normal     0.9979    0.9982    0.9980      2801
         DoS     1.0000    1.0000    1.0000        68
       Probe     0.8387    0.8667    0.8525        30
         R2L     1.0000    1.0000    1.0000        57
         U2R     1.0000    0.8966    0.9455        29
        Worm     0.9375    1.0000    0.9677        15

    accuracy                         0.9960      3000
   macro avg     0.9623    0.9602    0.9606      3000
weighted avg     0.9961    0.9960    0.9960      3000


--- val_data_drift --- Accuracy: 0.8757
              precision    recall  f1-score   support

 

In [10]:
# producer method returning JSON-like dict
producer = SyntheticTrafficGenerator(start_ts= "2026-01-12 18:00:00", epsilon_seconds= 100, rng= RNG)

outputs = []
for _ in range(10000):
    outputs.append(producer.produce(trend='normal'))

df_output = pd.DataFrame(outputs)
df_output['label'].value_counts()

label
0    9400
1     193
3     152
2     126
4      85
5      44
Name: count, dtype: int64

In [8]:
df_output

Unnamed: 0,timestamp,properties,label
0,2026-01-12T18:01:22.000000751,"{'src_port': 47116, 'dst_port': 58479, 'protoc...",0
1,2026-01-12T18:01:43.000000605,"{'src_port': 50868, 'dst_port': 57682, 'protoc...",0
2,2026-01-12T18:03:36.000000236,"{'src_port': 56232, 'dst_port': 14910, 'protoc...",0
3,2026-01-12T18:05:13.000000090,"{'src_port': 15575, 'dst_port': 63430, 'protoc...",0
4,2026-01-12T18:07:19.000000017,"{'src_port': 59834, 'dst_port': 29222, 'protoc...",0
...,...,...,...
9995,2026-01-24T07:39:39.000000396,"{'src_port': 26355, 'dst_port': 18439, 'protoc...",3
9996,2026-01-24T07:40:17.000000218,"{'src_port': 28504, 'dst_port': 10880, 'protoc...",0
9997,2026-01-24T07:42:06.000000976,"{'src_port': 43586, 'dst_port': 15772, 'protoc...",0
9998,2026-01-24T07:44:10.000000990,"{'src_port': 37432, 'dst_port': 35668, 'protoc...",0


In [8]:
import pandas as pd

df = pd.read_parquet("out/val_normal.parquet")
df

Unnamed: 0,src_port,dst_port,protocol,packet_count,conn_state,bytes_transferred,timestamp,attack
0,12653,59768,UDP,337,EST,149161.748507,2026-01-14 03:20:38,0
1,4880,31756,UDP,1,EST,49214.113299,2026-01-14 03:21:25,0
2,32075,13744,UDP,109,EST,447438.194602,2026-01-14 03:22:53,0
3,2099,6227,TCP,125,EST,338880.869919,2026-01-14 03:23:31,0
4,49547,12719,TCP,64,EST,324407.094029,2026-01-14 03:24:30,0
...,...,...,...,...,...,...,...,...
2995,19364,37638,UDP,170,EST,150686.117201,2026-01-16 05:15:57,0
2996,4884,55002,UDP,12,EST,29026.916108,2026-01-16 05:16:09,0
2997,20183,48360,TCP,168,EST,238008.552157,2026-01-16 05:17:03,0
2998,28889,55120,TCP,343,EST,105251.418170,2026-01-16 05:18:34,0


In [12]:
df_output = pd.read_parquet("part-00000-94e12034-20de-42c4-b60f-7c4a9db74502-c000.snappy.parquet")
df_output

Unnamed: 0,src_port,dst_port,protocol,packet_count,conn_state,bytes_transferred,timestamp,attack,timestamp_ts,hour,...,hour_sin,hour_cos,bytes_log,packet_log,protocol_conn,protocol_idx,conn_state_idx,protocol_conn_idx,features_vec,features_scaled
0,12653,59768,UDP,337,EST,149161.748507,1768360838000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,11.912793,5.823046,UDP_EST,1.0,0.0,1.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
1,4880,31756,UDP,1,EST,49214.113299,1768360885000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,10.803956,0.693147,UDP_EST,1.0,0.0,1.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
2,32075,13744,UDP,109,EST,447438.194602,1768360973000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,13.011296,4.700480,UDP_EST,1.0,0.0,1.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
3,2099,6227,TCP,125,EST,338880.869919,1768361011000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,12.733407,4.836282,TCP_EST,0.0,0.0,0.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
4,49547,12719,TCP,64,EST,324407.094029,1768361070000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,12.689758,4.174387,TCP_EST,0.0,0.0,0.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2995,19364,37638,UDP,170,EST,150686.117201,1768540557000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,11.922961,5.141664,UDP_EST,1.0,0.0,1.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
2996,4884,55002,UDP,12,EST,29026.916108,1768540569000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,10.276013,2.564949,UDP_EST,1.0,0.0,1.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
2997,20183,48360,TCP,168,EST,238008.552157,1768540623000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,12.380066,5.129899,TCP_EST,0.0,0.0,0.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
2998,28889,55120,TCP,343,EST,105251.418170,1768540714000000000,0,1717-03-31 08:59:44.950139928,4,...,0.866025,0.5,11.564117,5.840642,TCP_EST,0.0,0.0,0.0,"{'type': 1, 'size': None, 'indices': None, 'va...","{'type': 1, 'size': None, 'indices': None, 'va..."
