# =============================================================================

# TRAIN RL AGENT NOTEBOOK

    Purpose: Load baseline model & processed data, define RL environment, train PPO agent, save trained agent.

# =============================================================================

# Import Configuration and Libraries

In [None]:
import os
import warnings
import logging
import json

# Suppress warnings
os.environ["GYM_DISABLE_WARNINGS"] = "true"
warnings.filterwarnings("ignore", module="gymnasium")
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger("gymnasium").setLevel(logging.ERROR)
logging.getLogger("stable_baselines3").setLevel(logging.ERROR)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
import copy
from sklearn.metrics import accuracy_score
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from codecarbon import EmissionsTracker
from ptflops import get_model_complexity_info
import torch_pruning as tp
from tqdm.notebook import tqdm # For agent training progress bar
from builtins import print as builtin_print # To avoid conflict with tqdm print

print("Libraries imported.")

# Configuration Class

In [None]:
class Config:
    # Paths for Loading/Saving
    PROCESSED_DATA_PATH = '/kaggle/working/processed_data.pt' # Input path for processed data
    BASELINE_MODEL_PATH = '/kaggle/working/baseline_model.pth' # Input path for baseline model
    AGENT_SAVE_PATH = "/kaggle/working/sustainable_ai_agent_ppo.zip" # Output path for trained agent
    BEST_ACTION_SAVE_PATH = "/kaggle/working/best_action.json" # Output path for best action info

    # Data Parameters (needed for model init and evaluation)
    SEQUENCE_LENGTH = 30
    INPUT_DIM = 4

    # Model Parameters (must match baseline)
    HIDDEN_DIM = 256
    N_LAYERS = 2
    OUTPUT_DIM = 1
    DROPOUT = 0.2 # Although dropout is off in eval mode, keep consistent for loading state_dict

    # RL Agent Training
    TOTAL_TIMESTEPS = 10000
    TIMESTEPS_PER_CHUNK = 2000
    TENSORBOARD_LOG_PATH = "/kaggle/working/ppo_tensorboard/"

    # RL Environment Parameters
    ACCURACY_PENALTY_THRESHOLD = 0.95 # Accuracy drops below 95% of baseline -> heavy penalty
    ACCURACY_REWARD_SCALE = 10
    FLOPS_REWARD_SCALE = 1.5
    INACTION_PENALTY = -1.0
    ENV_ERROR_REWARD = -10.0 # Reward if environment step fails

    # Evaluation Parameters (for baseline metrics in env)
    EVAL_BATCH_SIZE = 64
    CODECARBON_BATCHES = 10 # Number of batches for energy measurement

    # Device
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Running on device: {Config.DEVICE}")
print(f"Loading processed data from: {Config.PROCESSED_DATA_PATH}")
print(f"Loading baseline model from: {Config.BASELINE_MODEL_PATH}")
print(f"Agent will be saved to: {Config.AGENT_SAVE_PATH}")

# Helper Functions and Model Definition

# (Copied from baseline notebook, needed for environment and loading)

