<a href="https://colab.research.google.com/github/supriyag123/PHD_Pub/blob/main/AGENTIC-MODULE6-REVISED-DecisionExpertLLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:


# ============================================================
# Fault Classification Pipeline
# ============================================================

############PASTE ADAPTIVE WINDOW HERE - so everything is in one file - later, we can import as a package#####################


# ====== AdaptiveWindowAgent ======
# =====================================================
# AdaptiveWindowAgent (improved version)
# =====================================================
# agents/adaptive_window_agent.py
import numpy as np
import pandas as pd
import pickle
import json
import os
from collections import deque
from typing import Dict, Any
import datetime as dt
import logging
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor  # ‚úÖ MOVED TO TOP
from statsmodels.tsa.vector_ar.var_model import VAR
import keras
import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')
from datetime import datetime

logger = logging.getLogger(__name__)

import sqlite3
import json
from datetime import datetime
import json


class EventStore:
    def __init__(self, db_path="event_store.db"):
        self.conn = sqlite3.connect(db_path)
        self._init_tables()

    def _init_tables(self):
        c = self.conn.cursor()
        c.execute("""
            CREATE TABLE IF NOT EXISTS events (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp TEXT,
                event_type TEXT,
                packet_json TEXT,
                expert_json TEXT,
                human_json TEXT
            )
        """)
        self.conn.commit()

    def save_event(self, event_type, packet, expert=None, human=None):
        c = self.conn.cursor()
        c.execute(
            "INSERT INTO events(timestamp,event_type,packet_json,expert_json,human_json) VALUES (?,?,?,?,?)",
            (
                datetime.now().isoformat(),
                event_type,
                json.dumps(packet, default=str),
                json.dumps(expert, default=str) if expert else None,
                json.dumps(human, default=str) if human else None,
            )
        )
        self.conn.commit()

    def fetch_recent(self, limit=100):
        c = self.conn.cursor()
        c.execute("SELECT packet_json, expert_json, human_json FROM events ORDER BY id DESC LIMIT ?", (limit,))
        return [json.loads(row[0]) for row in c.fetchall()]

#########################################################################
# Window Agent - Global Context or Global Predictive Context
#########################################################################

