<a href="https://colab.research.google.com/github/supriyag123/PHD_Pub/blob/main/AGENTIC-COMBINED-AGENTS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
# ====== Imports ======
import numpy as np
import pandas as pd
import pickle, json, os, logging, warnings
from collections import deque
from typing import Dict, Any
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from statsmodels.tsa.vector_ar.var_model import VAR
import keras, tensorflow as tf
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

# ====== RobustSensorAgent & RobustMasterAgent ======
import pickle
import os
from collections import deque
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# ML libraries
from sklearn.metrics import mean_squared_error
from scipy import stats
from scipy.spatial.distance import jensenshannon

# Deep learning
try:
    from tensorflow.keras.models import load_model
    KERAS_AVAILABLE = True
except ImportError:
    KERAS_AVAILABLE = False


# =====================================================
# ROBUST SENSOR AGENT - Observes ONE sensor with AE model
# =====================================================

class RobustSensorAgent:
    """
    Robust Sensor Agent for ONE sensor with advanced anomaly & drift detection.

    Loads pretrained AE model + metadata (scaler, baseline errors, rolling stats).
    Computes anomaly score via reconstruction error, applies adaptive thresholding,
    drift detection, and outputs robust anomaly/drift/retrain flags.
    """

    def __init__(self,
                 sensor_id: int,
                 model_path: str = None,
                 window_length: int = 50,
                 memory_size: int = 1000,
                 threshold_k: float = 2.0,
                 drift_threshold: float = 0.1):

        self.sensor_id = sensor_id
        self.window_length = window_length
        self.threshold_k = threshold_k
        self.drift_threshold = drift_threshold

        # Model & metadata
        self.model = None
        self.scaler = None
        self.is_model_loaded = False

        # Buffers
        self.error_memory = deque(maxlen=memory_size)
        self.data_memory = deque(maxlen=memory_size)
        self.recent_errors = deque(maxlen=100)

        # Rolling stats
        self.rolling_stats = {'mean': 0.0, 'std': 1.0, 'q95': 0.0, 'q99': 0.0}
        self.baseline_errors = None

        # Counters
        self.total_processed = 0
        self.anomalies_detected = 0
        self.drift_detected_count = 0
        self.last_stats_update = datetime.now()

        if model_path:
            self.load_model(model_path)

    def load_model(self, model_path: str) -> bool:
        """Load pretrained AE model + metadata."""
        try:
            if KERAS_AVAILABLE and model_path.endswith('.h5'):
                self.model = load_model(model_path, compile=False)

                # Metadata sidecar file
                metadata_path = model_path.replace('.h5', '_metadata.pkl')
                if os.path.exists(metadata_path):
                    with open(metadata_path, 'rb') as f:
                        metadata = pickle.load(f)
                   # self.scaler = metadata.get('scaler', StandardScaler())
                    self.rolling_stats = metadata.get('rolling_stats', self.rolling_stats)
                    if 'error_history' in metadata:
                        self.baseline_errors = np.array(metadata['error_history'])
            else:
                raise ValueError("Unsupported model format – expecting .h5 AE model")

            self.is_model_loaded = True
            print(f"✅ AE model loaded for sensor {self.sensor_id}")
            return True

        except Exception as e:
            print(f"❌ Failed to load AE model for sensor {self.sensor_id}: {e}")
            return False

    def observe(self, sensor_subsequence: np.ndarray) -> Dict:
        """Observe subsequence [window_length] and return anomaly/drift flags."""
        if not self.is_model_loaded:
            return {"sensor_id": self.sensor_id, "error": "no_model_loaded", "timestamp": datetime.now()}

        if len(sensor_subsequence) != self.window_length:
            return {"sensor_id": self.sensor_id,
                    "error": f"invalid_length_expected_{self.window_length}_got_{len(sensor_subsequence)}",
                    "timestamp": datetime.now()}

        # 1. Anomaly score
        anomaly_score = self._compute_robust_anomaly_score(sensor_subsequence)

        # 2. Update memory
        self.data_memory.append(sensor_subsequence.copy())
        self.error_memory.append(anomaly_score)
        self.recent_errors.append(anomaly_score)

        # 3. Update rolling stats periodically
        if len(self.error_memory) >= 50 and len(self.error_memory) % 10 == 0:
            self._update_rolling_stats(list(self.error_memory)[-50:])

        # 4. Flags
        is_anomaly = self._check_adaptive_anomaly(anomaly_score)
        drift_flag = self._check_advanced_drift()
        needs_retrain = self._check_retrain_need()
        confidence = self._compute_robust_confidence(anomaly_score)

        # 5. Update counters
        self.total_processed += 1
        if is_anomaly: self.anomalies_detected += 1
        if drift_flag: self.drift_detected_count += 1

        return {
            "sensor_id": self.sensor_id,
            "timestamp": datetime.now(),
            "is_anomaly": bool(is_anomaly),
            "drift_flag": bool(drift_flag),
            "needs_retrain_flag": bool(needs_retrain),
            "anomaly_score": float(anomaly_score),
            "confidence": float(confidence),
            "threshold_used": float(self.rolling_stats['mean'] + self.threshold_k * self.rolling_stats['std']),
            "anomaly_rate": self.anomalies_detected / max(1, self.total_processed),
            "drift_rate": self.drift_detected_count / max(1, self.total_processed)
        }

    def _compute_robust_anomaly_score(self, subsequence: np.ndarray) -> float:
        """Compute reconstruction error using AE model."""
        try:
            #data_scaled = self.scaler.transform(subsequence.reshape(-1, 1))
            X = subsequence.reshape(1, self.window_length, 1)  # [batch, timesteps, features]
            reconstruction = self.model.predict(X, verbose=0)
            error = mean_squared_error(subsequence.flatten(), reconstruction.flatten())
            return max(0.0, error)
        except Exception as e:
            print(f"⚠️ AE inference failed for sensor {self.sensor_id}: {e}")
            return np.var(subsequence)

    def _update_rolling_stats(self, errors: List[float]):
        errors_array = np.array(errors)
        self.rolling_stats['mean'] = np.mean(errors_array)
        self.rolling_stats['std'] = np.std(errors_array) + 1e-8
        self.rolling_stats['q95'] = np.percentile(errors_array, 95)
        self.rolling_stats['q99'] = np.percentile(errors_array, 99)
        self.last_stats_update = datetime.now()

    def _check_adaptive_anomaly(self, score: float) -> bool:
        threshold = self.rolling_stats['mean'] + self.threshold_k * self.rolling_stats['std']
        return score > threshold

    def _check_advanced_drift(self) -> bool:
        if self.baseline_errors is None or len(self.recent_errors) < 30:
            return False
        try:
            hist_baseline, bins = np.histogram(self.baseline_errors, bins=20, density=True)
            hist_recent, _ = np.histogram(list(self.recent_errors), bins=bins, density=True)
            hist_baseline += 1e-10; hist_recent += 1e-10
            hist_baseline /= hist_baseline.sum(); hist_recent /= hist_recent.sum()
            js_divergence = jensenshannon(hist_baseline, hist_recent)
            return js_divergence > self.drift_threshold
        except Exception:
            try:
                _, p_value = stats.ks_2samp(self.baseline_errors, list(self.recent_errors))
                return p_value < 0.05
            except:
                return False

    def _check_retrain_need(self) -> bool:
        if len(self.error_memory) < 100: return False
        recent_errors = list(self.error_memory)[-50:]
        threshold = self.rolling_stats['mean'] + self.threshold_k * self.rolling_stats['std']
        anomaly_rate = sum(1 for e in recent_errors if e > threshold) / len(recent_errors)
        criteria = [
            anomaly_rate > 0.3,
            self.drift_detected_count > 0.1 * self.total_processed,
            np.mean(recent_errors) > 2.0 * self.rolling_stats['mean'] if len(recent_errors) > 0 else False,
            (datetime.now() - self.last_stats_update).days > 7
        ]
        return sum(criteria) >= 2

    def _compute_robust_confidence(self, score: float) -> float:
        if self.rolling_stats['std'] == 0: return 0.5
        threshold = self.rolling_stats['mean'] + self.threshold_k * self.rolling_stats['std']
        distance_from_threshold = abs(score - threshold) / self.rolling_stats['std']
        return min(1.0, distance_from_threshold / 3.0)


