# RL Policy Exploration Notebook

This notebook provides an interactive environment to load, visualize, and analyze the behavior of trained Reinforcement Learning agents.

In [None]:
import ray
from ray.rllib.algorithms.algorithm import Algorithm
import yaml
import gymnasium as gym
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import os

from src.fleet_simulator.simpy_delivery_environment import SimpyDeliveryEnvironment
from src.graph_routing_engine.astar_optimization_logic import AStarRouting
from src.agent_brains.multi_agent_rl_policy import MultiAgentPolicyContainer
from src.feature_forge.graph_embedding_features import GraphEmbeddingFeatures

# Ensure Ray is initialized
if not ray.is_initialized():
    ray.init(ignore_reinit_error=True)

## 1. Configuration and Setup

In [None]:
CONFIG_PATH = 'conf/rl_agent_params.yaml'
ENV_CONFIG_PATH = 'conf/environments/dev.yaml'
ROUTING_CONFIG_PATH = 'conf/routing_engine_config.yaml'
OSM_PROCESSING_CONFIG_PATH = 'conf/osm_processing_config.yaml'
SCENARIO_PATH = 'data_nexus/simulation_scenarios/tehran_fleet_scenario.pkl'
RL_CHECKPOINT_DIR = 'rl_model_registry/latest_checkpoint'
GNN_MODEL_PATH = 'rl_model_registry/gcn_model.pth'

with open(CONFIG_PATH, 'r') as f:
    rl_agent_config = yaml.safe_load(f)

with open(OSM_PROCESSING_CONFIG_PATH, 'r') as f:
    osm_config = yaml.safe_load(f)

graph_path = osm_config['graph_serialization']['output_path']
graph = nx.read_gml(graph_path)

router = AStarRouting(graph, ROUTING_CONFIG_PATH)
gnn_embedder = GraphEmbeddingFeatures(CONFIG_PATH, graph_path)
gnn_embedder.load_model_weights(GNN_MODEL_PATH)

class MockRLlibEnv(gym.Env):
    def __init__(self, env_config):
        self.action_space = gym.spaces.Discrete(5)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(128,))
    def reset(self, seed=None, options=None):
        return self.observation_space.sample(), {}
    def step(self, action):
        return self.observation_space.sample(), 0.1, False, False, {}

## 2. Load RL Agent Policy

In [None]:
class RLLibPolicyLoader(Algorithm):
    def setup(self, config):
        super().setup(config)
        # The actual policy objects are stored in the `live_policies` of the worker
        # For inference, you often directly load and use the `Policy` objects
        # However, for a quick hack, we can mock a trainer and restore it.
        # A cleaner way involves `rllib.policy.policy.Policy.from_checkpoint`

        from ray.rllib.algorithms.ppo import PPOConfig
        from ray.tune.registry import register_env
        register_env("mock_env", MockRLlibEnv)

        mock_config = (
            PPOConfig()
            .environment("mock_env")
            .framework("torch")
            .multi_agent(
                policies={
                    "driver_policy": (
                        None,
                        gym.spaces.Box(low=-np.inf, high=np.inf, shape=rl_agent_config['multi_agent_config']['policies']['driver_policy']['obs_space']),
                        gym.spaces.Discrete(rl_agent_config['multi_agent_config']['policies']['driver_policy']['action_space'][0]),
                        {}
                    ),
                    "fleet_manager_policy": (
                        None,
                        gym.spaces.Box(low=-np.inf, high=np.inf, shape=rl_agent_config['multi_agent_config']['policies']['fleet_manager_policy']['obs_space']),
                        gym.spaces.Discrete(rl_agent_config['multi_agent_config']['policies']['fleet_manager_policy']['action_space'][0]),
                        {}
                    ),
                },
                policy_mapping_fn=MultiAgentPolicyContainer.map_agent_to_policy,
            )
        )
        self._trainer_for_restore = mock_config.build()
        if os.path.exists(RL_CHECKPOINT_DIR):
            self._trainer_for_restore.restore(RL_CHECKPOINT_DIR)
            print(f"RL Agent restored from {RL_CHECKPOINT_DIR}")
        else:
            print(f"Warning: RL Agent checkpoint not found at {RL_CHECKPOINT_DIR}. Using untrained policy.")

    def compute_single_action(self, obs, policy_id, explore):
        return self._trainer_for_restore.compute_single_action(obs, policy_id=policy_id, explore=explore)

rl_inferer_for_sim = RLLibPolicyLoader({"env": "mock_env"})


## 3. Simulate and Analyze Policy Behavior

In [None]:
sim_env = SimpyDeliveryEnvironment(ENV_CONFIG_PATH, SCENARIO_PATH, router, rl_inferer_for_sim)
metrics_df = sim_env.run_simulation(until=1800) # Run for 30 minutes simulated time

plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
sns.lineplot(x='time', y='num_delivered_orders', data=metrics_df, label='Delivered')
sns.lineplot(x='time', y='num_pending_orders', data=metrics_df, label='Pending')
sns.lineplot(x='time', y='num_in_transit_orders', data=metrics_df, label='In Transit')
plt.title('Order Status Over Time')
plt.xlabel('Simulation Time (s)')
plt.ylabel('Number of Orders')
plt.legend()

plt.subplot(1, 2, 2)
sns.lineplot(x='time', y='total_driver_distance', data=metrics_df, label='Total Distance')
sns.lineplot(x='time', y='total_driver_time', data=metrics_df, label='Total Time')
plt.title('Driver Cumulative Metrics')
plt.xlabel('Simulation Time (s)')
plt.ylabel('Value')
plt.legend()
plt.tight_layout()
plt.show()

print("\nFinal Simulation Metrics:")
print(metrics_df.iloc[-1])

## 4. Individual Agent Action Visualization (Conceptual)

This section would typically involve sampling observations from the environment and feeding them to the policy to see the predicted actions. It's difficult to visualize without a concrete environment interaction loop here.

In [None]:
# Example: Simulate an observation for a driver and get action
dummy_driver_obs = np.random.rand(128).astype(np.float32)
driver_action = rl_inferer_for_sim.compute_single_action(dummy_driver_obs, policy_id="driver_policy", explore=False)
print(f"Sample Driver Action: {driver_action}")

dummy_dispatcher_obs = np.random.rand(256).astype(np.float32)
dispatcher_action = rl_inferer_for_sim.compute_single_action(dummy_dispatcher_obs, policy_id="fleet_manager_policy", explore=False)
print(f"Sample Dispatcher Action: {dispatcher_action}")

# Further analysis would involve mapping actions back to meaningful decisions (e.g., node IDs, route segments)