class AdaptiveWindowAgent:
    """
    Adaptive Window Agent:
    - Predicts window size using MLP
    - Evaluates forecast with RF/persistence
    - Computes:
        * FDS: Forecast Deviation Score (normalized error)
        * FDI: Forecast Drift Index (JSD over FDS distribution)
        * WSS: Window Shift Score (normalized window size)
        * WDI: Window Drift Index (JSD over window size distribution)
    - Detects anomaly (local) + drift (regime) events.
    """

    def __init__(self, agent_id="adaptive_window_agent",
                 model_path=None, checkpoint_path=None):
        self.agent_id = agent_id
        self.model_path = model_path or "/content/drive/MyDrive/PHD/2025/DGRNet-MLP-Versions/METROPM_MLP_model_10Sec.keras"

        self.baseline_path = "/content/drive/MyDrive/PHD/2025/DGRNet-MLP-Versions/METROPM_MLP_baseline.pkl"
        self.checkpoint_path = checkpoint_path

        # --------------------------------------------------
        # Core model
        # --------------------------------------------------
        self.model = None
        self.transformer = StandardScaler()
        self.transformer_fitted = False
        self.is_model_loaded = False

        # Baseline error stats (median / MAD) ‚Äì filled from baseline file
        self.rolling_stats = {
            'median': 0.0,
            'mad': 1.0,
        }

        # --------------------------------------------------
        # Metrics memory
        # --------------------------------------------------
        # Raw error
        self.error_memory = deque(maxlen=300)     # long-term errors
        self.recent_errors = deque(maxlen=50)     # kept for backward compat (not central now)

        # FDS (Forecast Deviation Score) history
        self.fds_memory = deque(maxlen=300)
        self.recent_fds = deque(maxlen=50)

        # Window history
        self.window_memory = deque(maxlen=300)
        self.recent_windows = deque(maxlen=50)

        # Last-step metrics (for returning)
        self.last_fds = 0.0
        self.last_fdi = 0.0
        self.last_wss = 0.0
        self.last_wdi = 0.0

        # Baseline distributions
        self.baseline_errors = None
        self.baseline_fds = None
        self.baseline_windows = None
        self.window_mean = 50.0    # a reasonable mid value
        self.window_std = 10.0     # non-zero, avoids div-by-zero

        # Optional histogram bins stored in baseline
        self.window_hist_bins = None
        self.window_hist_counts = None

        # Debug flag (OFF by default)
        self.debug = False

        # --------------------------------------------------
        # Anomaly / drift settings
        # --------------------------------------------------
        self.threshold_k = 3.0
        self.anomaly_cooldown = 0
        self.anomaly_cooldown_steps = 5

        # Drift detection
        self.drift_threshold = 0.25          # for FDI (JSD over FDS)
        self.window_drift_threshold = 0.20   # for WDI (JSD over window sizes)
        self.consecutive_drift_votes = 0
        self.drift_cooldown = 0
        self.drift_votes_required = 10
        self.drift_cooldown_steps = 100

        # --------------------------------------------------
        # Retraining buffer (unchanged)
        # --------------------------------------------------
        self.performance_stats = {
            'total_predictions': 0,
            'avg_mse': 0.0,
            'avg_mae': 0.0,
            'last_retrain_time': None,
            'drift_events': 0,
            'anomaly_events': 0,
            'retraining_events': 0
        }

        self.retraining_data = {
            'x_buffer': deque(maxlen=10000),
            'y_buffer': deque(maxlen=10000)
        }

        # --------------------------------------------------
        # Prediction history (for reporting)
        # --------------------------------------------------
        self.prediction_history = deque(maxlen=1000)
        self.mse_history = deque(maxlen=200)
        self.mae_history = deque(maxlen=200)

        # --------------------------------------------------
        # Load baseline (errors + windows)
        # --------------------------------------------------
        self._load_baseline()

        # --------------------------------------------------
        # Load model last
        # --------------------------------------------------
        self.load_model()
        print(f"AdaptiveWindowAgent {self.agent_id} initialized")

        # --------------------------------------------------
        # Load NSP (Next-Step Predictor)
        # --------------------------------------------------
        self.nsp_model_path = "/content/drive/MyDrive/PHD/2025/NSP_LSTM_next_step.keras"
        self.nsp_model = keras.models.load_model(self.nsp_model_path)
        print("‚úÖ Loaded NSP LSTM next-step predictor")

    # =================== BASELINE LOADING ===================

    def _load_baseline(self):
        """
        Load baseline stats:
          - baseline_errors, median, mad
          - baseline_windows, window_mean, window_std
          - optional histogram bins/counts for windows
        """
        if os.path.exists(self.baseline_path):
            with open(self.baseline_path, "rb") as f:
                base = pickle.load(f)

            # Error baseline
            self.baseline_errors = np.array(base["baseline_errors"])
            self.rolling_stats["median"] = base["median"]
            self.rolling_stats["mad"] = base["mad"]

            # Precompute baseline FDS distribution
            med = self.rolling_stats["median"]
            mad = self.rolling_stats["mad"] if self.rolling_stats["mad"] > 0 else 1e-6
            self.baseline_fds = (self.baseline_errors - med) / (mad + 1e-8)

            # Window baseline (may or may not exist)
            if "baseline_windows" in base:
                self.baseline_windows = np.array(base["baseline_windows"])
                self.window_mean = float(base.get("window_mean", np.mean(self.baseline_windows)))
                self.window_std = float(base.get("window_std", np.std(self.baseline_windows) + 1e-8))
                self.window_hist_bins = np.array(base.get("window_hist_bins", [])) if "window_hist_bins" in base else None
                self.window_hist_counts = np.array(base.get("window_hist_counts", [])) if "window_hist_counts" in base else None
            else:
                self.baseline_windows = None
                self.window_mean = 0.0
                self.window_std = 1.0
                self.window_hist_bins = None
                self.window_hist_counts = None

            print("‚úÖ Loaded baseline error + window distribution.")
            print(f"   Error median={self.rolling_stats['median']:.6f}, MAD={self.rolling_stats['mad']:.6f}")
            if self.baseline_windows is not None:
                print(f"   Window mean={self.window_mean:.3f}, std={self.window_std:.3f}")
        else:
            print("‚ö†Ô∏è No baseline found. Using live history only.")
            self.baseline_errors = None
            self.baseline_fds = None
            self.baseline_windows = None

    # =================== MODEL LOADING ===================

    def load_model(self):
        try:
            if os.path.exists(self.model_path):
                self.model = keras.models.load_model(self.model_path)
                self.is_model_loaded = True
                print(f"‚úÖ Loaded MLP model from {self.model_path}")

                # Try to load transformer
                transformer_path = self.model_path.replace('.keras', '_transformer.pkl')
                if os.path.exists(transformer_path):
                    with open(transformer_path, 'rb') as f:
                        self.transformer = pickle.load(f)
                    self.transformer_fitted = True
                else:
                    # Fit transformer from true window labels
                    y_original = np.load(
                        "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/generated-data-true-window2.npy"
                    )
                    self.transformer.fit(y_original.reshape(-1, 1))
                    self.transformer_fitted = True
                    with open(transformer_path, 'wb') as f:
                        pickle.dump(self.transformer, f)
                    print("‚ö†Ô∏è No transformer found, fitted a new one.")
            else:
                print(f"‚ùå Model not found at {self.model_path}")
        except Exception as e:
            print(f"‚ùå Error loading model: {e}")

    # =================== FORECAST EVALUATION ===================

    def evaluate_forecast_performance(self, sequence_3d, predicted_window, n_future=1):
        try:
            seq = np.asarray(sequence_3d)
            T, F = seq.shape

            W = int(predicted_window)
            if W < 2:
                W = 2
            if W > T - n_future - 1:
                W = max(2, T - n_future - 1)

            # --- Prepare NSP input ---
            window = seq[-W:-n_future, :]   # shape (W-1, F)
            x = window[np.newaxis, ...]      # shape (1, W-1, F)

            # --- NSP prediction ---
            y_pred = self.nsp_model.predict(x, verbose=0)[0]  # (F,)
            y_true = seq[-n_future, :]                        # (F,)

            mse = float(np.mean((y_true - y_pred) ** 2))
            mae = float(np.mean(np.abs(y_true - y_pred)))

            return {
                "mse": mse,
                "mae": mae,
                "forecast_success": True,
                "actual_values": y_true.tolist(),
                "predicted_values": y_pred.tolist(),
                "window_size_used": W,
                "method": "NSP_LSTM"
            }

        except Exception as e:
            return {
                "mse": 9999.0,
                "mae": 9999.0,
                "forecast_success": False,
                "error": str(e),
                "method": "NSP_LSTM"
            }

    # =================== PERSISTENCE FALLBACK ===================

    def _persistence_forecast(self, seq, target_sensor_index, n_future):
        """
        Persistence fallback for RF evaluation.
        Last-value-carried-forward for target sensor.
        """
        try:
            seq = np.asarray(seq)
            if len(seq) < 2:
                return {
                    'mse': 9999,
                    'mae': 9999,
                    'forecast_success': False,
                    'error': 'Sequence too short',
                    'method': 'Persistence'
                }

            last_value = seq[-1, target_sensor_index]
            predicted_vals = [last_value]
            actual = [seq[-1, target_sensor_index]]

            mse = 0.0
            mae = 0.0

            return {
                'mse': mse,
                'mae': mae,
                'forecast_success': True,
                'actual_values': actual,
                'predicted_values': predicted_vals,
                'target_sensor_index': target_sensor_index,
                'method': 'Persistence',
                'note': 'persistence_fallback'
            }

        except Exception as e:
            return {
                'mse': 9999,
                'mae': 9999,
                'forecast_success': False,
                'error': str(e),
                'method': 'Persistence',
                'note': 'persistence_fallback_failed'
            }

    # =================== PREDICTION PIPELINE ===================

    def predict_window_size(self, feature_vector, sequence_3d):
        """
        Main entrypoint:
          - Predict window W_t
          - Evaluate forecast error e_t
          - Compute FDS (S_t) and WSS (Z_t)
          - Update drift/anomaly logic (FDI, WDI, events)
          - Return full metrics packet
        """
        if not self.is_model_loaded:
            return {'predicted_window': 20, 'error': "Model not loaded"}

        try:
            if feature_vector.ndim == 1:
                feature_vector = feature_vector.reshape(1, -1)

            pred_raw = self.model.predict(feature_vector, verbose=0)

            if self.transformer_fitted:
                predicted_window = int(round(self.transformer.inverse_transform(pred_raw)[0, 0]))
            else:
                predicted_window = int(round(pred_raw[0, 0]))

            # ----------------------------------------
            # WINDOW CLAMP ‚Äî HARD SAFETY FIX
            # ----------------------------------------
            # Prevent negative, zero, or extreme window sizes
            predicted_window = max(2, predicted_window)        # lower bound

            # Forecast evaluation
            forecast_metrics = self.evaluate_forecast_performance(sequence_3d, predicted_window, n_future=1)

            fds = None
            wss = None

            if forecast_metrics.get("forecast_success", False):
                mse = forecast_metrics["mse"]
                mae = forecast_metrics["mae"]

                # Basic stats
                self.mse_history.append(mse)
                self.mae_history.append(mae)
                self.error_memory.append(mse)

                self.performance_stats['total_predictions'] += 1
                self.performance_stats['avg_mse'] = float(np.mean(self.mse_history))
                self.performance_stats['avg_mae'] = float(np.mean(self.mae_history))

                # ---------- Forecast Deviation Score (FDS) ----------
                baseline_median = self.rolling_stats["median"]
                baseline_mad = self.rolling_stats["mad"] if self.rolling_stats["mad"] > 0 else 1e-6
                fds = (mse - baseline_median) / (baseline_mad + 1e-8)
                fds = float(fds) if fds is not None and not np.isnan(fds) else 0.0


                self.last_fds = fds
                self.fds_memory.append(fds)
                self.recent_fds.append(fds)

                # ---------- Window Shift Score (WSS) ----------
                if self.baseline_windows is not None and self.window_std > 0:
                    wss = (predicted_window - self.window_mean) / (self.window_std + 1e-8)
                else:
                    wss = 0.0

                wss = float(wss)
                self.last_wss = wss
                self.window_memory.append(predicted_window)
                self.recent_windows.append(predicted_window)

            # Event (ANOMALY / DRIFT) + Drift indices
            event, sev, fdi, wdi = self._check_for_event()
            self.last_fdi = fdi
            self.last_wdi = wdi

            # Save history for reporting
            record = {
                'timestamp': dt.datetime.now(),
                'predicted_window': predicted_window,
                'forecast_metrics': forecast_metrics,
                'fds': self.last_fds,
                'wss': self.last_wss,
                'fdi': self.last_fdi,
                'wdi': self.last_wdi,
                'event_type': event,
                'severity': sev
            }
            self.prediction_history.append(record)

            return {
                'predicted_window': predicted_window,
                'forecast_metrics': forecast_metrics,
                'fds': self.last_fds,
                'fdi': self.last_fdi,
                'wss': self.last_wss,
                'wdi': self.last_wdi,
                'event_type': event,
                'severity': sev,
                'performance_stats': self.get_recent_performance()
            }
        except Exception as e:
            return {'predicted_window': 20, 'error': str(e)}

    # =================== EVENT LOGIC (ANOMALY + DRIFT) ===================

    def _check_for_event(self):
        """
        Event detection for the Adaptive Window Agent.

        - ANOMALY: deviation of last MSE from baseline (median + k * MAD)
        - DRIFT:
            * FDI: JSD between recent FDS distribution and baseline FDS
            * WDI: JSD between recent window distribution and baseline window distribution
        """
        # Require enough history
        if len(self.error_memory) < 30:
            return None, 0.0, None, None

        last_mse = float(self.error_memory[-1])
        live_errors = np.array(self.error_memory)

        # ---------- BASELINE STATS ----------
        if self.baseline_errors is not None and len(self.baseline_errors) > 10:
            base_errors = np.array(self.baseline_errors)
            baseline_median = np.median(base_errors)
            baseline_mad = np.median(np.abs(base_errors - baseline_median)) + 1e-8
        else:
            baseline_median = np.median(live_errors)
            baseline_mad = np.median(np.abs(live_errors - baseline_median)) + 1e-8

        # Update rolling_stats so other components can see latest baseline-ish values
        self.rolling_stats["median"] = baseline_median
        self.rolling_stats["mad"] = baseline_mad

        # ---------- LIVE STATS ----------
        live_median = np.median(live_errors)
        live_mad = np.median(np.abs(live_errors - live_median)) + 1e-8

        # ---------- ANOMALY THRESHOLD ----------
        baseline_threshold = baseline_median + self.threshold_k * baseline_mad
        live_threshold = live_median + self.threshold_k * live_mad

        anomaly_threshold = 0.8 * baseline_threshold + 0.2 * live_threshold

        is_anomaly = last_mse > anomaly_threshold

        if self.anomaly_cooldown > 0:
            self.anomaly_cooldown -= 1
            is_anomaly = False
        elif is_anomaly:
            self.anomaly_cooldown = self.anomaly_cooldown_steps

        if is_anomaly:
            severity = (last_mse - anomaly_threshold) / (baseline_mad + 1e-6)
            severity = float(severity)
            self.performance_stats["anomaly_events"] += 1
            if self.debug:
                print(f"[ANOMALY] mse={last_mse:.6f}, thr={anomaly_threshold:.6f}, sev={severity:.3f}")
            return "ANOMALY", severity, 0.0, 0.0

        # ---------- DRIFT (FDI + WDI) ----------
        fdi = None
        wdi = None

        # FDI: JSD over FDS distribution
        if self.baseline_fds is not None and len(self.recent_fds) >= 30:
            base_fds = np.asarray(self.baseline_fds)
            recent_fds = np.asarray(self.recent_fds)

            hist_base, bins = np.histogram(base_fds, bins=25, density=True)
            hist_recent, _ = np.histogram(recent_fds, bins=bins, density=True)

            hist_base = hist_base / (hist_base.sum() + 1e-12)
            hist_recent = hist_recent / (hist_recent.sum() + 1e-12)

            fdi = float(jensenshannon(hist_base + 1e-12, hist_recent + 1e-12))

        # WDI: JSD over window-size distribution
        if self.baseline_windows is not None and len(self.recent_windows) >= 30:
            base_win = np.asarray(self.baseline_windows)
            recent_win = np.asarray(self.recent_windows)

            hist_w_base, bins_w = np.histogram(base_win, bins=20, density=True)
            hist_w_recent, _ = np.histogram(recent_win, bins=bins_w, density=True)

            hist_w_base = hist_w_base / (hist_w_base.sum() + 1e-12)
            hist_w_recent = hist_w_recent / (hist_w_recent.sum() + 1e-12)

            wdi = float(jensenshannon(hist_w_base + 1e-12, hist_w_recent + 1e-12))

        # Decide drift if either index is high
        is_drift_fdi = fdi is not None and fdi > self.drift_threshold
        is_drift_wdi = wdi is not None and wdi > self.window_drift_threshold

        is_drift = is_drift_fdi or is_drift_wdi

        if self.drift_cooldown > 0:
            self.drift_cooldown -= 1
            is_drift = False
        else:
            if is_drift:
                self.consecutive_drift_votes += 1
            else:
                self.consecutive_drift_votes = 0

        if self.consecutive_drift_votes >= self.drift_votes_required:
            self.consecutive_drift_votes = 0
            self.drift_cooldown = self.drift_cooldown_steps
            self.performance_stats["drift_events"] += 1
            if self.debug:
                print(f"[DRIFT] FDI={fdi:.4f} WDI={wdi:.4f}")
            fdi = float(fdi) if fdi is not None else 0.0
            wdi = float(wdi) if wdi is not None else 0.0
            return "DRIFT", fdi, fdi, wdi

        # Make safe for printing
        fdi = float(fdi) if fdi is not None else 0.0
        wdi = float(wdi) if wdi is not None else 0.0

        return None, 0.0, fdi, wdi


    # =================== HELPERS ===================

    def get_recent_performance(self):
        all_preds = list(self.prediction_history)

        successful_predictions = [
            p for p in all_preds
            if p.get('forecast_metrics', {}).get('forecast_success', False)
        ]

        return {
            'total_predictions': len(all_preds),
            'successful_predictions': len(successful_predictions),
            'success_rate': len(successful_predictions) / max(len(all_preds), 1),
            'drift_events': self.performance_stats['drift_events'],
            'anomaly_events': self.performance_stats['anomaly_events'],
            'retraining_events': self.performance_stats['retraining_events'],
            'recent_mse': float(np.mean(list(self.mse_history)[-10:])) if self.mse_history else 0,
            'avg_mse': float(np.mean(self.mse_history)) if self.mse_history else 0,
            'recent_mae': float(np.mean(list(self.mae_history)[-10:])) if self.mae_history else 0,
            'avg_mae': float(np.mean(self.mae_history)) if self.mae_history else 0,
            'transformer_fitted': self.transformer_fitted,
            'last_fdi': self.last_fdi,
            'last_wdi': self.last_wdi,
        }

    def save_performance_state(self, filepath: str):
        """Save performance statistics + prediction history to JSON"""
        try:
            state = {
                'performance_stats': self.performance_stats.copy(),
                'prediction_history': list(self.prediction_history)[-100:],
                'mse_history': list(self.mse_history),
                'mae_history': list(self.mae_history),
                'transformer_fitted': self.transformer_fitted
            }
            with open(filepath, 'w') as f:
                json.dump(state, f, indent=2, default=str)
            print(f"‚úÖ Performance state saved to {filepath}")
        except Exception as e:
            print(f"‚ùå Failed to save performance state: {e}")