# =====================================================
# ROBUST MASTER AGENT
# =====================================================

class RobustMasterAgent:
    """Aggregates sensor results, makes system-level anomaly/drift/retrain decisions."""
    def __init__(self, sensor_agents: List[RobustSensorAgent],
                 system_anomaly_threshold: float = 0.3,
                 drift_threshold: float = 0.2,
                 retrain_threshold: float = 0.15):
        self.sensor_agents = sensor_agents
        self.num_sensors = len(sensor_agents)
        self.system_anomaly_threshold = system_anomaly_threshold
        self.drift_threshold = drift_threshold
        self.retrain_threshold = retrain_threshold

    def process_system_input(self, system_subsequence: np.ndarray) -> Dict:
        """Process [window_length, num_sensors] multivariate subsequence."""
        timestamp = datetime.now()
        if system_subsequence.shape[1] != self.num_sensors:
            return {"error": f"Expected {self.num_sensors} sensors, got {system_subsequence.shape[1]}",
                    "timestamp": timestamp}

        # 1. Collect sensor observations
        sensor_results = []
        for i, agent in enumerate(self.sensor_agents):
            sensor_data = system_subsequence[:, i]
            result = agent.observe(sensor_data)
            sensor_results.append(result)

        # 2. Simple aggregation
        anomalies = sum(1 for r in sensor_results if r.get("is_anomaly"))
        drifts = sum(1 for r in sensor_results if r.get("drift_flag"))
        retrains = sum(1 for r in sensor_results if r.get("needs_retrain_flag"))

        anomaly_rate = anomalies / max(1, self.num_sensors)
        drift_rate = drifts / max(1, self.num_sensors)
        retrain_rate = retrains / max(1, self.num_sensors)

        system_decisions = {
            "system_anomaly": anomaly_rate >= self.system_anomaly_threshold,
            "system_drift": drift_rate >= self.drift_threshold,
            "system_needs_retrain": retrain_rate >= self.retrain_threshold,
            "anomaly_rate": anomaly_rate,
            "drift_rate": drift_rate,
            "retrain_rate": retrain_rate
        }

        return {
            "timestamp": timestamp,
            "sensor_results": sensor_results,
            "system_decisions": system_decisions
        }


