# =============================================================================
# REINFORCEMENT LEARNING AGENT TRAINING NOTEBOOK
# =============================================================================
## Purpose:
    - Load the pre-trained baseline GRU model and processed dataset.
    - Define the custom Reinforcement Learning (RL) environment (`SustainableAIAgentEnv`)
      incorporating sustainability objectives (inference FLOPs reduction and training parameter reduction proxy).
    - Train a Proximal Policy Optimization (PPO) agent within this environment.
    - Save the trained PPO agent and identify the best optimization strategy discovered.
# =============================================================================

# === Clone Repository & Install Dependencies ===

In [None]:
!rm -rf Sustainable_AI_Agent_Project
!git clone https://github.com/trongjhuongwr/Sustainable_AI_Agent_Project.git
%cd Sustainable_AI_Agent_Project

In [None]:
!pip install -q --extra-index-url https://download.pytorch.org/whl/cu121 -r /kaggle/working/Sustainable_AI_Agent_Project/requirements.txt

# 1. Import Libraries and Configuration

In [None]:
import os
import warnings
import logging
import json
import copy
import random

# Suppress specific warnings for cleaner output
os.environ["GYM_DISABLE_WARNINGS"] = "true"
warnings.filterwarnings("ignore", module="gymnasium")
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger("gymnasium").setLevel(logging.ERROR)
logging.getLogger("stable_baselines3").setLevel(logging.ERROR)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
from sklearn.metrics import accuracy_score
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import BaseCallback
from codecarbon import EmissionsTracker
from ptflops import get_model_complexity_info
import torch_pruning as tp
from tqdm.notebook import tqdm
from builtins import print as builtin_print

print("Libraries imported successfully.")

# 2. Configuration Class

In [None]:
# Defines hyperparameters, file paths, and environment parameters used throughout the notebook.
class Config:
    # --- Input/Output Paths ---
    PROCESSED_DATA_PATH = '/kaggle/working/processed_data.pt' # Input: Path to saved processed data tensors
    BASELINE_MODEL_PATH = '/kaggle/working/baseline_model.pth' # Input: Path to the trained baseline model state dictionary
    AGENT_SAVE_PATH = "/kaggle/working/sustainable_ai_agent_ppo.zip" # Output: Path to save/load the trained PPO agent
    BEST_ACTION_SAVE_PATH = "/kaggle/working/best_action.json" # Output: Path to save information about the best discovered action
    TENSORBOARD_LOG_PATH = "/kaggle/working/ppo_tensorboard/" # Output: Directory for TensorBoard logs

    # --- Data Parameters (consistent with baseline training) ---
    SEQUENCE_LENGTH = 30
    INPUT_DIM = 4

    # --- Model Architecture Parameters (must match baseline) ---
    HIDDEN_DIM = 256
    N_LAYERS = 2
    OUTPUT_DIM = 1
    DROPOUT = 0.2 # Required for state_dict compatibility, though inactive in eval mode

    # --- RL Agent Training Parameters ---
    TOTAL_TIMESTEPS = 10000 # Total number of environment steps for training
    TIMESTEPS_PER_CHUNK = 2000 # Save agent state every N steps
    SEED = 42 # For reproducibility

    # --- RL Environment Parameters (Reward shaping and constraints) ---
    ACCURACY_PENALTY_THRESHOLD = 0.95 # Threshold below baseline accuracy triggering heavy penalty (e.g., 0.95 = 5% drop allowed)
    ACCURACY_REWARD_SCALE = 10.0 # Scaling factor for accuracy-based reward/penalty
    FLOPS_REWARD_SCALE_INFERENCE = 1.5 # Scaling factor for inference FLOPs reduction reward
    PARAMS_REWARD_SCALE_TRAINING = 1.0 # NEW: Scaling factor for parameter reduction reward (training energy proxy)
    INACTION_PENALTY = -1.0 # Penalty for choosing action 0 (no optimization)
    ENV_ERROR_REWARD = -10.0 # Heavy penalty if an environment step fails (e.g., optimization error)

    # --- Evaluation Parameters (within the environment) ---
    EVAL_BATCH_SIZE = 64 # Batch size used for evaluation within the environment
    CODECARBON_BATCHES = 10 # Number of batches used for CodeCarbon energy measurement during env init

    # --- Computation Device ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seeds for reproducibility