#===============================================================================================================================================
###  SENSOR AGENTS - INDIVIDUAL AND MASTER
#--------------------------------------==========================================================================================================
import numpy as np
import pickle
import os
from collections import deque
from datetime import datetime
from typing import Dict, List, Tuple
import warnings
import pandas as pd
warnings.filterwarnings('ignore')

# ML libraries
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from scipy import stats
from scipy.spatial.distance import jensenshannon

# Deep learning
try:
    from tensorflow.keras.models import load_model
    KERAS_AVAILABLE = True
except ImportError:
    KERAS_AVAILABLE = False

from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    roc_auc_score,
    average_precision_score,
    precision_recall_curve,
    roc_curve
)


# =====================================================
# ROBUST SENSOR AGENT - Observes ONE sensor with AE model
# =====================================================


# =====================================================
# ROBUST SENSOR AGENT - Observes ONE sensor with AE model
# =====================================================

class RobustSensorAgent:
    """
    Robust Sensor Agent for ONE sensor with advanced anomaly & drift detection.

    Loads pretrained AE model + metadata (scaler, baseline errors, rolling stats).
    Computes anomaly score via reconstruction error, applies adaptive thresholding,
    drift detection, and outputs robust anomaly/drift/retrain flags.
    """

    def __init__(self,
                 sensor_id: int,
                 model_path: str = None,
                 window_length: int = 10, #K
                 memory_size: int = 1000,
                 threshold_k: float = 2.0,
                 drift_threshold: float = 0.1,
                warmup_steps: int = 100):    # <‚îÄ‚îÄ NEW PARAM

        self.sensor_id = sensor_id
        self.window_length = window_length
        self.threshold_k = threshold_k
        self.drift_threshold = drift_threshold
        self.warmup_steps = warmup_steps

        # Model & metadata
        self.model = None
        self.scaler = None
        self.is_model_loaded = False

        # Buffers
        self.error_memory = deque(maxlen=memory_size)
        self.data_memory = deque(maxlen=memory_size)
        self.recent_errors = deque(maxlen=100)

        # Rolling stats
        self.rolling_stats = {
            'median': 0.0,
            'mad': 1.0,
            'mean': 0.0,   # backward compatibility
            'std': 1.0,    # backward compatibility
            'q95': 0.0,
            'q99': 0.0
        }
        self.baseline_errors = None

        # Counters
        self.total_processed = 0
        self.anomalies_detected = 0
        self.drift_detected_count = 0
        self.last_stats_update = datetime.now()

        self.anomaly_cooldown = 0
        self.drift_cooldown = 0

        self.anomaly_cooldown_steps = 5    # you can tune
        self.drift_cooldown_steps = 10     # you can tune

        self.consecutive_drift_votes = 0
        self.consecutive_anomaly_votes = 0

        if model_path:
            self.load_model(model_path)

    def load_model(self, model_path: str) -> bool:
        """Load pretrained AE model + metadata."""
        try:
            if KERAS_AVAILABLE and model_path.endswith('.h5'):
                self.model = load_model(model_path, compile=False)

                # Correct metadata file
                metadata_path = model_path.replace('_model.h5', '_metadata.pkl')

                if os.path.exists(metadata_path):
                    with open(metadata_path, 'rb') as f:
                        metadata = pickle.load(f)

                    baseline = metadata.get('baseline_stats', None)

                    if baseline is not None:
                        # Initialize rolling stats from training
                        # Load robust baseline stats
                        self.rolling_stats['median'] = baseline.get('median')
                        self.rolling_stats['mad']    = baseline.get('mad')

                        # Backward compatibility for other parts of system
                        self.rolling_stats['mean'] = self.rolling_stats['median']
                        self.rolling_stats['std']  = self.rolling_stats['mad']

                        self.rolling_stats['q95']  = baseline['q95']
                        self.rolling_stats['q99']  = baseline['q99']

                        # Save baseline distribution for drift detection
                        self.baseline_errors = np.array(baseline['baseline_errors'])

                # AE was trained on raw, NOT scaled
                self.scaler = None

            else:
                raise ValueError("Unsupported model format ‚Äì expecting .h5 AE model")

            self.is_model_loaded = True
            print(f"‚úÖ AE model loaded for sensor {self.sensor_id}")
            return True

        except Exception as e:
            print(f"‚ùå Failed to load AE model for sensor {self.sensor_id}: {e}")
            return False


    def observe(self, sensor_subsequence: np.ndarray) -> Dict:
        """Observe subsequence [window_length] and return anomaly/drift flags."""
        if not self.is_model_loaded:
            return {"sensor_id": self.sensor_id, "error": "no_model_loaded", "timestamp": datetime.now()}

        if len(sensor_subsequence) != self.window_length:
            return {"sensor_id": self.sensor_id,
                    "error": f"invalid_length_expected_{self.window_length}_got_{len(sensor_subsequence)}",
                    "timestamp": datetime.now()}

        # 1. Anomaly score
        anomaly_score = self._compute_robust_anomaly_score(sensor_subsequence)

        # 2. Update memory
        self.data_memory.append(sensor_subsequence.copy())
        self.error_memory.append(anomaly_score)
        self.recent_errors.append(anomaly_score)

        # 3

        # --------------- WARM-UP PHASE -----------------
        # During warm-up, rolling stats ignore live data and stay fixed
        if self.total_processed < self.warmup_steps:
            med = np.median(self.baseline_errors)
            mad = np.median(np.abs(self.baseline_errors - med)) + 1e-8

            self.rolling_stats['median'] = med
            self.rolling_stats['mad'] = mad
            self.rolling_stats['mean'] = med     # backward compatibility
            self.rolling_stats['std'] = mad
        else:
            # After warm-up, rolling stats evolve normally
            if len(self.error_memory) >= 50 and len(self.error_memory) % 10 == 0:
                self._update_rolling_stats(list(self.error_memory)[-50:])

        # 4. Flags
        is_anomaly = self._check_adaptive_anomaly(anomaly_score)
        drift_flag = self._check_advanced_drift()
        needs_retrain = self._check_retrain_need()
        confidence = self._compute_robust_confidence(anomaly_score)

        # 5. Update counters
        self.total_processed += 1
        if is_anomaly: self.anomalies_detected += 1
        if drift_flag: self.drift_detected_count += 1

        return {
            "sensor_id": self.sensor_id,
            "timestamp": datetime.now(),
            "is_anomaly": bool(is_anomaly),
            "drift_flag": bool(drift_flag),
            "needs_retrain_flag": bool(needs_retrain),
            "anomaly_score": float(anomaly_score),
            "confidence": float(confidence),
            "threshold_used": float(self.rolling_stats['median'] + self.threshold_k * self.rolling_stats['mad']),
            "anomaly_rate": self.anomalies_detected / max(1, self.total_processed),
            "drift_rate": self.drift_detected_count / max(1, self.total_processed)
        }

    def _compute_robust_anomaly_score(self, subsequence: np.ndarray) -> float:
        """Compute reconstruction error using AE model on RAW values."""
        try:
            # Ensure shape: [1, window_length, 1]
            X = subsequence.reshape(1, self.window_length, 1)
            reconstruction = self.model.predict(X, verbose=0)

            error = mean_squared_error(
                subsequence.flatten(),
                reconstruction.flatten()
            )
            return max(0.0, error)

        except Exception as e:
            print(f"‚ö†Ô∏è AE inference failed for sensor {self.sensor_id}: {e}")
            # Fallback: variance of raw subsequence
            return float(np.var(subsequence))

    def _update_rolling_stats(self, errors: List[float]):
        errors_array = np.array(errors)

        median = np.median(errors_array)
        mad = np.median(np.abs(errors_array - median)) + 1e-8  # avoid zero

        # Store
        self.rolling_stats['median'] = median
        self.rolling_stats['mad'] = mad

        # Backward compatibility fields (for plotting)
        self.rolling_stats['mean'] = median
        self.rolling_stats['std'] = mad

        # Percentile bands (unchanged; good for drift & visualization)
        self.rolling_stats['q95'] = np.percentile(errors_array, 95)
        self.rolling_stats['q99'] = np.percentile(errors_array, 99)

        self.last_stats_update = datetime.now()

    def _check_adaptive_anomaly(self, score: float) -> bool:
        median = self.rolling_stats.get('median', self.rolling_stats['mean'])
        mad = self.rolling_stats.get('mad', self.rolling_stats['std'])
        threshold = median + self.threshold_k * mad
        is_anomaly_now = score > threshold

        # Cooldown active ‚Üí suppress anomaly
        if self.anomaly_cooldown > 0:
            self.anomaly_cooldown -= 1
            return False

        # No cooldown and anomaly happened ‚Üí activate cooldown
        if is_anomaly_now:
            self.anomaly_cooldown = self.anomaly_cooldown_steps
            return True

        return False

    def _check_advanced_drift(self) -> bool:
        if self.baseline_errors is None or len(self.recent_errors) < 30:
            return False
        try:
            hist_baseline, bins = np.histogram(self.baseline_errors, bins=20, density=True)
            hist_recent, _ = np.histogram(list(self.recent_errors), bins=bins, density=True)
            hist_baseline += 1e-10; hist_recent += 1e-10
            hist_baseline /= hist_baseline.sum(); hist_recent /= hist_recent.sum()
            js_divergence = jensenshannon(hist_baseline, hist_recent)
            is_drift_now = js_divergence > self.drift_threshold

            # Cooldown active ‚Üí suppress
            if self.drift_cooldown > 0:
                self.drift_cooldown -= 1
                return False

            # Multi-step confirmation: require 3 drift votes in last few steps
            if is_drift_now:
                self.consecutive_drift_votes += 1
            else:
                self.consecutive_drift_votes = 0

            if self.consecutive_drift_votes >= 3:
                self.drift_cooldown = self.drift_cooldown_steps
                self.consecutive_drift_votes = 0
                return True

            return False

        except Exception:
            try:
                _, p_value = stats.ks_2samp(self.baseline_errors, list(self.recent_errors))
                return p_value < 0.05
            except:
                return False

    def _check_retrain_need(self) -> bool:
        if len(self.error_memory) < 100: return False
        recent_errors = list(self.error_memory)[-50:]
        threshold = self.rolling_stats['mean'] + self.threshold_k * self.rolling_stats['std']
        anomaly_rate = sum(1 for e in recent_errors if e > threshold) / len(recent_errors)
        criteria = [
            anomaly_rate > 0.3,
            self.drift_detected_count > 0.1 * self.total_processed,
            np.mean(recent_errors) > 2.0 * self.rolling_stats['mean'] if len(recent_errors) > 0 else False,
            (datetime.now() - self.last_stats_update).days > 7
        ]
        return sum(criteria) >= 2

    def _compute_robust_confidence(self, score: float) -> float:
        median = self.rolling_stats.get('median')
        mad = self.rolling_stats.get('mad')

        if mad == 0:
            return 0.5

        threshold = median + self.threshold_k * mad

        z = (score - threshold) / mad  # how far beyond threshold?

        # Smooth probability-like mapping
        confidence = 1 / (1 + np.exp(-z))

        return float(np.clip(confidence, 0.0, 1.0))