# =====================================================
# SENSOR SYSTEM CREATION
# =====================================================

def create_robust_system(num_sensors: int, models_dir: str) -> Tuple[List[RobustSensorAgent], RobustMasterAgent]:
    """Create robust sensor system loading AE models + metadata."""
    print(f"🚀 Creating robust system with {num_sensors} sensors")
    sensor_agents = []
    for sensor_id in range(num_sensors):
        model_path = os.path.join(models_dir, f"sensor_{sensor_id}_model.h5")
        agent = RobustSensorAgent(sensor_id=sensor_id,
                                  model_path=model_path if os.path.exists(model_path) else None,
                                  window_length=50,
                                  memory_size=1000,
                                  threshold_k=2.0,
                                  drift_threshold=0.1)
        sensor_agents.append(agent)

    master = RobustMasterAgent(sensor_agents=sensor_agents,
                               system_anomaly_threshold=0.3,
                               drift_threshold=0.2,
                               retrain_threshold=0.15)
    print(f"✅ Created system: {len([a for a in sensor_agents if a.is_model_loaded])}/{num_sensors} models loaded")
    return sensor_agents, master

# ====== AdaptiveWindowAgent ======

# agents/adaptive_window_agent.py
import numpy as np
import pandas as pd
import pickle
import json
import os
from collections import deque
from typing import Dict, Any
import datetime as dt
import logging
from sklearn.preprocessing import StandardScaler
from statsmodels.tsa.vector_ar.var_model import VAR
import keras
import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')

logger = logging.getLogger(__name__)

