In [2]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from firecastrl_env.envs.wildfire_env import WildfireEnv

# --- Re-define Wrapper for Evaluation Context ---
# (You can also import this if it's in a separate file)
class SafeWildfireWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
    def observation(self, obs):
        if 'cells' in obs:
            obs['cells'] = np.nan_to_num(obs['cells'], posinf=-1.0)
        return obs

def evaluate_agent(model_path, agent_name, num_episodes=5):
    print(f"========================================")
    print(f"üî• EVALUATING: {agent_name}")
    print(f"========================================")
    
    # 1. Setup Env (Mode doesn't matter here, we just want physics)
    # We use 3 agents because that's how you trained them
    raw_env = WildfireEnv(num_agents=3) 
    env = SafeWildfireWrapper(raw_env)
    
    # 2. Load Model
    try:
        model = PPO.load(model_path)
    except FileNotFoundError:
        print(f"‚ùå Could not find model file: {model_path}.zip")
        return

    total_burnt = []
    total_reward = []

    for i in range(num_episodes):
        obs, _ = env.reset()
        done = False
        episode_reward = 0
        
        while not done:
            # Deterministic=True ensures we see the agent's BEST behavior, not random exploration
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)
            episode_reward += reward
            done = terminated or truncated
            
        burnt = info['cells_burnt']
        total_burnt.append(burnt)
        total_reward.append(episode_reward)
        
        print(f"  Episode {i+1}: Burnt Cells = {burnt} | Reward = {episode_reward:.2f}")

    avg_burnt = np.mean(total_burnt)
    print(f"----------------------------------------")
    print(f"‚úÖ FINAL RESULT - {agent_name}")
    print(f"   Average Cells Burnt: {avg_burnt:.2f}")
    print(f"   Average Reward:      {np.mean(total_reward):.2f}")
    print(f"========================================\n")
    
    env.close()
    return avg_burnt

# --- RUN THE COMPARISON ---

# 1. Evaluate Cooperative (Experiment A)
# Ensure "ppo_fire_squad_coop.zip" is in your folder
score_coop = evaluate_agent("ppo_fire_squad_coop", "Cooperative Squad")

# 2. Evaluate Greedy (Experiment B)
# Only run this after train_greedy finishes and saves "ppo_fire_squad_greedy.zip"
# score_greedy = evaluate_agent("ppo_fire_squad_greedy", "Greedy Squad")

# Comparison Logic (Uncomment when both are done)
# if score_coop and score_greedy:
#     diff = score_greedy - score_coop
#     print(f"CONCLUSON: Cooperative strategy saved {diff:.2f} more trees on average!")

üî• EVALUATING: Cooperative Squad
  Episode 1: Burnt Cells = 312 | Reward = -383.91
  Episode 2: Burnt Cells = 1198 | Reward = -1287.19
  Episode 3: Burnt Cells = 871 | Reward = -956.33
  Episode 4: Burnt Cells = 871 | Reward = -956.33
  Episode 5: Burnt Cells = 871 | Reward = -956.33
----------------------------------------
‚úÖ FINAL RESULT - Cooperative Squad
   Average Cells Burnt: 824.60
   Average Reward:      -908.02