# =====================================================
# ROBUST MASTER AGENT
# =====================================================

class RobustMasterAgent:
    """Aggregates sensor results, makes system-level anomaly/drift/retrain decisions."""
    def __init__(self, sensor_agents: List[RobustSensorAgent],
                 system_anomaly_threshold: float = 0.3,
                 drift_threshold: float = 0.2,
                 retrain_threshold: float = 0.15):
        self.sensor_agents = sensor_agents
        self.num_sensors = len(sensor_agents)
        self.system_anomaly_threshold = system_anomaly_threshold
        self.drift_threshold = drift_threshold
        self.retrain_threshold = retrain_threshold

    def process_system_input(self, system_subsequence: np.ndarray) -> Dict:
        """Process [window_length, num_sensors] multivariate subsequence."""
        timestamp = datetime.now()
        if system_subsequence.shape[1] != self.num_sensors:
            return {"error": f"Expected {self.num_sensors} sensors, got {system_subsequence.shape[1]}",
                    "timestamp": timestamp}

        # 1. Collect sensor observations
        sensor_results = []
        for i, agent in enumerate(self.sensor_agents):
            sensor_data = system_subsequence[:, i]
            result = agent.observe(sensor_data)
            sensor_results.append(result)

        # 2. Simple aggregation
        anomalies = sum(1 for r in sensor_results if r.get("is_anomaly"))
        drifts = sum(1 for r in sensor_results if r.get("drift_flag"))
        retrains = sum(1 for r in sensor_results if r.get("needs_retrain_flag"))

        anomaly_rate = anomalies / max(1, self.num_sensors)
        drift_rate = drifts / max(1, self.num_sensors)
        retrain_rate = retrains / max(1, self.num_sensors)

        system_decisions = {
            "system_anomaly": anomaly_rate >= self.system_anomaly_threshold,
            "system_drift": drift_rate >= self.drift_threshold,
            "system_needs_retrain": retrain_rate >= self.retrain_threshold,
            "anomaly_rate": anomaly_rate,
            "drift_rate": drift_rate,
            "retrain_rate": retrain_rate
        }

        return {
            "timestamp": timestamp,
            "sensor_results": sensor_results,
            "system_decisions": system_decisions
        }



# =====================================================
# SYSTEM CREATION
# =====================================================

def create_robust_system(num_sensors: int, models_dir: str, win_length: int, warmup_steps: int = 100) -> Tuple[List[RobustSensorAgent], RobustMasterAgent]:
    """Create robust sensor system loading AE models + metadata."""
    print(f"üöÄ Creating robust system with {num_sensors} sensors")
    sensor_agents = []
    for sensor_id in range(num_sensors):
        model_path = os.path.join(models_dir, f"sensor_{sensor_id}_model.h5")
        agent = RobustSensorAgent(sensor_id=sensor_id,
                                  model_path=model_path if os.path.exists(model_path) else None,
                                  window_length=win_length,
                                  memory_size=1000,
                                  threshold_k=2.0,
                                  drift_threshold=0.1,
                                  warmup_steps=warmup_steps)       # <‚îÄ‚îÄ NEW
        sensor_agents.append(agent)

    master = RobustMasterAgent(sensor_agents=sensor_agents,
                               system_anomaly_threshold=0.3,
                               drift_threshold=0.2,
                               retrain_threshold=0.15)
    print(f"‚úÖ Created system: {len([a for a in sensor_agents if a.is_model_loaded])}/{num_sensors} models loaded")

    return sensor_agents, master



# =====================================================
# DECISION AGENT (replaces / extends CoordinatorAgent)
# =====================================================
from typing import Optional

# =====================================================
# DecisionAgent V2-PURE (Prediction-focused fusion)
# =====================================================
from typing import Optional, Dict, Any
from datetime import datetime
import numpy as np