class AdaptiveWindowAgent:
    """
    Agent A: Adaptive Window Management with Enhanced MLP

    Capabilities:
    1. Predict window size using trained MLP
    2. Calculate actual performance using VAR forecast
    3. Track accuracy and performance statistics
    4. Monitor for anomaly vs drift in prediction performance
    5. Retrain MLP when drift is confirmed
    """

    def __init__(self, agent_id: str = "adaptive_window_agent",
                 model_path: str = None,
                 checkpoint_path: str = None):
        self.agent_id = agent_id
        self.model_path = model_path or "/content/drive/MyDrive/PHD/2025/DGRNet-MLP-Versions/METROPM_MLP_model_Daily.keras"
        self.checkpoint_path = checkpoint_path or "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/ckp2.weights.h5"

        # Core MLP
        self.model = None
        self.transformer = StandardScaler()
        self.transformer_fitted = False
        self.is_model_loaded = False

        # Performance tracking
        self.prediction_history = deque(maxlen=1000)
        self.mse_history = deque(maxlen=200)
        self.mae_history = deque(maxlen=200)

        # Event detection parameters
        self.drift_detection_window = 20
        self.drift_threshold_mse = 0.2
        self.drift_threshold_mae = 0.2
        self.consecutive_poor_predictions = 0
        self.cooldown_counter = 0

        # Stats
        self.performance_stats = {
            'total_predictions': 0,
            'avg_mse': 0.0,
            'avg_mae': 0.0,
            'last_retrain_time': None,
            'drift_events': 0,
            'anomaly_events': 0,
            'retraining_events': 0
        }

        # Buffers for retraining
        self.retraining_data = {
            'x_buffer': deque(maxlen=10000),
            'y_buffer': deque(maxlen=10000)
        }

        self.load_model()
        print(f"AdaptiveWindowAgent {self.agent_id} initialized")
        print(f"Model loaded: {self.is_model_loaded}")
        print(f"Transformer fitted: {self.transformer_fitted}")

    def load_model(self):
        """Load trained MLP model and recreate transformer"""
        try:
            if os.path.exists(self.model_path):
                self.model = keras.models.load_model(self.model_path)
                self.is_model_loaded = True
                print(f"Loaded MLP model from {self.model_path}")

                transformer_path = self.model_path.replace('.keras', '_transformer.pkl')
                if os.path.exists(transformer_path):
                    with open(transformer_path, 'rb') as f:
                        self.transformer = pickle.load(f)
                    self.transformer_fitted = True
                    print("Loaded saved transformer")
                else:
                    y_original = np.load('/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/generated-data-true-window2.npy')
                    self.transformer = StandardScaler()
                    self.transformer.fit(y_original.reshape(-1, 1))
                    self.transformer_fitted = True
                    with open(transformer_path, 'wb') as f:
                        pickle.dump(self.transformer, f)
                    print(f"Fitted transformer on {len(y_original)} samples and saved")
            else:
                print(f"Model file not found at {self.model_path}")
        except Exception as e:
            print(f"Error loading model: {e}")

    def evaluate_forecast_performance(self, sequence_3d: np.ndarray, predicted_window: int, n_future: int = 1) -> Dict[str, float]:
        """Use predicted window to forecast with VAR and calculate MSE/MAE"""
        try:
            df = pd.DataFrame(sequence_3d, columns=[f'V{i+1}' for i in range(sequence_3d.shape[1])])
            df_train, df_test = df[0:-n_future], df[-n_future:]

            # Drop constant columns
            constant_columns = [col for col in df_train.columns if df_train[col].nunique() <= 1 or df_train[col].var() < 1e-12]
            df_train = df_train.drop(columns=constant_columns, errors="ignore")
            df_test = df_test.drop(columns=constant_columns, errors="ignore")

            if len(df_train.columns) < 2:
                return {'mse': 99999, 'mae': 99999, 'forecast_success': False}

            k = min(predicted_window, len(df_train) - 2)
            if k < 1: k = 1

            model = VAR(df_train)
            model_fitted = None
            for trend in ['n', 'c', 'ct', 'ctt']:
                try:
                    model_fitted = model.fit(maxlags=k, trend=trend)
                    break
                except:
                    continue
            if model_fitted is None:
                return {'mse': 99999, 'mae': 99999, 'forecast_success': False}

            forecast_input = df_train.values[-model_fitted.k_ar:]
            fc = model_fitted.forecast(y=forecast_input, steps=n_future)
            df_forecast = pd.DataFrame(fc, index=df.index[-n_future:], columns=df_train.columns)

            common_cols = [c for c in df_forecast.columns if c in df_test.columns]
            actual = df_test[common_cols].values.flatten()
            predicted = df_forecast[common_cols].values.flatten()

            mse = np.mean((actual - predicted) ** 2)
            mae = np.mean(np.abs(actual - predicted))

            return {
                'mse': float(mse),
                'mae': float(mae),
                'forecast_success': True,
                'used_columns': common_cols,
                'actual_values': actual.tolist(),
                'predicted_values': predicted.tolist()
            }
        except Exception as e:
            return {'mse': 99999, 'mae': 99999, 'forecast_success': False, 'error': str(e)}

    def predict_window_size(self, feature_vector: np.ndarray, sequence_3d: np.ndarray) -> Dict[str, Any]:
        """Predict window size using MLP and evaluate forecast"""
        if not self.is_model_loaded:
            return {'predicted_window': 20, 'confidence': 0.0, 'error': "Model not loaded"}
        try:
            if feature_vector.ndim == 1:
                feature_vector = feature_vector.reshape(1, -1)

            prediction_raw = self.model.predict(feature_vector, verbose=0)
            if self.transformer_fitted:
                predicted_window = self.transformer.inverse_transform(prediction_raw)[0, 0]
            else:
                predicted_window = prediction_raw[0, 0]
            predicted_window = int(round(predicted_window))

            forecast_metrics = self.evaluate_forecast_performance(sequence_3d, predicted_window, n_future=1)

            prediction_record = {
                'timestamp': dt.datetime.now(),
                'predicted_window': predicted_window,
                'forecast_metrics': forecast_metrics,
                'forecast_success': forecast_metrics.get('forecast_success', False)
            }

            if forecast_metrics.get('forecast_success', False):
                self.mse_history.append(forecast_metrics['mse'])
                self.mae_history.append(forecast_metrics['mae'])
                self.performance_stats['total_predictions'] += 1
                self.performance_stats['avg_mse'] = np.mean(self.mse_history)
                self.performance_stats['avg_mae'] = np.mean(self.mae_history)

                event_type = self._check_for_event()
                prediction_record['event_type'] = event_type
            else:
                self.consecutive_poor_predictions += 1
                prediction_record['event_type'] = None

            self.prediction_history.append(prediction_record)
            self.retraining_data['x_buffer'].append(feature_vector.flatten())
            self.retraining_data['y_buffer'].append(predicted_window)

            return {
                'predicted_window': predicted_window,
                'forecast_metrics': forecast_metrics,
                'confidence': self._calculate_confidence(prediction_record),
                'performance_stats': self.get_recent_performance(),
                'event_type': prediction_record['event_type'],
                'prediction_id': len(self.prediction_history)
            }
        except Exception as e:
            return {'predicted_window': 20, 'confidence': 0.0, 'error': str(e)}

    def _check_for_event(self) -> str:
        """Detect anomaly vs drift using persistence and error spread."""
        if len(self.mse_history) < self.drift_detection_window:
            return None
        try:
            def ema(values, alpha=0.3):
                ema_val = values[0]
                for v in values[1:]:
                    ema_val = alpha * v + (1 - alpha) * ema_val
                return ema_val

            mse_vals = list(self.mse_history)[-self.drift_detection_window:]
            mae_vals = list(self.mae_history)[-self.drift_detection_window:]
            ema_mse, ema_mae = ema(mse_vals), ema(mae_vals)
            baseline_mse, baseline_mae = np.median(mse_vals), np.median(mae_vals)

            mse_ratio = ema_mse / max(baseline_mse, 1e-5)
            mae_ratio = ema_mae / max(baseline_mae, 1e-5)

            # Absolute thresholds
            if ema_mse < 0.02 and ema_mae < 0.08:
                self.consecutive_poor_predictions = 0
                return None

            event_condition = (mse_ratio > (1 + self.drift_threshold_mse) and
                               mae_ratio > (1 + self.drift_threshold_mae))

            if event_condition:
                self.consecutive_poor_predictions += 1
            else:
                if 0 < self.consecutive_poor_predictions < 3:
                    self.performance_stats['anomaly_events'] += 1
                    logger.warning(f"ANOMALY detected: EMA_MSE={ema_mse:.4f}, EMA_MAE={ema_mae:.4f}")
                    self.consecutive_poor_predictions = 0
                    return "ANOMALY"
                self.consecutive_poor_predictions = 0

            if self.consecutive_poor_predictions >= 5:
                self.performance_stats['drift_events'] += 1
                logger.warning(f"DRIFT detected: MSE ratio={mse_ratio:.3f}, MAE ratio={mae_ratio:.3f}, "
                               f"EMA_MSE={ema_mse:.4f}, EMA_MAE={ema_mae:.4f}")
                self.consecutive_poor_predictions = 0
                return "DRIFT"

            return None
        except Exception as e:
            logger.error(f"Event detection error: {e}")
            return None

    def _calculate_confidence(self, prediction_record: Dict) -> float:
        if len(self.mse_history) < 10: return 0.5
        recent_mse = np.mean(list(self.mse_history)[-10:])
        recent_mae = np.mean(list(self.mae_history)[-10:])
        mse_conf = max(0, 1 - (recent_mse / (np.percentile(list(self.mse_history), 25) * 4)))
        mae_conf = max(0, 1 - (recent_mae / (np.percentile(list(self.mae_history), 25) * 4)))
        return min(1.0, max(0.1, (mse_conf + mae_conf) / 2))

    def get_recent_performance(self) -> Dict[str, Any]:
        successful_predictions = [p for p in list(self.prediction_history)[-50:] if p.get('forecast_success', False)]
        return {
            'total_predictions': len(self.prediction_history),
            'successful_predictions': len(successful_predictions),
            'success_rate': len(successful_predictions) / max(len(self.prediction_history), 1),
            'drift_events': self.performance_stats['drift_events'],
            'anomaly_events': self.performance_stats['anomaly_events'],
            'retraining_events': self.performance_stats['retraining_events'],
            'recent_mse': np.mean(list(self.mse_history)[-10:]) if self.mse_history else 0,
            'avg_mse': np.mean(self.mse_history) if self.mse_history else 0,
            'recent_mae': np.mean(list(self.mae_history)[-10:]) if self.mae_history else 0,
            'avg_mae': np.mean(self.mae_history) if self.mae_history else 0,
            'transformer_fitted': self.transformer_fitted
        }

    def save_performance_state(self, filepath: str):
        state = {
            'performance_stats': self.performance_stats.copy(),
            'prediction_history': list(self.prediction_history)[-100:],
            'mse_history': list(self.mse_history),
            'mae_history': list(self.mae_history),
            'transformer_fitted': self.transformer_fitted
        }
        with open(filepath, 'w') as f:
            json.dump(state, f, default=str, indent=2)