random.seed(Config.SEED)
np.random.seed(Config.SEED)
torch.manual_seed(Config.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(Config.SEED)
    torch.cuda.manual_seed_all(Config.SEED)
    # Optional: Enable deterministic algorithms for full reproducibility, may impact performance
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

print(f"Configuration loaded. Using device: {Config.DEVICE}")
print(f"Seed set to: {Config.SEED}")
print(f"Loading processed data from: {Config.PROCESSED_DATA_PATH}")
print(f"Loading baseline model from: {Config.BASELINE_MODEL_PATH}")
print(f"Agent will be saved to: {Config.AGENT_SAVE_PATH}")

# 3. Utility Functions and Model Definition

In [None]:
# Includes functions for evaluating model metrics (accuracy, parameters, FLOPs, energy),
# applying optimization techniques (pruning, quantization), and the `WeatherGRU` model class
# definition (required for loading the baseline state).

def count_parameters(model):
    """Counts the number of trainable parameters in a PyTorch model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def evaluate_model(model, loader, codecarbon_batches=10):
    """Performs comprehensive evaluation (accuracy, params, FLOPs, energy).
    Crucially, forces execution on CPU for consistent energy measurement.
    Args:
        model (nn.Module): The model instance to evaluate.
        loader (DataLoader): DataLoader for the test/evaluation dataset.
        codecarbon_batches (int): Number of batches to run inference on for energy measurement.
    Returns:
        dict: A dictionary containing evaluation metrics.
    """
    model_cpu = copy.deepcopy(model).cpu() # Ensure evaluation is on CPU
    device = torch.device("cpu")
    model_cpu.eval() # Set model to evaluation mode

    # 1. Accuracy Calculation
    y_true, y_pred = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model_cpu(inputs)
            preds = (outputs > 0.5).float() # Binary classification threshold
            y_pred.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
    accuracy = accuracy_score(y_true, y_pred)

    # 2. Energy and CO2 Emission Estimation (using CodeCarbon)
    energy_kwh = 0
    co2_eq_kg = 0
    try:
        # Configure tracker for process-level tracking and local output
        tracker = EmissionsTracker(log_level="error", output_dir="/kaggle/working/", tracking_mode="process")
        tracker.start()
        with torch.no_grad():
            for i, (inputs, _) in enumerate(loader):
                if i >= codecarbon_batches: break # Limit measurement duration
                model_cpu(inputs.to(device))
        tracker.stop()
        # Safely access emission data, handling cases where tracking might be too short
        if tracker.final_emissions_data:
             energy_kwh = tracker.final_emissions_data.energy_consumed or 0
             co2_eq_kg = tracker.final_emissions_data.emissions or 0
        else:
            builtin_print("Warning: CodeCarbon tracker did not record final emissions data (runtime may be too short).")
    except Exception as e:
        builtin_print(f"Warning: CodeCarbon measurement failed - {e}")

    # 3. Parameter Count
    params = count_parameters(model_cpu)

    # 4. FLOPs Estimation (using ptflops)
    flops = 0
    try:
        # Provide input dimensions for complexity analysis
        macs, _ = get_model_complexity_info(
            model_cpu, (Config.SEQUENCE_LENGTH, Config.INPUT_DIM),
            as_strings=False, print_per_layer_stat=False, verbose=False)
        flops = macs * 2 # Approximate FLOPs = 2 * MACs
    except (KeyError, AttributeError, RuntimeError, TypeError):
        # ptflops often fails with quantized models or certain layer types.
        # Report 0 FLOPs as a signal of significant theoretical reduction in these cases.
        # builtin_print("Note: Could not calculate FLOPs via ptflops (likely quantized model). Reporting 0.")
        flops = 0

    return {
        "accuracy": accuracy,
        "energy_kwh": energy_kwh,
        "co2_eq_kg": co2_eq_kg,
        "flops": max(0.0, flops), # Ensure non-negative
        "params": max(0.0, params) # Ensure non-negative
    }

def apply_l1_pruning(model, amount):
    """Applies L1 unstructured magnitude pruning to a model copy.
    Excludes GRU layers to preserve recurrent dynamics.
    Args:
        model (nn.Module): The original model.
        amount (float): The target pruning ratio (0.0 to 1.0).
    Returns:
        nn.Module: A new, pruned model instance on the CPU.
    """
    if not (0 < amount <= 1.0):
         # Return a copy if amount is 0 or invalid to avoid errors
         return copy.deepcopy(model).cpu()

    model_to_prune = copy.deepcopy(model).cpu()
    model_to_prune.eval() # Pruning is typically done in evaluation mode

    # Identify GRU layers to ignore during pruning
    ignored_layers = [m for m in model_to_prune.modules() if isinstance(m, nn.GRU)]

    # Create dummy input for dependency graph analysis by the pruner
    example_inputs = torch.randn(1, Config.SEQUENCE_LENGTH, Config.INPUT_DIM)

    # Use L1 magnitude importance
    importance = tp.importance.MagnitudeImportance(p=1)

    # Initialize the pruner
    pruner = tp.pruner.MagnitudePruner(
        model_to_prune,
        example_inputs,
        importance=importance,
        pruning_ratio=amount, # Global pruning ratio
        ignored_layers=ignored_layers,
    )

    # Apply the pruning step (zeros out weights)
    try:
        pruner.step()
    except Exception as e:
        builtin_print(f"Error during pruning step with amount {amount}: {e}")
        # Return the unpruned copy in case of error
        return copy.deepcopy(model).cpu()

    return model_to_prune

def apply_dynamic_quantization(model):
    """Applies PyTorch dynamic quantization (weights to INT8) to relevant layers.
    Args:
        model (nn.Module): The original model.
    Returns:
        nn.Module: A new, quantized model instance on the CPU.
    """
    quantized_model = copy.deepcopy(model).cpu()
    quantized_model.eval() # Quantization requires evaluation mode

    # Specify layers to quantize (GRU and Linear are common targets)
    try:
        quantized_model = torch.quantization.quantize_dynamic(
            quantized_model,
            {nn.GRU, nn.Linear}, # Set of layer types to quantize
            dtype=torch.qint8 # Target data type
        )
    except Exception as e:
        builtin_print(f"Error during dynamic quantization: {e}")
        # Return the original copy if quantization fails
        return copy.deepcopy(model).cpu()

    return quantized_model

class WeatherGRU(nn.Module):
    """GRU model definition (replicated from baseline notebook).
    Needed here to instantiate the model before loading the state_dict.
    """
    def __init__(self, input_dim, hidden_dim, n_layers, output_dim, dropout):
        super(WeatherGRU, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, n_layers,
                          batch_first=True, dropout=dropout if n_layers > 1 else 0)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        gru_out, _ = self.gru(x)
        # Use the output from the last time step
        out = self.fc(gru_out[:, -1, :])
        return self.sigmoid(out)

print("Utility functions and WeatherGRU class defined.")

# 4. Load Baseline Model and Test Data

In [None]:
# Loads the necessary artifacts generated by the `train_baseline.ipynb` notebook:
# the processed test dataset tensors and the state dictionary of the trained baseline `WeatherGRU` model.

# Load processed test data tensors
try:
    processed_data = torch.load(Config.PROCESSED_DATA_PATH)
    X_test_tensor = processed_data['X_test']
    y_test_tensor = processed_data['y_test']
    builtin_print(f"Processed test data loaded from {Config.PROCESSED_DATA_PATH}")
except FileNotFoundError:
    builtin_print(f"Error: Processed data file not found at {Config.PROCESSED_DATA_PATH}. Please run train_baseline.ipynb first.")
    raise
except Exception as e:
    builtin_print(f"Error loading processed data: {e}")
    raise

# Create DataLoader for the test set (used for environment initialization and evaluation)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=Config.EVAL_BATCH_SIZE, shuffle=False)
builtin_print("Test DataLoader created.")

# Initialize the GRU model structure
baseline_model_instance = WeatherGRU(
    input_dim=Config.INPUT_DIM,
    hidden_dim=Config.HIDDEN_DIM,
    n_layers=Config.N_LAYERS,
    output_dim=Config.OUTPUT_DIM,
    dropout=Config.DROPOUT # Ensure dropout matches the saved model structure
)

# Load the trained baseline model state dictionary
try:
    # map_location='cpu' ensures the model loads correctly regardless of the device it was trained on
    baseline_model_instance.load_state_dict(torch.load(Config.BASELINE_MODEL_PATH, map_location=torch.device('cpu')))
    baseline_model_instance.eval() # Set to evaluation mode by default
    builtin_print(f"Baseline model state loaded from {Config.BASELINE_MODEL_PATH}")
except FileNotFoundError:
    builtin_print(f"Error: Baseline model file not found at {Config.BASELINE_MODEL_PATH}. Please run train_baseline.ipynb first.")
    raise
except Exception as e:
    builtin_print(f"Error loading baseline model state: {e}")
    raise

# Keep a clean copy on the CPU for the environment
baseline_model_cpu = baseline_model_instance.cpu()

# 5. Define Custom Reinforcement Learning Environment

In [None]:
# This section defines the `SustainableAIAgentEnv` class, inheriting from `gym.Env`.
# It encapsulates the optimization problem, defining the action space (pruning ratios, quantization),
  # observation space (performance metrics relative to baseline), and the crucial reward function
  # that incentivizes both accuracy maintenance and resource reduction (inference FLOPs and training parameters).

class SustainableAIAgentEnv(gym.Env):
    """Custom Gymnasium environment for optimizing GRU model sustainability.

    Action Space:
        Discrete(16): Corresponds to combinations of L1 pruning ratios [0.0, 0.1, ..., 0.7]
                     and the application of dynamic quantization.
                     Actions 0-7: Pruning only.
                     Actions 8-15: Pruning + Quantization.

    Observation Space (Box): [current_accuracy, accuracy_delta_vs_baseline, params_reduction_ratio, flops_reduction_ratio]

    Reward Function:
        Combines rewards/penalties for:
        - Accuracy change relative to baseline (heavy penalty for drops > threshold).
        - Reduction in estimated inference FLOPs.
        - Reduction in model parameters (proxy for training energy reduction).
        - Inaction (choosing 0% pruning and no quantization).
    """
    metadata = {'render_modes': []} # Required by Gymnasium API

    def __init__(self, baseline_model, data_loader):
        """Initializes the environment.
        Args:
            baseline_model (nn.Module): The pre-trained baseline model instance (on CPU).
            data_loader (DataLoader): DataLoader for evaluation (typically test set).
        """
        super().__init__()
        self.base_model = baseline_model.cpu() # Ensure base model is on CPU
        self.base_model.eval()
        self.loader = data_loader

        builtin_print("Initializing RL environment: Calculating baseline metrics...")
        # Perform a full evaluation once at initialization to get baseline reference metrics
        self.baseline_metrics = evaluate_model(self.base_model, self.loader, Config.CODECARBON_BATCHES)

        # Handle potential division by zero if baseline FLOPs/params are 0 (e.g., from ptflops error)
        if self.baseline_metrics.get("flops", 0) == 0:
            self.baseline_metrics["flops"] = 1 # Avoid division by zero
            builtin_print("Warning: Baseline FLOPs calculated as 0. Setting to 1 to prevent division errors.")
        if self.baseline_metrics.get("params", 0) == 0:
             self.baseline_metrics["params"] = 1 # Avoid division by zero
             builtin_print("Warning: Baseline Parameters calculated as 0. Setting to 1 to prevent division errors.")

        builtin_print(f"Baseline Metrics initialized: {self.baseline_metrics}")

        # --- Action Space Definition ---
        self.pruning_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
        n_pruning_levels = len(self.pruning_levels)
        self.action_space = spaces.Discrete(n_pruning_levels * 2) # 8 pruning levels * (prune_only + prune_and_quantize)

        # --- Observation Space Definition ---
        # [current_accuracy, accuracy_delta, params_reduction_ratio, flops_reduction_ratio]
        # Bounds ensure values stay within a reasonable range (e.g., accuracy 0-1, reduction 0-1, delta -1 to 1)
        self.observation_space = spaces.Box(
            low=np.array([0.0, -1.0, 0.0, 0.0], dtype=np.float32),
            high=np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32),
            dtype=np.float32
        )

        # Store initial observation for reset
        self.initial_obs = np.array([
            self.baseline_metrics.get("accuracy", 0.0),
            0.0, # Initial accuracy delta
            0.0, # Initial params reduction
            0.0  # Initial FLOPs reduction
        ], dtype=np.float32)

    def step(self, action):
        """Executes one step in the environment based on the agent's action.
        Args:
            action (int): The discrete action selected by the agent.
        Returns:
            tuple: (observation, reward, terminated, truncated, info)
        """
        terminated = True # This environment is episodic with a single step
        truncated = False # Not using truncation
        info = {} # Dictionary for logging diagnostic information

        try:
            # --- 1. Decode Action ---
            action_val = action.item() if isinstance(action, torch.Tensor) else action
            n_pruning_levels = len(self.pruning_levels)
            pruning_idx = action_val % n_pruning_levels
            pruning_amount = self.pruning_levels[pruning_idx]
            apply_quant = action_val >= n_pruning_levels

            # --- 2. Apply Optimization(s) ---
            # Start with a fresh copy of the baseline model (on CPU)
            optimized_model = copy.deepcopy(self.base_model)
            if pruning_amount > 0:
                optimized_model = apply_l1_pruning(optimized_model, pruning_amount)
            if apply_quant:
                optimized_model = apply_dynamic_quantization(optimized_model)
            # If action is 0 (no prune, no quant), optimized_model remains a copy of the baseline

            # --- 3. Lightweight Evaluation ---
            # Evaluate the optimized model quickly (e.g., on one batch) to get metrics for reward/state
            metrics = self._evaluate_lightweight(optimized_model)

            # --- 4. Calculate Reward Components (MODIFIED FOR DIRECTION 1) ---
            baseline_acc = self.baseline_metrics.get("accuracy", 0.0)
            current_acc = metrics.get("accuracy", 0.0)
            accuracy_delta = current_acc - baseline_acc

            # a) Accuracy Reward/Penalty
            if current_acc < (baseline_acc * Config.ACCURACY_PENALTY_THRESHOLD):
                accuracy_reward = Config.ENV_ERROR_REWARD # Heavy penalty for significant accuracy drop
            else:
                accuracy_reward = accuracy_delta * Config.ACCURACY_REWARD_SCALE # Proportional reward/penalty

            # b) Inference Resource Reward (FLOPs Reduction)
            baseline_flops = self.baseline_metrics.get("flops", 1) # Default to 1 if missing
            current_flops = metrics.get("flops", baseline_flops)
            flops_reduction_ratio = max(0.0, 1.0 - (current_flops / baseline_flops)) # Ensure non-negative
            resource_reward_inference = flops_reduction_ratio * Config.FLOPS_REWARD_SCALE_INFERENCE

            # c) Training Resource Reward Proxy (Parameter Reduction) - NEW
            baseline_params = self.baseline_metrics.get("params", 1) # Default to 1 if missing
            current_params = metrics.get("params", baseline_params)
            params_reduction_ratio = max(0.0, 1.0 - (current_params / baseline_params)) # Ensure non-negative
            resource_reward_training = params_reduction_ratio * Config.PARAMS_REWARD_SCALE_TRAINING

            # d) Inaction Penalty
            inaction_penalty = Config.INACTION_PENALTY if pruning_amount == 0 and not apply_quant else 0.0

            # e) Total Reward
            reward = accuracy_reward + resource_reward_inference + resource_reward_training + inaction_penalty

            # --- 5. Construct Observation (Next State) ---
            obs = np.array([
                current_acc,
                accuracy_delta,
                params_reduction_ratio,
                flops_reduction_ratio
            ], dtype=np.float32)

            # Clip observation to ensure it's within the defined space bounds
            obs = np.clip(obs, self.observation_space.low, self.observation_space.high)

            # --- 6. Populate Info Dictionary --- (Added reward components)
            info = {
                "action": action_val,
                "pruning_amount": pruning_amount,
                "quantized": apply_quant,
                "accuracy": current_acc,
                "accuracy_delta": accuracy_delta,
                "flops_reduction": flops_reduction_ratio,
                "params_reduction": params_reduction_ratio,
                "reward_components": {
                    "accuracy": accuracy_reward,
                    "flops_inference": resource_reward_inference,
                    "params_training": resource_reward_training,
                    "inaction": inaction_penalty
                },
                "reward": reward
            }

            return obs, reward, terminated, truncated, info

        except Exception as e:
            builtin_print(f"Error during environment step with action {action}: {e}")
            # Return initial state, heavy penalty, and terminate episode on error
            obs, _ = self.reset()
            return obs, Config.ENV_ERROR_REWARD, True, False, {"error": str(e)}

    def _evaluate_lightweight(self, model):
        """Performs a quick evaluation on a single batch for state/reward calculation.
        Handles potential ptflops errors for quantized models by estimating reductions.
        Always runs on CPU.
        Args:
            model (nn.Module): The optimized model (on CPU).
        Returns:
            dict: Dictionary with 'accuracy', 'flops', 'params'.
        """
        model.eval()
        device = torch.device("cpu") # Lightweight eval always on CPU
        # Initialize metrics with baseline defaults in case of evaluation error
        metrics = {
            "accuracy": self.baseline_metrics.get("accuracy", 0.0),
            "flops": self.baseline_metrics.get("flops", 1),
            "params": self.baseline_metrics.get("params", 1)
        }

        try:
            inputs, labels = next(iter(self.loader)) # Get a single batch
            inputs, labels = inputs.to(device), labels.to(device)

            # Calculate accuracy on the batch
            with torch.no_grad():
                outputs = model(inputs)
                preds = (outputs > 0.5).float()
                # Use numpy for accuracy calculation here
                accuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
                metrics["accuracy"] = accuracy

            # Attempt to calculate FLOPs and Params
            try:
                macs, params_val = get_model_complexity_info(
                    model, (Config.SEQUENCE_LENGTH, Config.INPUT_DIM),
                    as_strings=False, print_per_layer_stat=False, verbose=False)
                flops_val = macs * 2
                metrics["flops"] = flops_val
                metrics["params"] = params_val
            except (RuntimeError, KeyError, AttributeError, TypeError):
                 # ptflops failed (likely due to quantization)
                 # Estimate benefits:
                 metrics["flops"] = 0 # Signal significant theoretical reduction
                 # Estimate quantized params roughly as 1/4 of baseline float32 params
                 metrics["params"] = self.baseline_metrics.get("params", 1) / 4

        except StopIteration:
            builtin_print("Warning: DataLoader is empty in _evaluate_lightweight.")
        except Exception as e:
            builtin_print(f"Error during _evaluate_lightweight: {e}")
            # Keep default metrics (baseline) if evaluation fails

        # Ensure reported metrics are non-negative
        metrics["flops"] = max(0.0, metrics["flops"])
        metrics["params"] = max(0.0, metrics["params"])

        return metrics

    def reset(self, seed=None, options=None):
        """Resets the environment to the initial state.
        Args:
            seed (int, optional): Seed for the random number generator.
            options (dict, optional): Additional options (not used here).
        Returns:
            tuple: (initial_observation, info)
        """
        super().reset(seed=seed)
        # Return the pre-calculated initial observation corresponding to the baseline model
        info = {} # Must return an info dictionary
        return self.initial_obs, info

    def render(self):
        """Render the environment (not applicable here)."""
        pass

    def close(self):
        """Close the environment (cleanup, not applicable here)."""
        pass

print("SustainableAIAgentEnv class defined.")

# 6. Initialize and Verify Environment

In [None]:
# Instantiates the custom environment and performs a check using
# `stable_baselines3.common.env_checker` to ensure API compatibility.

try:
    # Instantiate the environment with the baseline model and test loader
    env = SustainableAIAgentEnv(model=baseline_model_cpu, loader=test_loader)

    # Optional: Verify the environment conforms to the Gymnasium API
    # check_env(env) # This can sometimes raise warnings even if functional; skip if causing issues
    builtin_print("\nRL Environment initialized successfully.")

except Exception as e:
    builtin_print(f"\nError initializing RL Environment: {e}")
    raise

# 7. Train PPO Agent

In [None]:
# Initializes or loads a PPO agent from `stable_baselines3` and trains it using the custom environment.
# Training proceeds in chunks, saving the agent's state periodically.
# A custom callback is used to log detailed information about each optimization attempt during training.

# --- Agent Initialization or Loading ---
if os.path.exists(Config.AGENT_SAVE_PATH):
    builtin_print(f"\n--- Loading pre-trained Agent from: {Config.AGENT_SAVE_PATH} ---")
    try:
        agent = PPO.load(Config.AGENT_SAVE_PATH, env=env, device=Config.DEVICE)
        builtin_print("Agent loaded successfully.")
    except Exception as e:
        builtin_print(f"Error loading agent: {e}. Creating a new agent.")
        agent = PPO("MlpPolicy", env, verbose=0, device=Config.DEVICE,
                    tensorboard_log=Config.TENSORBOARD_LOG_PATH, seed=Config.SEED)
        builtin_print("New PPO Agent created.")
else:
    builtin_print(f"\n--- Agent file not found at {Config.AGENT_SAVE_PATH}. Creating a new PPO Agent. ---")
    agent = PPO("MlpPolicy", env, verbose=0, device=Config.DEVICE,
                tensorboard_log=Config.TENSORBOARD_LOG_PATH, seed=Config.SEED)
    builtin_print("New PPO Agent created.")

# --- Custom Callback for History Logging ---
experiment_history = [] # List to store info dictionaries from each step

class HistoryCallback(BaseCallback):
    """A custom callback to store the info dictionary returned by env.step()."""
    def _on_step(self) -> bool:
        # infos is a list of info dicts, one for each env (here, only one)
        if 'infos' in self.locals:
            for info in self.locals['infos']:
                # Ensure info is not empty and contains expected keys before logging
                if info and "action" in info and "reward" in info:
                    experiment_history.append(info.copy()) # Store a copy
        return True # Return True to continue training

history_callback = HistoryCallback()

# --- Training Loop (Chunk-based with saving) ---
timesteps_trained_so_far = agent.num_timesteps if hasattr(agent, 'num_timesteps') else 0

if timesteps_trained_so_far >= Config.TOTAL_TIMESTEPS:
    builtin_print(f"\n--- Agent already trained for {timesteps_trained_so_far} timesteps. Skipping training. ---")
else:
    remaining_timesteps = Config.TOTAL_TIMESTEPS - timesteps_trained_so_far
    # Calculate number of chunks needed, rounding up
    chunks_to_run = int(np.ceil(remaining_timesteps / Config.TIMESTEPS_PER_CHUNK))

    builtin_print(f"\n--- Starting/Resuming Agent Training ---")
    builtin_print(f"Target Timesteps: {Config.TOTAL_TIMESTEPS} | Current Timesteps: {timesteps_trained_so_far}")
    builtin_print(f"Training will proceed in {chunks_to_run} chunk(s) of up to {Config.TIMESTEPS_PER_CHUNK} steps each.")

    for i in range(chunks_to_run):
        current_chunk_number = i + 1
        # Calculate steps for this specific chunk, ensuring not to exceed total timesteps
        steps_this_chunk = min(Config.TIMESTEPS_PER_CHUNK, Config.TOTAL_TIMESTEPS - agent.num_timesteps)

        if steps_this_chunk <= 0: break # Safety check

        builtin_print(f"\n--- Running Training Chunk {current_chunk_number}/{chunks_to_run} ({steps_this_chunk} steps) ---")

        try:
            # Execute the learning process for the calculated number of steps
            agent.learn(
                total_timesteps=steps_this_chunk,
                reset_num_timesteps=False, # Continue timestep count across chunks
                progress_bar=True, # Display tqdm progress bar
                callback=history_callback # Log step info using the custom callback
            )

            # Save the agent's state after completing the chunk
            agent.save(Config.AGENT_SAVE_PATH)
            current_total_steps = agent.num_timesteps
            builtin_print(f"Agent state saved to {Config.AGENT_SAVE_PATH}. Total timesteps trained: {current_total_steps}.")

        except Exception as e:
            builtin_print(f"Error occurred during training chunk {current_chunk_number}: {e}")
            builtin_print("Attempting to save agent state before exiting...")
            # Try saving an error state for potential recovery
            try:
                agent.save(Config.AGENT_SAVE_PATH + "_error")
                builtin_print(f"Agent error state saved to {Config.AGENT_SAVE_PATH}_error")
            except Exception as save_e:
                builtin_print(f"Could not save agent state after error: {save_e}")
            break # Stop the training loop if an error occurs

        # Check if the total timesteps target has been reached
        if agent.num_timesteps >= Config.TOTAL_TIMESTEPS:
            builtin_print(f"\n--- 🎉 Target of {Config.TOTAL_TIMESTEPS} timesteps reached. Finishing training. ---")
            break

builtin_print("\n--- Agent Training Loop Completed ---")

# 8. Analyze Training History and Determine Best Action

In [None]:
# Examines the `experiment_history` logged by the callback to find the optimization action
# (pruning ratio + quantization) that yielded the highest reward during training.
# If no history is available (e.g., agent loaded after training), it predicts the best action
# based on the agent's final learned policy.
# The best action information is saved to a JSON file for potential use in the evaluation notebook.

best_action_info = None
best_action_code = None

# --- Analysis based on logged history (preferred) ---
if experiment_history:
    # Convert history list of dicts to a Pandas DataFrame for easier analysis
    history_df = pd.DataFrame(experiment_history)
    builtin_print("\n--- Agent Training Experiment History (Sample) ---")
    # Display head and tail for a quick overview
    builtin_print("First 5 steps:")
    builtin_print(history_df.head().to_markdown(index=False))
    builtin_print("\nLast 5 steps:")
    builtin_print(history_df.tail().to_markdown(index=False))

    # Find the row (step) corresponding to the maximum reward achieved
    if not history_df.empty and 'reward' in history_df.columns:
        best_step_index = history_df['reward'].idxmax()
        best_step_data = history_df.loc[best_step_index]

        # Extract relevant information from the best step
        best_action_info = {
            'source': 'history_max_reward',
            'step_index': int(best_step_index), # Ensure index is standard int
            'pruning_amount': best_step_data.get('pruning_amount', 0.0),
            'quantized': best_step_data.get('quantized', False),
            'reward_at_step': best_step_data.get('reward', None),
            'accuracy_at_step': best_step_data.get('accuracy', None), # From lightweight eval
            'flops_reduct_at_step': best_step_data.get('flops_reduction', None),
            'params_reduct_at_step': best_step_data.get('params_reduction', None),
            # Optionally include reward components if logged
            'reward_components_at_step': best_step_data.get('reward_components', None)
        }
        best_action_code = int(best_step_data.get('action', -1)) # Get the action code as int

        builtin_print("\n" + "="*60)
        builtin_print("BEST OPTIMIZATION STRATEGY FOUND DURING TRAINING (Max Reward)")
        builtin_print("="*60)
        builtin_print(f"Action Code: {best_action_code}")
        builtin_print(f"Pruning Ratio: {best_action_info['pruning_amount']*100:.0f}%")
        builtin_print(f"Quantization Applied: {best_action_info['quantized']}")
        builtin_print(f"Achieved Reward: {best_action_info['reward_at_step']:.4f}")
        builtin_print(f"  (Accuracy at step [lightweight eval]: {best_action_info['accuracy_at_step']:.4f})")
        builtin_print(f"  (FLOPs Reduction at step: {best_action_info['flops_reduct_at_step']:.4f})")
        builtin_print(f"  (Params Reduction at step: {best_action_info['params_reduct_at_step']:.4f})")
        builtin_print("="*60)
    else:
        builtin_print("\nWarning: Experiment history is empty or lacks 'reward' column. Cannot determine best action from history.")

# --- Fallback: Predict action from the final agent policy ---
if best_action_info is None:
    builtin_print("\n--- Predicting best action from the final trained agent policy ---")
    try:
        # Ensure the environment is reset to get the initial observation
        obs, _ = env.reset()
        # Predict the action deterministically based on the learned policy
        action_pred, _ = agent.predict(obs, deterministic=True)
        best_action_code = action_pred.item() # Get the action code as int

        # Decode the predicted action to get pruning amount and quantization status
        n_pruning_levels = len(env.pruning_levels)
        pruning_idx_pred = best_action_code % n_pruning_levels
        pruning_amount_pred = env.pruning_levels[pruning_idx_pred]
        quantized_pred = best_action_code >= n_pruning_levels

        # Store the predicted action information
        best_action_info = {
            'source': 'final_policy_prediction',
            'pruning_amount': pruning_amount_pred,
            'quantized': quantized_pred
            # Note: Reward and metrics are not directly available from predict
        }

        builtin_print("\n" + "="*60)
        builtin_print("BEST ACTION PREDICTED BY FINAL AGENT POLICY")
        builtin_print("="*60)
        builtin_print(f"Action Code: {best_action_code}")
        builtin_print(f"Pruning Ratio: {best_action_info['pruning_amount']*100:.0f}%")
        builtin_print(f"Quantization Applied: {best_action_info['quantized']}")
        builtin_print("="*60)

    except Exception as e:
        builtin_print(f"\nError predicting action from final policy: {e}")
        builtin_print("Could not determine the best action.")

# --- Save the best action information to a JSON file ---
# This allows the evaluation notebook to directly use the best strategy without reloading the agent
if best_action_info:
    try:
        # Ensure data types are JSON serializable (convert numpy types if present)
        serializable_info = {k: (v.item() if isinstance(v, (np.generic, np.ndarray)) else v) for k, v in best_action_info.items()}

        # Add the action code itself to the saved info
        if best_action_code is not None:
            serializable_info['action_code'] = best_action_code

        with open(Config.BEST_ACTION_SAVE_PATH, 'w') as f:
            json.dump(serializable_info, f, indent=4)
        builtin_print(f"\nBest action information saved to {Config.BEST_ACTION_SAVE_PATH}")
    except Exception as e:
        builtin_print(f"\nError saving best action information to JSON: {e}")
else:
    builtin_print("\nNo best action information was determined, skipping save to JSON.")