class DecisionAgent:
    """
    DecisionAgent V2-PURE:

    - INPUTS (same call signature as before):
        * master_output: RobustMasterAgent.process_system_input(...)
        * window_output: AdaptiveWindowAgent.predict_window_size(...)
        * model_outputs: optional dict, expected keys:
              - "failure_prob": float in [0,1]  (ML failure prediction)
              - "prototype_score": float (optional, diagnostic only)
              - "transformer_prob": float (optional, diagnostic only)
        * metadata: optional dict (e.g. index, timestamps)

    - LOGIC:
        * Detection layer (ANOMALY / DRIFT / RETRAIN)
              -> derived ONLY from sensor + window signals
        * Prediction layer (FAILURE PREDICTION)
              -> derived ONLY from ML failure_prob
        * Fusion:
              -> prediction dominates alert_level / severity
              -> detection still influences severity but has lower weight

    - OUTPUT (unified decision packet):
        {
          "timestamp": ...
          "final_failure": bool               # prediction-driven
          "final_anomaly": bool               # sensor-driven
          "final_drift": bool                 # sensor/window-driven
          "final_retrain": bool               # retrain suggestion
          "alert_level": "NORMAL"/"LOW"/"MEDIUM"/"HIGH"/"CRITICAL",
          "scores": {
              "prediction_failure_prob": float,
              "prediction_fused_score": float,      # main score for alerting
              "sensor_anomaly_rate": float,
              "sensor_drift_rate": float,
              "sensor_retrain_rate": float,
              "window_event_is_anomaly": 0/1,
              "window_event_is_drift": 0/1,
              "window_mse": float or None,
              "prototype_score": float or None,
              "transformer_prob": float or None,
          },
          "window_agent": {
              "predicted_window": int or None,
              "window_anomaly": bool,
              "window_drift": bool,
          },
          "raw": {
              "master_output": ...,
              "window_output": ...,
              "model_outputs": ...,
              "metadata": ...,
          },
        }
    """

    def __init__(
        self,
        # weights for FUSION (prediction vs detection)
        prediction_weight: float = 0.7,     # ML failure prediction dominates
        sensor_weight: float = 0.2,         # sensor anomaly contribution
        drift_weight: float = 0.1,          # drift contribution

        # thresholds
        failure_threshold: float = 0.5,     # base failure decision
        failure_critical_threshold: float = 0.8,  # high-risk cutoff

        anomaly_rate_threshold: float = 0.3,    # for final_anomaly
        drift_rate_threshold: float = 0.2,      # for final_drift
        retrain_rate_threshold: float = 0.15,   # for final_retrain
    ):
        # Fusion weights (normalised later)
        self.prediction_weight = prediction_weight
        self.sensor_weight = sensor_weight
        self.drift_weight = drift_weight

        self.failure_threshold = failure_threshold
        self.failure_critical_threshold = failure_critical_threshold

        self.anomaly_rate_threshold = anomaly_rate_threshold
        self.drift_rate_threshold = drift_rate_threshold
        self.retrain_rate_threshold = retrain_rate_threshold

        # history buffer (optional)
        self.history = []

    # -------------------------------------------------
    # Main decision function
    # -------------------------------------------------
    def decide(
        self,
        master_output: Dict[str, Any],
        window_output: Dict[str, Any],
        model_outputs: Optional[Dict[str, Any]] = None,
        metadata: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        timestamp = datetime.now()

        # -----------------------------
        # 1) Extract master/system info
        # -----------------------------
        sys_dec = master_output.get("system_decisions", {}) if master_output else {}
        sys_anomaly = bool(sys_dec.get("system_anomaly", False))
        sys_drift = bool(sys_dec.get("system_drift", False))
        sys_retrain = bool(sys_dec.get("system_needs_retrain", False))

        anomaly_rate = float(sys_dec.get("anomaly_rate", 0.0))
        drift_rate = float(sys_dec.get("drift_rate", 0.0))
        retrain_rate = float(sys_dec.get("retrain_rate", 0.0))

        # -----------------------------
        # 2) Extract window agent info
        # -----------------------------
        window_event_type = window_output.get("event_type", None) if window_output else None
        win_anomaly = (window_event_type == "ANOMALY")
        win_drift = (window_event_type == "DRIFT")

        predicted_window = window_output.get("predicted_window", None) if window_output else None
        window_mse = None
        if window_output and "forecast_metrics" in window_output:
            window_mse = window_output["forecast_metrics"].get("mse", None)

        # -----------------------------
        # 3) Extract ML prediction info
        # -----------------------------
        failure_prob = None
        prototype_score = None
        transformer_prob = None

        if model_outputs is not None:
            # failure prediction only ‚Äì this is the ML ‚Äútruth‚Äù in this layer
            failure_prob = model_outputs.get("failure_prob", None)
            # keep these as diagnostic/extra info, NOT for anomaly
            prototype_score = model_outputs.get("prototype_score", None)
            transformer_prob = model_outputs.get("transformer_prob", None)

        # Guard: normalise missing failure_prob -> 0.0
        if failure_prob is None:
            failure_prob = 0.0
        failure_prob = float(failure_prob)

        # -----------------------------
        # 4) Detection layer (ANOMALY / DRIFT / RETRAIN)
        #    -> purely sensor + window based
        # -----------------------------
        # final_anomaly: system anomaly rate OR strong window anomaly
        final_anomaly = (
            anomaly_rate >= self.anomaly_rate_threshold
            or sys_anomaly
            or win_anomaly
        )

        # final_drift: drift rate OR window drift
        final_drift = (
            drift_rate >= self.drift_rate_threshold
            or sys_drift
            or win_drift
        )

        # final_retrain: explicit retrain + high retrain rate
        final_retrain = (
            sys_retrain
            or retrain_rate >= self.retrain_rate_threshold
            or (final_drift and anomaly_rate > 0.1)
        )

        # -----------------------------
        # 5) Prediction layer (FAILURE)
        #    -> failure_prob from ML ONLY
        # -----------------------------
        # Base failure flag purely from probability
        base_failure = failure_prob >= self.failure_threshold

        # Context-aware adjustment:
        # if failure_prob is moderate but sensor anomalies are strong,
        # we can gently promote final_failure = True
        strong_sensor_context = (
            anomaly_rate > (self.anomaly_rate_threshold * 1.2)
            or (win_anomaly and anomaly_rate > 0.1)
        )

        if not base_failure and failure_prob >= (self.failure_threshold * 0.7) and strong_sensor_context:
            final_failure = True
        else:
            final_failure = base_failure

        # -----------------------------
        # 6) Fused severity / alert scoring
        # -----------------------------
        # normalise weights
        w_sum = self.prediction_weight + self.sensor_weight + self.drift_weight
        if w_sum <= 0:
            w_pred = 1.0
            w_sens = 0.0
            w_drift = 0.0
        else:
            w_pred = self.prediction_weight / w_sum
            w_sens = self.sensor_weight / w_sum
            w_drift = self.drift_weight / w_sum

        # detection ‚Äúintensity‚Äù ‚Äì bounded in [0,1]
        detection_intensity = np.clip(anomaly_rate, 0.0, 1.0)
        drift_intensity = np.clip(drift_rate, 0.0, 1.0)

        # fused prediction-driven risk score
        prediction_fused_score = (
            w_pred * failure_prob +
            w_sens * detection_intensity +
            w_drift * drift_intensity +
            (0.05 if win_anomaly else 0.0) +
            (0.03 if win_drift else 0.0)
        )
        prediction_fused_score = float(np.clip(prediction_fused_score, 0.0, 1.0))

        # -----------------------------
        # 7) Map fused score + flags ‚Üí alert_level
        # -----------------------------
        alert_level = self._compute_alert_level(
            failure_prob=failure_prob,
            fused_score=prediction_fused_score,
            final_failure=final_failure,
            final_anomaly=final_anomaly,
            final_drift=final_drift,
        )

        # -----------------------------
        # 8) Build decision packet
        # -----------------------------
        decision = {
            "timestamp": timestamp.isoformat(),
            # prediction-focused
            "final_failure": bool(final_failure),
            # detection-focused
            "final_anomaly": bool(final_anomaly),
            "final_drift": bool(final_drift),
            "final_retrain": bool(final_retrain),
            # human-facing
            "alert_level": alert_level,
            # structured scores
            "scores": {
                # Prediction side
                "prediction_failure_prob": float(failure_prob),
                "prediction_fused_score": float(prediction_fused_score),

                # Detection side
                "sensor_anomaly_rate": float(anomaly_rate),
                "sensor_drift_rate": float(drift_rate),
                "sensor_retrain_rate": float(retrain_rate),

                # Window side
                "window_event_type": window_event_type,
                "window_event_is_anomaly": 1.0 if win_anomaly else 0.0,
                "window_event_is_drift": 1.0 if win_drift else 0.0,
                "window_mse": float(window_mse) if window_mse is not None else None,

                # Extra ML diagnostics (not used for detection)
                "prototype_score": float(prototype_score) if prototype_score is not None else None,
                "transformer_prob": float(transformer_prob) if transformer_prob is not None else None,
            },
            "window_agent": {
                "predicted_window": predicted_window,
                "window_anomaly": bool(win_anomaly),
                "window_drift": bool(win_drift),
            },
            "raw": {
                "master_output": master_output,
                "window_output": window_output,
                "model_outputs": model_outputs,
                "metadata": metadata,
            },
        }

        self.history.append(decision)
        return decision

    # -------------------------------------------------
    # Helper: alert level mapping
    # -------------------------------------------------
    def _compute_alert_level(
        self,
        failure_prob: float,
        fused_score: float,
        final_failure: bool,
        final_anomaly: bool,
        final_drift: bool,
    ) -> str:
        """
        Map scores + flags into a discrete alert level.
        ML failure prediction dominates; sensor anomalies modulate.
        """
        # Hard overrides for critical cases
        if failure_prob >= self.failure_critical_threshold and (final_anomaly or final_drift):
            return "CRITICAL"

        if final_failure and fused_score >= 0.9:
            return "CRITICAL"

        # General mapping based on fused_score and flags
        if not final_failure and not final_anomaly and not final_drift:
            if fused_score < 0.3:
                return "NORMAL"
            if fused_score < 0.4:
                return "LOW"

        # moderate region
        if fused_score < 0.5:
            return "LOW"
        if fused_score < 0.7:
            return "MEDIUM"
        if fused_score < 0.9:
            return "HIGH"

        return "CRITICAL"


# =====================================================
# UPGRADED EXPERT AGENT (SOTA-STYLE FOR METROPT)
# =====================================================
import json
from datetime import datetime
from typing import Optional, Dict, Any, List
import numpy as np


class ExpertAgent:
    """
    ExpertAgent (Upgraded):
      - Invoked when DecisionAgent flags an anomaly (or when you choose).
      - Uses rich context:
          * decision_packet (DecisionAgent output)
          * raw master_output (sensor-level results)
          * raw window_output (AdaptiveWindowAgent)
          * prototype scores, etc. via model_outputs in decision_packet
          * recent multivariate sensor window (numeric snapshot)
          * recent history from EventStore
          * static domain knowledge (sensor metadata, fault patterns)
      - Calls an LLM to produce:
          * summary
          * explanation
          * likely_fault
          * recommended_action
          * severity
      - Validates JSON and falls back gracefully on errors.
    """

    def __init__(
        self,
        event_store,
        system_description: str = "MetroPT Air Production Unit (APU)",
        llm_client: Optional[object] = None,
        history_limit: int = 100,
        max_history_for_prompt: int = 10,
        window_preview_len: int = 10,
        sensor_metadata: Optional[Dict[str, str]] = None,
        fault_knowledge: Optional[List[str]] = None,
        max_retries: int = 2,
    ):
        """
        event_store         : EventStore instance (for past decisions)
        system_description  : short description of the system
        llm_client          : OpenAI client (OpenAI())
        history_limit       : how many past events to fetch from DB
        max_history_for_prompt : how many to actually show to LLM
        window_preview_len  : last N timesteps per sensor to show
        sensor_metadata     : mapping sensor_index -> human-readable name
        fault_knowledge     : list of domain fault patterns (strings)
        max_retries         : number of LLM retry attempts on failure
        """
        self.store = event_store
        self.system_description = system_description
        self.llm_client = llm_client
        self.history_limit = history_limit
        self.max_history_for_prompt = max_history_for_prompt
        self.window_preview_len = window_preview_len
        self.max_retries = max_retries

        # Default MetroPT-style sensor metadata (can be overridden)
        if sensor_metadata is None:
            self.sensor_metadata = {
                0: "TP2 (bar): the measure of the pressure on the compressor",
                1: "TP3 (bar): the measure of the pressure generated at the pneumatic panel",
                2: "H1 (bar) ‚Äì the measure of the pressure generated due to pressure drop when the discharge of the cyclonic separator filter occurs",
                3: "DV pressure (bar): the measure of the pressure drop generated when the towers discharge air dryers; a zero reading indicates that the compressor is operating under load",
                4: "Reservoirs (bar): the measure of the downstream pressure of the reservoirs, which should be close to the pneumatic panel pressure (TP3)",
                5: "Oil Temperature (¬∫C) :  the measure of the oil temperature on the compressor",
                6: "Motor Current (A) ‚Äì  the measure of the current of one phase of the three-phase motor; it presents values close to 0A - when it turns off, 4A - when working offloaded, 7A - when working under load, and 9A - when it starts working",
                7: "COMP - the electrical signal of the air intake valve on the compressor; it is active when there is no air intake, indicating that the compressor is either turned off or operating in an offloaded state.",
                8: "DV electric ‚Äì the electrical signal that controls the compressor outlet valve; it is active when the compressor is functioning under load and inactive when the compressor is either off or operating in an offloaded state.",
                9: "TOWERS ‚Äì the electrical signal that defines the tower responsible for drying the air and the tower responsible for draining the humidity removed from the air; when not active, it indicates that tower one is functioning; when active, it indicates that tower two is in operation.",
                10: "MPG ‚Äì the electrical signal responsible for starting the compressor under load by activating the intake valve when the pressure in the air production unit (APU) falls below 8.2 bar; it activates the COMP sensor, which assumes the same behaviour as the MPG sensor",
                11: "LPS ‚Äì the electrical signal that detects and activates when the pressure drops below 7 bars",

            }
        else:
            self.sensor_metadata = sensor_metadata

        # Default MetroPT-style fault patterns (can be overridden)
        if fault_knowledge is None:
            self.fault_knowledge = [
                "Air Leak: gradual pressure decay + increased compressor runtime.",
                "Blockage: oscillatory or unstable pressure + abnormal valve cycling.",
                "Overheating: rising oil temperature + increased motor current.",
                "Valve Stuck: valve digital state frozen while pressure behaviour is abnormal.",
                "Short Cycling: frequent compressor start/stop in short intervals.",
                "Sensor Failure: flat-line, impossible values, or inconsistent readings.",
            ]
        else:
            self.fault_knowledge = fault_knowledge

    # =====================================================
    # PUBLIC ENTRYPOINT
    # =====================================================
    def analyse_anomaly(
        self,
        decision_packet: dict,
        system_subsequence: np.ndarray,
        extra_context: Optional[dict] = None,
    ) -> dict:
        """
        Main entry point:
          - decision_packet: from DecisionAgent.decide(...)
          - system_subsequence: [window_length, num_sensors] np.ndarray
          - extra_context: anything else (index, timestamps, etc.)

        Returns:
          expert_packet dict with:
            - timestamp
            - decision_packet
            - prompt_used
            - llm_result (JSON from model or fallback)
        """
        # 1) Fetch recent past decisions from EventStore
        recent_raw_decisions = self.store.fetch_recent(limit=self.history_limit)

        # 2) Build rich prompt
        prompt = self._build_prompt(
            decision_packet=decision_packet,
            system_subsequence=system_subsequence,
            extra_context=extra_context,
            recent_events=recent_raw_decisions,
        )

        # 3) Call LLM with JSON-only contract
        llm_result = self._call_llm_with_json(prompt)

        expert_packet = {
            "timestamp": datetime.now().isoformat(),
            "decision_packet": decision_packet,
            "prompt_used": prompt,
            "llm_result": llm_result,
        }

        return expert_packet

    # =====================================================
    # PROMPT CONSTRUCTION
    # =====================================================
    def _build_prompt(
        self,
        decision_packet: dict,
        system_subsequence: np.ndarray,
        extra_context: Optional[dict],
        recent_events: List[dict],
    ) -> str:
        """
        Build a rich natural-language + structured prompt for the LLM.
        Includes:
          - system description
          - sensor metadata
          - known fault patterns
          - current decision scores + flags
          - per-sensor anomaly info (if available)
          - small numeric snapshot of current window
          - compressed recent history
          - explicit JSON response schema
        """

        # -------------------------------
        # 1) Extract top-level scores
        # -------------------------------
        scores = decision_packet.get("scores", {})
        window_info = decision_packet.get("window_agent", {})
        alert_level = decision_packet.get("alert_level", "UNKNOWN")
        final_anomaly = decision_packet.get("final_anomaly", False)
        final_drift = decision_packet.get("final_drift", False)
        final_retrain = decision_packet.get("final_retrain", False)

        # -------------------------------
        # 2) Extract raw master/window/model data (if available)
        # -------------------------------
        raw_block = decision_packet.get("raw", {})
        master_output = raw_block.get("master_output", None)
        window_output = raw_block.get("window_output", None)
        model_outputs = raw_block.get("model_outputs", {})
        metadata = raw_block.get("metadata", {})

        prototype_score = model_outputs.get("prototype_score", None)
        transformer_prob = model_outputs.get("transformer_prob", None)

        # -------------------------------
        # 3) Per-sensor anomaly stats from master_output
        # -------------------------------
        per_sensor_summary = []
        if master_output is not None:
            sensor_results = master_output.get("sensor_results", [])
            for i, res in enumerate(sensor_results):
                name = self.sensor_metadata.get(i, f"Sensor_{i}")
                per_sensor_summary.append({
                    "sensor_index": i,
                    "name": name,
                    "is_anomaly": bool(res.get("is_anomaly", False)),
                    "drift_flag": bool(res.get("drift_flag", False)),
                    "needs_retrain": bool(res.get("needs_retrain_flag", False)),
                    "anomaly_score": float(res.get("anomaly_score", 0.0)),
                    "confidence": float(res.get("confidence", 0.0)),
                })

        # -------------------------------
        # 4) Numeric snapshot of current window
        # -------------------------------
        # system_subsequence: [T, F]
        try:
            seq = np.asarray(system_subsequence)
            T, F = seq.shape
        except Exception:
            seq = np.array(system_subsequence)
            if seq.ndim == 1:
                seq = seq.reshape(-1, 1)
            T, F = seq.shape

        preview_len = min(self.window_preview_len, T)
        window_preview = seq[-preview_len:]  # shape [preview_len, F]

        # Represent as {sensor_name: [values...]}
        sensor_window_dict = {}
        for j in range(F):
            name = self.sensor_metadata.get(j, f"Sensor_{j}")
            sensor_window_dict[name] = window_preview[:, j].round(4).tolist()

        # -------------------------------
        # 5) Compressed recent history
        # -------------------------------
        history_summaries = []
        for ev in recent_events[: self.max_history_for_prompt]:
            try:
                # ev is a past decision_packet (because we stored 'packet=decision_packet')
                scores_ev = ev.get("scores", {})
                history_summaries.append({
                    "timestamp": ev.get("timestamp", ""),
                    "alert_level": ev.get("alert_level", "UNKNOWN"),
                    "anomaly_score": scores_ev.get("anomaly_score", None),
                    "drift_score": scores_ev.get("drift_score", None),
                })
            except Exception:
                continue

        # -------------------------------
        # 6) Sensor metadata & fault patterns as text
        # -------------------------------
        sensor_meta_list = [f"{idx}: {desc}" for idx, desc in self.sensor_metadata.items()]

        fault_knowledge_text = "\n".join(
            [f"- {fk}" for fk in self.fault_knowledge]
        )

        # -------------------------------
        # 7) Build final instruction with JSON schema
        # -------------------------------
        schema_instruction = """
You MUST respond with ONLY a single valid JSON object, no extra text.
The JSON MUST have exactly the following keys at the top level:

- "summary": short 1-2 sentence description of what is happening.
- "explanation": 2-6 sentences, clear reasoning in industrial / physical terms.
- "likely_fault": short label of the most likely fault type, or "Unknown".
- "recommended_action": one of:
    - "IGNORE"
    - "MONITOR"
    - "ACK_AND_INVESTIGATE"
    - "SCHEDULE_MAINTENANCE"
    - "IMMEDIATE_SHUTDOWN"
- "severity": one of: "LOW", "MEDIUM", "HIGH", "CRITICAL".

Do NOT include any extra keys outside these five.
Do NOT include any surrounding text, markdown, or commentary.
Return ONLY the JSON object.
"""

        prompt = f"""
You are an expert industrial fault diagnosis assistant for: {self.system_description}.

System context:
- This is a compressed-air production unit (APU) of a metro train.
- Sensors include analog (pressure, current, temperature) and digital (valves, states).

Sensor metadata (index -> description):
{json.dumps(sensor_meta_list, indent=2)}

Known fault patterns (domain knowledge):
{fault_knowledge_text}

CURRENT DECISION PACKET:

- Alert level: {alert_level}

PREDICTION SIGNALS (from ML failure model):
- final_failure (prediction-based): {decision_packet.get("final_failure")}
- prediction_failure_prob: {scores.get("prediction_failure_prob")}
- prediction_fused_score: {scores.get("prediction_fused_score")}

DETECTION SIGNALS (from sensor & window agents):
- final_anomaly (sensor-based): {final_anomaly}
- final_drift (sensor/window-based): {final_drift}
- final_retrain: {final_retrain}

All prediction and detection signals are separate:
- Prediction = future failure likelihood (ML)
- Detection = current abnormalities (sensor + drift + window)

Window agent info:
{json.dumps(window_info)}

Prototype / transformer outputs (if any):
- prototype_score: {prototype_score}
- transformer_prob: {transformer_prob}

Per-sensor anomaly summary:
{json.dumps(per_sensor_summary, indent=2)}

Current sensor window snapshot (last {preview_len} timesteps for each sensor):
{json.dumps(sensor_window_dict, indent=2)}

Recent history of decisions (compressed):
{json.dumps(history_summaries, indent=2)}

Extra context:
{json.dumps(extra_context, default=str)}

Your tasks:
1. Decide whether this anomaly is likely REAL or a FALSE POSITIVE.
2. Infer which fault pattern (if any) best matches the evidence.
3. Explain the reasoning in terms of sensor behaviour (pressure, current, valves, temperature).
4. Recommend the next action for the human operator, considering safety and cost.
5. Assign a severity level: LOW, MEDIUM, HIGH, or CRITICAL.

{schema_instruction}
"""
        return prompt

    # =====================================================
    # LLM CALL + JSON HANDLING
    # =====================================================
    def _call_llm_with_json(self, prompt: str) -> dict:
        """
        Calls the LLM via self.llm_client and returns a validated JSON dict.
        Uses:
          - prompt-based JSON contract
          - retry with minimal fallback
          - schema validation
        """

        # If no client configured, return fallback
        if self.llm_client is None:
            return {
                "summary": "Anomaly detected (fallback, no LLM client configured).",
                "explanation": "ExpertAgent has no LLM client; returning default recommendation.",
                "likely_fault": "Unknown",
                "recommended_action": "ACK_AND_INVESTIGATE",
                "severity": "MEDIUM",
            }

        last_error = None

        for attempt in range(self.max_retries):
            try:
                response = self.llm_client.responses.create(
                    model="gpt-4o-mini",
                    input=prompt,
                    max_output_tokens=400,
                )

                # Using unified Responses API: easiest is output_text
                raw_text = getattr(response, "output_text", None)
                if raw_text is None:
                    # fallback to explicit extraction
                    raw_text = response.output[0].content[0].text

                parsed = self._robust_json_parse(raw_text)
                validated = self._validate_llm_json(parsed)
                return validated

            except Exception as e:
                print(f"‚ö†Ô∏è ExpertAgent LLM error (attempt {attempt+1}): {e}")
                last_error = e

        # Final fallback if everything fails
        return {
            "summary": "Anomaly detected (LLM fallback).",
            "explanation": f"LLM failed or returned invalid JSON. Last error: {last_error}",
            "likely_fault": "Unknown",
            "recommended_action": "ACK_AND_INVESTIGATE",
            "severity": "MEDIUM",
        }

    def _robust_json_parse(self, text: str) -> dict:
        """
        Best-effort JSON parsing:
          - First try direct json.loads
          - If fails, try to extract the first {...} block
        """
        text = text.strip()
        try:
            return json.loads(text)
        except Exception:
            # Try to find a JSON object substring
            start = text.find("{")
            end = text.rfind("}")
            if start != -1 and end != -1 and end > start:
                snippet = text[start : end + 1]
                return json.loads(snippet)
            # If still failing, raise
            raise

    def _validate_llm_json(self, obj: Any) -> dict:
        """
        Enforce a minimal schema:
          - must be a dict
          - must contain keys: summary, explanation, likely_fault, recommended_action, severity
        If keys are missing, fill with defaults.
        """

        if not isinstance(obj, dict):
            raise ValueError("LLM output is not a JSON object")

        required_keys = ["summary", "explanation", "likely_fault", "recommended_action", "severity"]
        defaults = {
            "summary": "No summary provided by LLM.",
            "explanation": "No explanation provided by LLM.",
            "likely_fault": "Unknown",
            "recommended_action": "ACK_AND_INVESTIGATE",
            "severity": "MEDIUM",
        }

        cleaned = {}
        for k in required_keys:
            v = obj.get(k, defaults[k])
            # ensure string
            cleaned[k] = str(v)

        # Optionally normalise recommended_action/severity to known values
        cleaned["recommended_action"] = cleaned["recommended_action"].upper().strip()
        cleaned["severity"] = cleaned["severity"].upper().strip()

        # Clamp to allowed sets if desired
        allowed_actions = {
            "IGNORE",
            "MONITOR",
            "ACK_AND_INVESTIGATE",
            "SCHEDULE_MAINTENANCE",
            "IMMEDIATE_SHUTDOWN",
        }
        if cleaned["recommended_action"] not in allowed_actions:
            cleaned["recommended_action"] = "ACK_AND_INVESTIGATE"

        allowed_severity = {"LOW", "MEDIUM", "HIGH", "CRITICAL"}
        if cleaned["severity"] not in allowed_severity:
            cleaned["severity"] = "MEDIUM"

        return cleaned


# =====================================================
# SIMPLE HUMAN LOOP + ALERT STUB
# =====================================================

import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart


def send_alert_to_human(expert_packet: dict):
    """
    Sends an email alert to a human operator using SMTP.
    Includes summary, recommended action, severity, and JSON dump.
    """

    # ----------------------
    # Extract ExpertAgent output
    # ----------------------
    result = expert_packet.get("llm_result", {})
    summary = result.get("summary", "No summary")
    action = result.get("recommended_action", "N/A")
    severity = result.get("severity", "N/A")

    decision_packet_json = json.dumps(
        expert_packet.get("decision_packet", {}),
        indent=2
    )
    llm_json = json.dumps(result, indent=2)

    # ----------------------
    # Email content
    # ----------------------
    subject = f"‚ö†Ô∏è APU Alert ‚Äì Severity: {severity}"

    body = f"""
Human Operator,

An alert has been generated by the APU Monitoring System.

-------------------
SUMMARY
-------------------
{summary}

-------------------
RECOMMENDED ACTION
-------------------
{action}

-------------------
SEVERITY
-------------------
{severity}

-------------------
FULL LLM RESULT
-------------------
{llm_json}

-------------------
RAW DECISION PACKET
-------------------
{decision_packet_json}

Timestamp: {expert_packet.get('timestamp')}
"""

    # ----------------------
    # Create email object
    # ----------------------
    msg = MIMEMultipart()
    msg["From"] = EMAIL_SENDER
    msg["To"] = EMAIL_RECIPIENT
    msg["Subject"] = subject

    msg.attach(MIMEText(body, "plain"))

    # ----------------------
    # Send email via SMTP
    # ----------------------
    try:
        server = smtplib.SMTP(SMTP_SERVER, SMTP_PORT)
        server.starttls()
        server.login(EMAIL_SENDER, EMAIL_PASSWORD)
        server.sendmail(EMAIL_SENDER, EMAIL_RECIPIENT, msg.as_string())
        server.quit()

        print("\n=== EMAIL ALERT SENT SUCCESSFULLY ===")
        print(f"To: {EMAIL_RECIPIENT}")
        print(f"Summary: {summary}")
        print("======================================\n")

    except Exception as e:
        print("‚ùå Failed to send alert email:", e)



def get_human_feedback_stub(expert_packet: dict) -> dict:
    """
    Stub: in a real deployment, this would come from UI / operator.
    Here we just echo back a synthetic 'ACK'.
    """
    return {
        "timestamp": datetime.now().isoformat(),
        "final_decision": "ACK_AND_LOG",
        "notes": "Stub human feedback ‚Äì replace with real operator input.",
    }
##################################################
#GROUND TRUTH ANOMALY VS WHICH AGENT AGENT IS RIGHT
############################################################
import matplotlib.pyplot as plt
import numpy as np

def plot_agent_vs_groundtruth(
    results,
    detection_labels=None,
    h1_labels=None,
    h3_labels=None,
    h12_labels=None,
    max_samples=200
):
    """
    Plot comparison between agent decisions, ground-truth labels,
    coordinator outputs, and AdaptiveWindowAgent predictions.
    """

    n = min(max_samples, len(results))
    t = np.arange(n)

    # Extract coordinator anomaly/drift
    agent_anomaly = [1 if r["coordinator"]["final_anomaly"] else 0 for r in results[:n]]
    agent_drift   = [1 if r["coordinator"]["final_drift"] else 0 for r in results[:n]]

    # New: extract ML prediction
    failure_prob = [
        results[i]["coordinator"]["scores"].get("prediction_failure_prob", 0.0)
        for i in range(n)
    ]
    final_failure = [
        1 if results[i]["coordinator"].get("final_failure") else 0
        for i in range(n)
    ]

    # Ground-truth labels
    det = detection_labels[:n] if detection_labels is not None else None
    h1  = h1_labels[:n] if h1_labels is not None else None
    h3  = h3_labels[:n] if h3_labels is not None else None
    h12 = h12_labels[:n] if h12_labels is not None else None

    # Window agent outputs
    window_sizes  = [r["window"]["predicted_window"] for r in results[:n]]
    window_events = [r["window"].get("event_type") for r in results[:n]]

    # ---- Create subplots ----
    fig, axes = plt.subplots(6, 1, figsize=(15, 17), sharex=True)

    # Agent outputs
    axes[0].plot(t, agent_anomaly, label="Agent Anomaly", color="red")
    axes[0].plot(t, agent_drift,   label="Agent Drift", color="orange")
    axes[0].set_ylabel("Agent")
    axes[0].legend()

    # Detection labels
    if det is not None:
        axes[1].plot(t, det, label="Detection Labels", color="blue")
        axes[1].set_ylabel("Detect")
        axes[1].legend()

    # Prediction labels (1h, 3h, 12h)
    if h1 is not None:
        axes[2].plot(t, h1, label="H1", color="green")
    if h3 is not None:
        axes[2].plot(t, h3, label="H3", color="purple")
    if h12 is not None:
        axes[2].plot(t, h12, label="H12", color="brown")
    axes[2].set_ylabel("Prediction")
    axes[2].legend()

    # Coordinator alert level
    alert_map = {"NORMAL": 0, "MEDIUM": 1, "HIGH": 2, "CRITICAL": 3}
    alerts = [alert_map.get(r["coordinator"]["alert_level"], 0) for r in results[:n]]
    axes[3].plot(t, alerts, label="Alert Level", color="black")
    axes[3].set_yticks([0,1,2,3])
    axes[3].set_yticklabels(["NORMAL","MED","HIGH","CRIT"])
    axes[3].set_ylabel("Coordinator")
    axes[3].legend()

    # AdaptiveWindowAgent subplot
    axes[4].plot(t, window_sizes, label="Predicted Window", color="blue")

    # Mark drift/anomaly events
    for i, evt in enumerate(window_events):
        if evt == "DRIFT":
            axes[4].scatter(i, window_sizes[i], color="orange", marker="x", label="Window Drift" if i==0 else "")
        elif evt == "ANOMALY":
            axes[4].scatter(i, window_sizes[i], color="red", marker="o", label="Window Anomaly" if i==0 else "")

        # ----------------------------------------------------
    # ‚≠ê Row 6: ML Failure Prediction Probability
    # ----------------------------------------------------
    axes[5].plot(t, failure_prob, label="Failure Prob (ML)", color="green")
    axes[5].plot(t, final_failure, label="Final Failure Decision", color="red", linestyle="--")

    axes[5].set_ylabel("Fail Prob")
    axes[5].set_ylim([-0.1, 1.1])
    axes[5].legend()

    # Overlay ground-truth fault events (vertical lines)
    if det is not None:
        for i, val in enumerate(det):
            if val == 1:
                axes[4].axvline(i, color="red", linestyle="--", alpha=0.3, label="Fault (Detection)" if i==0 else "")
    if h1 is not None:
        for i, val in enumerate(h1):
            if val == 1:
                axes[4].axvline(i, color="green", linestyle="--", alpha=0.2, label="Fault (H1)" if i==0 else "")
    if h3 is not None:
        for i, val in enumerate(h3):
            if val == 1:
                axes[4].axvline(i, color="purple", linestyle="--", alpha=0.2, label="Fault (H3)" if i==0 else "")
    if h12 is not None:
        for i, val in enumerate(h12):
            if val == 1:
                axes[4].axvline(i, color="brown", linestyle="--", alpha=0.2, label="Fault (H12)" if i==0 else "")

    axes[4].set_ylabel("Window Size")
    axes[4].legend()

    axes[-1].set_xlabel("Sample index")
    plt.tight_layout()
    plt.show()

#=========================================================================================================================================
#================= TESTING MAIN CODE========================================================================================================================
#=========================================================================================================================================
#=========================================================================================================================================
#############################################
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
import joblib


    # -------------------------------------------------
    # 1. LOAD DATA + MASK + LABELS
    # -------------------------------------------------
data_path       = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/multivariate_long_sequences-TRAIN-10Sec-DIRECT-VAR.npy"
label_path      = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/window_labels_3class.npy"
train_mask_path = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/train_mask.npy"
test_mask_path  = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/test_mask.npy"
holdout_mask_path = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/holdout_mask.npy"

X = np.load(data_path)        # (N, W, S)
y = np.load(label_path)       # (N,) ‚àà {0,1,2}

train_mask   = np.load(train_mask_path).astype(bool)
test_mask    = np.load(test_mask_path).astype(bool)
holdout_mask = np.load(holdout_mask_path).astype(bool)

# Sanity check (VERY important) -- This assertion enforces strict mutual exclusivity between evaluation and holdout subsets, preventing data leakage and ensuring unbiased performance estimation.
assert not np.any(train_mask & test_mask)
assert not np.any(train_mask & holdout_mask)
assert not np.any(test_mask & holdout_mask)

X_train, y_train = X[train_mask], y[train_mask]
X_test,  y_test  = X[test_mask],  y[test_mask]
X_hold,  y_hold  = X[holdout_mask], y[holdout_mask]

print("Train:", X_train.shape)
print("Test :", X_test.shape)
print("Hold :", X_hold.shape)


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np



#------------------------------------------------------------------------------------------
# -------------TRansformer for 3-class classification - NO CONTRASTIVE LEARNING
#----------------------------------------------------------------------------------------

def transformer_encoder(x, head_size, num_heads, ff_dim, dropout=0.2):
    # --- Self-attention block ---
    attn_output = layers.MultiHeadAttention(
        key_dim=head_size,
        num_heads=num_heads,
        dropout=dropout
    )(x, x)

    attn_output = layers.Dropout(dropout)(attn_output)
    x = layers.Add()([x, attn_output])
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # --- Feed-forward block ---
    ff_output = layers.Dense(ff_dim, activation="relu")(x)
    ff_output = layers.Dropout(dropout)(ff_output)
    ff_output = layers.Dense(x.shape[-1])(ff_output)

    x = layers.Add()([x, ff_output])
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    return x

def build_transformer(window_len, n_features, n_classes=3):
    inputs = keras.Input(shape=(window_len, n_features))

    # Project features to model dimension
    x = layers.Dense(64)(inputs)

    # Transformer blocks
    x = transformer_encoder(x, head_size=64, num_heads=4, ff_dim=128, dropout=0.2)
    x = transformer_encoder(x, head_size=64, num_heads=4, ff_dim=128, dropout=0.2)

    # Pooling + classifier
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(64, activation="relu")(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(n_classes, activation="softmax")(x)

    model = keras.Model(inputs, outputs)
    return model



window_len  = X_train.shape[1]
n_features = X_train.shape[2]

from sklearn.utils.class_weight import compute_class_weight
n_classes = 3
model = build_transformer(window_len, n_features, n_classes)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
        keras.metrics.SparseTopKCategoricalAccuracy(k=2, name="top2_acc"),
    ]
)