# (optional) plotting helper


# ==================== PLOTTING ====================
import matplotlib.pyplot as plt

def plot_all_features_forecasts(test_sequences, agent, max_features=12):
    all_actual, all_forecast, event_points = [], [], []
    for i, seq in enumerate(test_sequences):
        result = agent.predict_window_size(seq.flatten(), seq)
        fm = result.get("forecast_metrics", {})
        if not fm.get("forecast_success", False): continue
        actual_row = [np.nan] * max_features
        forecast_row = [np.nan] * max_features
        for j, col in enumerate(fm.get("used_columns", [])):
            idx = int(col.replace("V", "")) - 1
            if idx < max_features:
                actual_row[idx] = fm["actual_values"][j]
                forecast_row[idx] = fm["predicted_values"][j]
        all_actual.append(actual_row)
        all_forecast.append(forecast_row)
        if result.get("event_type", None):
            event_points.append((i, result["event_type"]))
    if not all_actual: return
    actual_matrix = np.array(all_actual)
    forecast_matrix = np.array(all_forecast)
    timestamps = np.arange(len(actual_matrix))
    fig, axes = plt.subplots(max_features, 1, figsize=(12, 2.5 * max_features), sharex=True)
    for idx, ax in enumerate(axes):
        ax.plot(timestamps, actual_matrix[:, idx], label="Actual", marker="o", alpha=0.6)
        ax.plot(timestamps, forecast_matrix[:, idx], label="Forecast", marker="x", alpha=0.6)
        for (t, etype) in event_points:
            ax.scatter(t, actual_matrix[t, idx], color="red" if etype=="ANOMALY" else "orange",
                       marker="D", label=etype if idx==0 else "")
        ax.set_ylabel(f"V{idx+1}"); ax.grid(True, alpha=0.3)
        if idx == 0: ax.legend()
    axes[-1].set_xlabel("Sequence index")
    plt.tight_layout(); plt.show()