In [None]:
def count_parameters(model):
    """Counts trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def evaluate_model(model, loader, codecarbon_batches=10):
    """Comprehensive evaluation (accuracy, params, flops, energy). Runs on CPU."""
    model_cpu = copy.deepcopy(model).cpu()
    device = torch.device("cpu")
    model_cpu.eval() # Ensure model is in evaluation mode

    # Accuracy
    y_true, y_pred = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model_cpu(inputs)
            preds = (outputs > 0.5).float()
            y_pred.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
    accuracy = accuracy_score(y_true, y_pred)

    # Energy and CO2 with CodeCarbon
    energy_kwh = 0
    co2_eq_kg = 0
    try:
        # Save emissions locally in /kaggle/working/
        tracker = EmissionsTracker(log_level="error", output_dir="/kaggle/working/", tracking_mode="process")
        tracker.start()
        with torch.no_grad():
            for i, (inputs, _) in enumerate(loader):
                if i >= codecarbon_batches: break # Limit measurement time
                model_cpu(inputs.to(device))
        tracker.stop()
        # Check if emissions data is available (might not be if run time is too short)
        if tracker.final_emissions_data:
             energy_kwh = tracker.final_emissions_data.energy_consumed or 0
             co2_eq_kg = tracker.final_emissions_data.emissions or 0
        else:
            builtin_print("Warning: CodeCarbon tracker did not record final emissions data.")
    except Exception as e:
        builtin_print(f"Warning: CodeCarbon measurement failed - {e}")


    # Params
    params = count_parameters(model_cpu)

    # FLOPs with ptflops
    flops = 0
    try:
        macs, _ = get_model_complexity_info(
            model_cpu, (Config.SEQUENCE_LENGTH, Config.INPUT_DIM),
            as_strings=False, print_per_layer_stat=False, verbose=False)
        flops = macs * 2
    except (KeyError, AttributeError, RuntimeError, TypeError): # Added TypeError
        # Likely a quantized model or unsupported layer
        # builtin_print("Note: Could not calculate FLOPs via ptflops (likely quantized model). Reporting 0.")
        flops = 0 # Report 0 to signify significant theoretical reduction

    return {
        "accuracy": accuracy, "energy_kwh": energy_kwh, "co2_eq_kg": co2_eq_kg,
        "flops": flops, "params": params,
    }

def apply_l1_pruning(model, amount):
    """Applies L1 pruning, ignoring GRU layers. Returns CPU model."""
    model_to_prune = copy.deepcopy(model).cpu()
    model_to_prune.eval() # Pruning should be done in eval mode for consistency
    ignored_layers = [m for m in model_to_prune.modules() if isinstance(m, nn.GRU)]
    # Use random input example for pruning analysis
    example_inputs = torch.randn(1, Config.SEQUENCE_LENGTH, Config.INPUT_DIM)
    pruner = tp.pruner.MagnitudePruner(
        model_to_prune,
        example_inputs=example_inputs,
        importance=tp.importance.MagnitudeImportance(p=1), # L1 norm
        pruning_ratio=amount,
        ignored_layers=ignored_layers
    )
    pruner.step()
    return model_to_prune

def apply_dynamic_quantization(model):
    """Applies dynamic quantization. Returns CPU model."""
    quantized_model = copy.deepcopy(model).cpu()
    quantized_model.eval() # Quantization requires eval mode
    # Quantize GRU and Linear layers
    quantized_model = torch.quantization.quantize_dynamic(
        quantized_model, {nn.GRU, nn.Linear}, dtype=torch.qint8
    )
    return quantized_model

class WeatherGRU(nn.Module):
    """GRU model definition (needed to load state_dict)."""
    def __init__(self, input_dim, hidden_dim, n_layers, output_dim, dropout):
        super(WeatherGRU, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, n_layers,
                          batch_first=True, dropout=dropout if n_layers > 1 else 0)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        gru_out, _ = self.gru(x)
        out = self.fc(gru_out[:, -1, :])
        return self.sigmoid(out)

print("Helper functions and WeatherGRU class defined.")

# Load Baseline Model and Test Data

In [None]:
# Load processed test data
try:
    processed_data = torch.load(Config.PROCESSED_DATA_PATH)
    X_test_tensor = processed_data['X_test']
    y_test_tensor = processed_data['y_test']
    print(f"Processed test data loaded from {Config.PROCESSED_DATA_PATH}")
except FileNotFoundError:
    print(f"Error: Processed data file not found at {Config.PROCESSED_DATA_PATH}. Run train_baseline.ipynb first.")
    raise
except Exception as e:
    print(f"Error loading processed data: {e}")
    raise

# Create Test DataLoader (needed for env initialization)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=Config.EVAL_BATCH_SIZE, shuffle=False)
print("Test DataLoader created.")

# Initialize and load baseline model
baseline_model = WeatherGRU(
    input_dim=Config.INPUT_DIM,
    hidden_dim=Config.HIDDEN_DIM,
    n_layers=Config.N_LAYERS,
    output_dim=Config.OUTPUT_DIM,
    dropout=Config.DROPOUT # Include dropout for state_dict compatibility
)
try:
    # Load state dict, map location ensures it loads correctly regardless of original device
    baseline_model.load_state_dict(torch.load(Config.BASELINE_MODEL_PATH, map_location=torch.device('cpu')))
    baseline_model.eval() # Set to evaluation mode
    print(f"Baseline model state loaded from {Config.BASELINE_MODEL_PATH}")
except FileNotFoundError:
    print(f"Error: Baseline model file not found at {Config.BASELINE_MODEL_PATH}. Run train_baseline.ipynb first.")
    raise
except Exception as e:
    print(f"Error loading baseline model state: {e}")
    raise

# Define RL Environment

In [None]:
class SustainableAIAgentEnv(gym.Env):
    """Custom Gym environment for optimizing GRU sustainability."""
    metadata = {'render_modes': []} # Required for Gymnasium

    def __init__(self, model, loader):
        super().__init__()
        self.base_model = model.cpu() # Ensure base model is on CPU for consistency
        self.base_model.eval()
        self.loader = loader

        builtin_print("Calculating baseline metrics for RL environment initialization...")
        # Full evaluation once to get baseline metrics
        self.baseline_metrics = evaluate_model(self.base_model, self.loader, Config.CODECARBON_BATCHES)
        # Ensure flops is not zero to avoid division errors in reward/state
        if self.baseline_metrics.get("flops", 0) == 0:
             self.baseline_metrics["flops"] = 1 # Avoid division by zero
             builtin_print("Warning: Baseline FLOPs calculated as 0. Setting to 1 to avoid division errors.")
        builtin_print(f"Baseline Metrics initialized: {self.baseline_metrics}")

        # Define action space: [0..7] = pruning only, [8..15] = pruning + quant
        self.pruning_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
        self.action_space = spaces.Discrete(len(self.pruning_levels) * 2)

        # Define observation space: [current_accuracy, acc_vs_baseline, params_reduction, flops_reduction]
        self.observation_space = spaces.Box(
            low=np.array([0.0, -1.0, 0.0, 0.0], dtype=np.float32),
            high=np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32),
            dtype=np.float32
        )

    def step(self, action):
        terminated = True # Episode ends after one step
        truncated = False # Not using truncation
        info = {}

        try:
            # Decode action
            action_val = action.item() if isinstance(action, torch.Tensor) else action
            pruning_idx = action_val % len(self.pruning_levels)
            pruning_amount = self.pruning_levels[pruning_idx]
            apply_quant = action_val >= len(self.pruning_levels)

            # Apply optimization(s) - these functions return CPU models
            optimized_model = copy.deepcopy(self.base_model) # Start fresh
            if pruning_amount > 0:
                optimized_model = apply_l1_pruning(optimized_model, pruning_amount)
            if apply_quant:
                optimized_model = apply_dynamic_quantization(optimized_model)
            # If action is 0 (no prune, no quant), optimized_model is just the baseline copy

            # Lightweight evaluation on CPU
            metrics = self._evaluate_lightweight(optimized_model)

            # Calculate reward components
            accuracy_delta = metrics["accuracy"] - self.baseline_metrics["accuracy"]
            if metrics["accuracy"] < (self.baseline_metrics["accuracy"] * Config.ACCURACY_PENALTY_THRESHOLD):
                accuracy_reward = Config.ENV_ERROR_REWARD # Heavy penalty if accuracy drops too much
            else:
                accuracy_reward = accuracy_delta * Config.ACCURACY_REWARD_SCALE

            # Ensure baseline flops is not zero before division
            baseline_flops = self.baseline_metrics.get("flops", 1) or 1
            flops_reduction = max(0.0, 1.0 - (metrics["flops"] / baseline_flops)) # Ensure non-negative
            resource_reward = flops_reduction * Config.FLOPS_REWARD_SCALE

            inaction_penalty = Config.INACTION_PENALTY if pruning_amount == 0 and not apply_quant else 0.0

            reward = accuracy_reward + resource_reward + inaction_penalty

            # Calculate state components
            baseline_params = self.baseline_metrics.get("params", 1) or 1
            params_reduction = max(0.0, 1.0 - (metrics["params"] / baseline_params)) # Ensure non-negative

            # Construct observation (next state)
            obs = np.array([
                metrics["accuracy"],
                accuracy_delta,
                params_reduction,
                flops_reduction
            ], dtype=np.float32)
            # Clip observation to defined bounds
            obs = np.clip(obs, self.observation_space.low, self.observation_space.high)

            # Populate info dictionary for logging
            info = {
                "action": action_val,
                "pruning_amount": pruning_amount,
                "quantized": apply_quant,
                "accuracy": metrics["accuracy"],
                "flops_reduction": flops_reduction,
                "params_reduction": params_reduction,
                "reward": reward
            }

            return obs, reward, terminated, truncated, info

        except Exception as e:
            builtin_print(f"Error during env.step with action {action}: {e}")
            # Return initial state, heavy penalty, and terminate
            initial_obs, _ = self.reset()
            return initial_obs, Config.ENV_ERROR_REWARD, terminated, truncated, {"error": str(e)}

    def _evaluate_lightweight(self, model):
        """Quick evaluation on one batch for state/reward. Handles ptflops error."""
        model.eval()
        device = torch.device("cpu") # Lightweight eval always on CPU
        metrics = {"accuracy": 0.0, "flops": self.baseline_metrics.get("flops", 1), "params": self.baseline_metrics.get("params", 1)} # Default to baseline

        try:
            inputs, labels = next(iter(self.loader)) # Get one batch
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.no_grad():
                outputs = model(inputs)
                preds = (outputs > 0.5).float()
                # Use numpy for accuracy calculation here
                accuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
                metrics["accuracy"] = accuracy

            # Try to get FLOPs and Params
            try:
                macs, params = get_model_complexity_info(
                    model, (Config.SEQUENCE_LENGTH, Config.INPUT_DIM),
                    as_strings=False, print_per_layer_stat=False, verbose=False)
                flops = macs * 2
                metrics["flops"] = flops if flops > 0 else 0 # Ensure non-negative
                metrics["params"] = params if params > 0 else 0
            except (RuntimeError, KeyError, AttributeError, TypeError):
                 # Failed (likely quantized): Estimate benefits
                 metrics["flops"] = 0 # Significant reduction signal
                 # Estimate quantized params as 1/4 of baseline float32 params
                 metrics["params"] = (self.baseline_metrics.get("params", 1) or 1) / 4

        except StopIteration:
            builtin_print("Warning: DataLoader is empty in _evaluate_lightweight.")
        except Exception as e:
            builtin_print(f"Error during _evaluate_lightweight: {e}")
            # Keep default metrics (baseline) if evaluation fails

        # Ensure flops and params are non-negative
        metrics["flops"] = max(0.0, metrics["flops"])
        metrics["params"] = max(0.0, metrics["params"])

        return metrics


    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # Initial state: baseline accuracy, 0 delta, 0 reduction
        initial_obs = np.array([
            self.baseline_metrics.get("accuracy", 0.0),
            0.0,
            0.0,
            0.0
        ], dtype=np.float32)
        info = {} # Must return info dict
        return initial_obs, info

# Initialize and Check Environment

In [None]:
try:
    env = SustainableAIAgentEnv(model=baseline_model, loader=test_loader)
    # Check if the environment follows the Gymnasium API
    # check_env(env) # Can sometimes raise warnings even if env is functional, skip for now
    print("\nRL Environment initialized successfully.")
except Exception as e:
    print(f"\nError initializing RL Environment: {e}")
    raise

# Train PPO Agent

In [None]:
# Initialize or load Agent
if os.path.exists(Config.AGENT_SAVE_PATH):
    print(f"\n--- Loading pre-trained Agent from: {Config.AGENT_SAVE_PATH} ---")
    try:
        agent = PPO.load(Config.AGENT_SAVE_PATH, env=env, device=Config.DEVICE)
        print("Agent loaded successfully.")
    except Exception as e:
        print(f"Error loading agent: {e}. Creating a new agent.")
        agent = PPO( "MlpPolicy", env, verbose=0, device=Config.DEVICE, tensorboard_log=Config.TENSORBOARD_LOG_PATH )
        print("New PPO Agent created.")
else:
    print(f"\n--- Agent file not found at {Config.AGENT_SAVE_PATH}. Creating a new PPO Agent. ---")
    agent = PPO( "MlpPolicy", env, verbose=0, device=Config.DEVICE, tensorboard_log=Config.TENSORBOARD_LOG_PATH )
    print("New PPO Agent created.")

# --- Training Loop ---
experiment_history = [] # To store results from each step

# Callback function to log results
def history_callback(local_vars, global_vars):
    """Callback to store step information."""
    if 'infos' in local_vars:
        for info in local_vars['infos']:
            if info and "action" in info: # Check if info is not empty and has key data
                experiment_history.append(info.copy())
    return True # Continue training

timesteps_trained_so_far = agent.num_timesteps if hasattr(agent, 'num_timesteps') else 0

if timesteps_trained_so_far >= Config.TOTAL_TIMESTEPS:
    print(f"\n--- Agent already trained for {timesteps_trained_so_far} timesteps. Skipping training. ---")
else:
    remaining_timesteps = Config.TOTAL_TIMESTEPS - timesteps_trained_so_far
    chunks_to_run = int(np.ceil(remaining_timesteps / Config.TIMESTEPS_PER_CHUNK))

    print(f"\n--- Starting/Resuming Agent Training ---")
    print(f"Target: {Config.TOTAL_TIMESTEPS} steps | Already trained: {timesteps_trained_so_far} steps.")
    print(f"Will run {chunks_to_run} more chunk(s) of {Config.TIMESTEPS_PER_CHUNK} steps each.")

    for i in range(chunks_to_run):
        current_chunk = i + 1
        steps_this_chunk = min(Config.TIMESTEPS_PER_CHUNK, Config.TOTAL_TIMESTEPS - agent.num_timesteps)
        if steps_this_chunk <= 0: break # Should not happen with ceil, but safety check

        print(f"\n--- Running Chunk {current_chunk}/{chunks_to_run} ({steps_this_chunk} steps) ---")

        try:
            agent.learn(
                total_timesteps=steps_this_chunk,
                reset_num_timesteps=False, # Continue counting total steps
                progress_bar=True,
                callback=history_callback
            )

            # Save agent after each chunk
            agent.save(Config.AGENT_SAVE_PATH)
            current_total_steps = agent.num_timesteps
            print(f"Agent saved. Total timesteps trained: {current_total_steps}.")

        except Exception as e:
            print(f"Error during training chunk {current_chunk}: {e}")
            print("Attempting to save agent state before exiting...")
            try:
                agent.save(Config.AGENT_SAVE_PATH + "_error")
                print(f"Agent state saved to {Config.AGENT_SAVE_PATH}_error")
            except Exception as save_e:
                print(f"Could not save agent state after error: {save_e}")
            break # Stop training loop on error

        # Check if target reached
        if agent.num_timesteps >= Config.TOTAL_TIMESTEPS:
            print(f"\n--- 🎉 Target of {Config.TOTAL_TIMESTEPS} timesteps reached. ---")
            break

print("\n--- Agent Training Loop Finished ---")

# Analyze Training History and Find Best Action

In [None]:
best_solution_info_agent = None
best_action_agent = None

if experiment_history:
    history_df = pd.DataFrame(experiment_history)
    print("\n--- Agent Experiment History (Sample): ---")
    print(history_df.head().to_markdown(index=False)) # Display first few rows nicely
    print("...")
    print(history_df.tail().to_markdown(index=False)) # Display last few rows

    # Find the step with the highest reward in the history
    best_step = history_df.loc[history_df['reward'].idxmax()]

    best_solution_info_agent = {
        'pruning_amount': best_step['pruning_amount'],
        'quantized': best_step['quantized'],
        'accuracy_at_step': best_step['accuracy'], # Info from lightweight eval
        'reward_at_step': best_step['reward']
    }
    best_action_agent = int(best_step['action']) # Ensure it's an integer

    print("\n" + "="*50)
    print("BEST SOLUTION FOUND DURING TRAINING (based on reward)")
    print("="*50)
    print(f"Action: {best_action_agent}")
    print(f"Pruning: {best_solution_info_agent['pruning_amount']*100:.0f}%")
    print(f"Quantized: {best_solution_info_agent['quantized']}")
    print(f"Achieved Reward: {best_solution_info_agent['reward_at_step']:.4f}")
    print(f"(Lightweight Eval Accuracy at that step: {best_solution_info_agent['accuracy_at_step']:.4f})")
    print("="*50)

else:
    print("\nNo experiment history recorded. Predicting best action from final agent policy.")
    try:
        obs, _ = env.reset()
        action_pred, _ = agent.predict(obs, deterministic=True)
        best_action_agent = action_pred.item()

        # Decode the predicted action
        pruning_idx_pred = best_action_agent % len(env.pruning_levels)
        pruning_amount_pred = env.pruning_levels[pruning_idx_pred]
        quantized_pred = best_action_agent >= len(env.pruning_levels)

        best_solution_info_agent = {
            'pruning_amount': pruning_amount_pred,
            'quantized': quantized_pred
            # No reward/accuracy info available directly from predict
        }

        print("\n" + "="*50)
        print("BEST ACTION PREDICTED BY FINAL AGENT POLICY 🤖")
        print("="*50)
        print(f"Action: {best_action_agent}")
        print(f"Pruning: {best_solution_info_agent['pruning_amount']*100:.0f}%")
        print(f"Quantized: {best_solution_info_agent['quantized']}")
        print("="*50)

    except Exception as e:
        print(f"Error predicting action from final policy: {e}")
        print("Cannot determine best action.")

# Save the best action info to JSON (optional but useful for evaluate notebook)
if best_solution_info_agent:
    try:
        with open(Config.BEST_ACTION_SAVE_PATH, 'w') as f:
            # Convert numpy types if necessary before saving
            serializable_info = {k: (v.item() if isinstance(v, np.generic) else v) for k, v in best_solution_info_agent.items()}
            json.dump(serializable_info, f, indent=4)
        print(f"\nBest action info saved to {Config.BEST_ACTION_SAVE_PATH}")
    except Exception as e:
        print(f"\nError saving best action info: {e}")