# Class weights (IMBALANCE HANDLING)
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),
    y=y_train
)

class_weight = {
    i: w for i, w in zip(np.unique(y_train), class_weights)
}

print("Class weights:", class_weight)
model.summary()

early_stop = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=2,
    restore_best_weights=True
)


#####------------------------Train and Save----------------------------
classifier = model
history = classifier.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),   # TEST = offline validation
    epochs=50,
    batch_size=256,
    class_weight=class_weight,
    callbacks=[early_stop],
    verbose=2
)

save_dir = "/content/drive/MyDrive/PHD/2025/models/transformer_3class"
os.makedirs(save_dir, exist_ok=True)

classifier.save(
    os.path.join(save_dir, "transformer_classifier.keras"),
    include_optimizer=False   # important for clean reload
)

print("‚úÖ Transformer classifier saved")

####-----------------------------Load and test-----------------------------

import tensorflow as tf
import tensorflow.keras.backend as K
from sklearn.metrics import classification_report, confusion_matrix


def focal_loss(gamma=2.0, alpha=0.25):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())

        bce = K.binary_crossentropy(y_true, y_pred)
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)

        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        modulating_factor = K.pow(1.0 - p_t, gamma)

        return K.mean(alpha_factor * modulating_factor * bce)
    return loss



model_path = "/content/drive/MyDrive/PHD/2025/models/transformer_3class/transformer_classifier.keras"