# ====== CoordinatorAgent ======

class CoordinatorAgent:
    """
    Coordinator Agent - orchestrates outputs from:
      1. RobustMasterAgent (sensor-level aggregation)
      2. AdaptiveWindowAgent (global window-level perspective)

    Combines them into a final system decision:
      - True anomalies (local + global agreement)
      - Concept drift (global window instability)
      - Retraining triggers
    """

    def __init__(self, anomaly_weight: float = 0.6, window_weight: float = 0.4):
        self.anomaly_weight = anomaly_weight
        self.window_weight = window_weight
        self.history = []

    def fuse(self, master_output: dict, window_output: dict) -> dict:
        timestamp = datetime.now()

        # Master outputs
        sys_dec = master_output.get("system_decisions", {})
        sys_anomaly = sys_dec.get("system_anomaly", False)
        sys_drift = sys_dec.get("system_drift", False)
        sys_retrain = sys_dec.get("system_needs_retrain", False)
        anomaly_rate = sys_dec.get("anomaly_rate", 0.0)
        drift_rate = sys_dec.get("drift_rate", 0.0)

        # Window outputs
        win_anomaly = window_output.get("window_anomaly_flag", False) \
                      or (window_output.get("event_type") == "ANOMALY")
        win_drift = window_output.get("window_drift_flag", False) \
                    or (window_output.get("event_type") == "DRIFT")
        predicted_window = window_output.get("predicted_window", None)

        # Fusion logic
        anomaly_score = (
            self.anomaly_weight * anomaly_rate +
            self.window_weight * (1.0 if win_anomaly else 0.0)
        )
        drift_score = (
            self.anomaly_weight * drift_rate +
            self.window_weight * (1.0 if win_drift else 0.0)
        )

        final_anomaly = anomaly_score >= 0.3 or (sys_anomaly and win_anomaly)
        final_drift = drift_score >= 0.2 or (sys_drift or win_drift)
        final_retrain = sys_retrain or (win_drift and anomaly_score > 0.2)

        if final_anomaly and final_drift:
            alert = "CRITICAL"
        elif final_anomaly:
            alert = "HIGH"
        elif final_drift:
            alert = "MEDIUM"
        else:
            alert = "NORMAL"

        decision = {
            "timestamp": timestamp,
            "final_anomaly": final_anomaly,
            "final_drift": final_drift,
            "final_retrain": final_retrain,
            "alert_level": alert,
            "scores": {
                "anomaly_score": anomaly_score,
                "drift_score": drift_score,
                "sensor_anomaly_rate": anomaly_rate,
                "sensor_drift_rate": drift_rate
            },
            "window_agent": {
                "predicted_window": predicted_window,
                "window_anomaly": win_anomaly,
                "window_drift": win_drift
            }
        }

        self.history.append(decision)
        return decision