from tensorflow.keras.models import load_model

classifier = load_model(
    model_path,
    custom_objects={"focal_loss": focal_loss}  # only if you actually used it
)

classifier.trainable = False  # safety


print("‚úÖ Transformer loaded for TEST")
probs_test = classifier.predict(X_test)
y_pred_test = np.argmax(probs_test, axis=1)

print("\n=== TEST SET PERFORMANCE ===")
print(classification_report(y_test, y_pred_test, digits=4))
print("Confusion matrix:\n", confusion_matrix(y_test, y_pred_test))

###-------------------------HOLDOUT-------------------
############################################################
probs_hold = classifier.predict(X_hold)
y_pred_hold = np.argmax(probs_hold, axis=1)

print("\n=== HOLDOUT SET PERFORMANCE (FINAL) ===")
print(classification_report(y_hold, y_pred_hold, digits=4))
print("Confusion matrix:\n", confusion_matrix(y_hold, y_pred_hold))



def transformer_signals(classifier, X_window):
    probs = classifier.predict(X_window, verbose=0)[0]  # shape (3,)

    p0, p1, p2 = probs

    return {
        "p_normal": float(p0),
        "early_fault_score": float(p1),
        "severe_fault_score": float(p2),
        "anomaly_score": float(1.0 - p0)
    }

signals = transformer_signals(classifier, X_hold) #Get actual prob distribution now
print(signals)


# =====================================================
#  MAIN ORCHESTRATION: DECISION + EXPERT + EVENT STORE
# =====================================================

# Reuse your existing:
# - window_agent  (AdaptiveWindowAgent)
# - sensor_agents, master (RobustMasterAgent)
# - scores        (prototype_scores on emb_holdout)
# - X_holdout, y_holdout

print("\nüöÄ Running DecisionAgent + ExpertAgent pipeline on HOLDOUT...\n")

# 1) Initialise storage + agents
import os
from google.colab import userdata
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
from openai import OpenAI
openai_client = OpenAI()

event_store = EventStore(db_path="event_store.db")

# -----------------------------------------------------
# 1) Decision + Expert agents
# -----------------------------------------------------
decision_agent = DecisionAgent(
    prediction_weight=0.7,   # transformer dominates
    sensor_weight=0.2,
    drift_weight=0.1,
    failure_threshold=0.5,
    failure_critical_threshold=0.8,
)

expert_agent = ExpertAgent(
    event_store=event_store,
    system_description="MetroPT multivariate APU system",
    llm_client=openai_client,
    history_limit=100,
    max_history_for_prompt=10,
    window_preview_len=30,
)

# -----------------------------------------------------
# 2) Run HOLDOUT
# -----------------------------------------------------
print("\nüöÄ Running Decision + Expert pipeline on HOLDOUT...\n")

decision_results = []
decision_flags = []

for i, seq in enumerate(X_hold):
    try:
        # ---------------------------------------------
        # a) Sensor-level detection
        # ---------------------------------------------
        master_out = master.process_system_input(seq)

        # ---------------------------------------------
        # b) Window agent
        # ---------------------------------------------
        features = seq.flatten()
        window_out = window_agent.predict_window_size(features, seq)

        # ---------------------------------------------
        # c) Transformer prediction (ML truth)
        # ---------------------------------------------
        probs = classifier.predict(seq[np.newaxis, ...], verbose=0)[0]

        model_outputs = {
            "failure_prob": float(probs[2]),   # class 2 = FAULT
            "transformer_prob": float(probs[1]),  # class 1 = WARNING
            "prototype_score": None,
        }

        # ---------------------------------------------
        # d) Decision fusion
        # ---------------------------------------------
        decision_packet = decision_agent.decide(
            master_output=master_out,
            window_output=window_out,
            model_outputs=model_outputs,
            metadata={"index": int(i)},
        )

        expert_packet = None
        human_feedback = None

        # ---------------------------------------------
        # e) LLM only for HIGH / CRITICAL
        # ---------------------------------------------
        if decision_packet["alert_level"] in ["HIGH", "CRITICAL"]:
            expert_packet = expert_agent.analyse_anomaly(
                decision_packet=decision_packet,
                system_subsequence=seq,
                extra_context={"index": int(i)},
            )

            human_feedback = get_human_feedback_stub(expert_packet)

            event_store.save_event(
                event_type="ALERT",
                packet=decision_packet,
                expert=expert_packet,
                human=human_feedback,
            )
        else:
            event_store.save_event(
                event_type="NORMAL",
                packet=decision_packet,
                expert=None,
                human=None,
            )

        decision_results.append({
            "master": master_out,
            "window": window_out,
            "decision": decision_packet,
            "expert": expert_packet,
            "human_feedback": human_feedback,
        })

        decision_flags.append(1 if decision_packet["final_failure"] else 0)

        if (i + 1) % 500 == 0:
            print(f"Processed {i+1}/{len(X_hold)} samples...")

    except Exception as e:
        print(f"‚ö†Ô∏è Sample {i} failed: {e}")
        decision_results.append({
            "master": None,
            "window": None,
            "decision": {"final_failure": False},
            "expert": None,
            "human_feedback": None,
        })
        decision_flags.append(0)

decision_flags = np.array(decision_flags)

print("\n‚úÖ Pipeline completed.")
print("Failure flag distribution:", np.unique(decision_flags, return_counts=True))
print("Anomaly flags distribution:", np.unique(decision_flags, return_counts=True))



=== HOLDOUT SET PERFORMANCE (FINAL) ===
              precision    recall  f1-score   support

           0     0.9996    0.9896    0.9946    344234
           1     0.1739    0.9408    0.2936       726
           2     0.7771    0.8832    0.8268      1721

    accuracy                         0.9890    346681
   macro avg     0.6502    0.9379    0.7050    346681
weighted avg     0.9968    0.9890    0.9923    346681

Confusion matrix:
 [[340667   3131    436]
 [    43    683      0]
 [    88    113   1520]]
{'p_normal': 0.5116674900054932, 'early_fault_score': 0.43738940358161926, 'severe_fault_score': 0.05094306915998459, 'anomaly_score': 0.48833250999450684}


In [None]:
from google.colab import drive
drive.mount('/content/drive')