##################################################
#GROUND TRUTH ANOMALY VS WHICH AGENT AGENT IS RIGHT
############################################################

import matplotlib.pyplot as plt
import numpy as np

def plot_agent_vs_groundtruth(results, anomaly_labels, pred_labels_h1=None, pred_labels_h5=None, max_samples=200):
    """
    Compare Master, Window, Coordinator outputs vs ground truth labels.

    Args:
        results: list of dicts from your demo loop (each has master_out, window_out, coordinator).
        anomaly_labels: np.ndarray of shape [N] (0/1 ground truth detection).
        pred_labels_h1: np.ndarray of shape [N] (0/1 ground truth prediction 1h).
        pred_labels_h5: np.ndarray of shape [N] (0/1 ground truth prediction 5h).
        max_samples: number of samples to visualize.
    """

    n = min(len(results), max_samples, len(anomaly_labels))
    x = np.arange(n)

    # Extract agent outputs
    master_anom_rate = [r['master']['system_decisions']['anomaly_rate'] for r in results[:n]]
    window_events = [1 if r['window'].get('event_type') == "DRIFT" else
                     (0.5 if r['window'].get('event_type') == "ANOMALY" else 0)
                     for r in results[:n]]
    coord_alert = []
    for r in results[:n]:
        lvl = r['coordinator']['alert_level']
        if lvl == "NORMAL": coord_alert.append(0)
        elif lvl == "MEDIUM": coord_alert.append(1)
        elif lvl == "HIGH": coord_alert.append(2)
        elif lvl == "CRITICAL": coord_alert.append(3)

    # Ground truth
    gt_detect = anomaly_labels[:n]
    gt_h1 = pred_labels_h1[:n] if pred_labels_h1 is not None else None
    gt_h5 = pred_labels_h5[:n] if pred_labels_h5 is not None else None

    # --- Plot ---
    fig, axes = plt.subplots(4 if gt_h1 is not None else 3, 1, figsize=(14, 8), sharex=True)

    axes[0].plot(x, master_anom_rate, label="Master anomaly rate", color="blue")
    axes[0].scatter(x, window_events, label="Window events (0.5=ANOM,1=DRIFT)", color="orange", marker="x")
    axes[0].legend(); axes[0].grid(True, alpha=0.3)

    axes[1].plot(x, coord_alert, label="Coordinator Alert Level (0-3)", color="red")
    axes[1].legend(); axes[1].grid(True, alpha=0.3)

    axes[2].step(x, gt_detect, where="mid", label="GT Anomaly Detection", color="green")
    if gt_h1 is None and gt_h5 is None:
        axes[2].legend(); axes[2].grid(True, alpha=0.3)

    if gt_h1 is not None:
        axes[2].step(x, gt_h1, where="mid", label="GT 1h Prediction", color="purple", linestyle="--")
    if gt_h5 is not None:
        axes[2].step(x, gt_h5, where="mid", label="GT 5h Prediction", color="brown", linestyle=":")
    axes[2].legend(); axes[2].grid(True, alpha=0.3)

    axes[-1].set_xlabel("Sequence index")
    plt.tight_layout()
    plt.show()

# =========================
# DEMO LOOP
# =========================

if __name__ == "__main__":
    models_dir = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/sensor/model"
    mlp_path = "/content/drive/MyDrive/PHD/2025/DGRNet-MLP-Versions/METROPM_MLP_model_Daily.keras"
    data_path = "/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/multivariate_long_sequences-TRAIN-Daily-DIRECT-VAR.npy"



    # create agents
    sensor_agents, master = create_robust_system(num_sensors=12, models_dir=models_dir)
    window_agent = AdaptiveWindowAgent(model_path=mlp_path)
    coordinator = CoordinatorAgent()

    # load test data
    subsequences = np.load(data_path)
    holdout = subsequences[-1000:]

    results = []

    for i, seq in enumerate(holdout, 1):
        features = seq.flatten()
        master_out = master.process_system_input(seq)
        window_out = window_agent.predict_window_size(features, seq)
        final = coordinator.fuse(master_out, window_out)
        results.append({
        "master": master_out,
        "window": window_out,
        "coordinator": final
        })

        print("\n" + "="*60)
        print(f"Sample {i}")
        print("Master:", master_out["system_decisions"])
        print("Window:", {"predicted_window": window_out.get("predicted_window"),
                          "event": window_out.get("event_type")})
        print("Coordinator:", final["alert_level"], final["scores"])

    anomaly_labels = np.load("/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/anomaly_labels_detection.npy")
    h1_labels = np.load("/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/anomaly_labels_H1.npy")
    h5_labels = np.load("/content/drive/MyDrive/PHD/2025/TEMP_OUTPUT_METROPM/anomaly_labels_H5.npy")

    plot_agent_vs_groundtruth(results, anomaly_labels, h1_labels, h5_labels, max_samples=200)

In [None]:
from google.colab import drive
drive.mount('/content/drive')