In [1]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO  # <--- THIS WAS MISSING
from firecastrl_env.envs.wildfire_env import WildfireEnv
from firecastrl_env.envs.environment import helper 

# --- 1. Define Wrapper (Must be re-defined if kernel was restarted) ---
class SafeWildfireWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
    def observation(self, obs):
        if 'cells' in obs:
            obs['cells'] = np.nan_to_num(obs['cells'], posinf=-1.0)
        return obs

class MultiAgentRewardWrapper(gym.Wrapper):
    def __init__(self, env, mode="cooperative"):
        super().__init__(env)
        self.mode = mode.lower()
        
    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)
        if self.mode == "cooperative":
            new_reward = self._calculate_cooperative(info, obs)
        else:
            new_reward = self._calculate_competitive(info)
        return obs, float(new_reward), terminated, truncated, info

    def _calculate_cooperative(self, info, obs):
        curr_burning = info['cells_burning']
        total_extinguished = obs['quenched_cells'][0]
        reward = 10.0 * total_extinguished
        reward -= 0.1 * curr_burning
        reward -= self._calculate_wasted_water_penalty()
        return np.clip(reward, -50.0, 50.0)

    def _calculate_competitive(self, info):
        # GREEDY LOGIC: No penalty for burning cells
        wasted_penalty = self._calculate_wasted_water_penalty()
        total_extinguished = self.env.unwrapped.state['quenched_cells'][0]
        reward = 10.0 * total_extinguished
        reward -= wasted_penalty
        return np.clip(reward, -50.0, 50.0)

    def _calculate_wasted_water_penalty(self):
        penalty = 0.0
        base_env = self.env.unwrapped
        for i in range(base_env.num_agents):
            last_act = base_env.state['last_action'][i]
            hx, hy = base_env.state['helicopter_coord'][i]
            if last_act == 4:
                cell_idx = helper.get_grid_index_for_location(hx, hy, base_env.gridWidth)
                cell = base_env.cells[cell_idx]
                if cell.fireState != 1: 
                    penalty += 2.0
        return penalty

# --- 2. Train COMPETITIVE (GREEDY) ---
print("ðŸ’° Starting COMPETITIVE/GREEDY Training...")
raw_env = WildfireEnv(num_agents=3)
safe_env = SafeWildfireWrapper(raw_env)

# CHANGE MODE HERE
comp_env = MultiAgentRewardWrapper(safe_env, mode="competitive")

model_comp = PPO("MultiInputPolicy", comp_env, verbose=1)
model_comp.learn(total_timesteps=100_000)
model_comp.save("ppo_fire_squad_greedy")
print("âœ… Greedy Model Saved!")

ðŸ’° Starting COMPETITIVE/GREEDY Training...
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 417      |
|    ep_rew_mean     | -491     |
| time/              |          |
|    fps             | 12       |
|    iterations      | 1        |
|    time_elapsed    | 158      |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 468         |
|    ep_rew_mean          | -477        |
| time/                   |             |
|    fps                  | 12          |
|    iterations           | 2           |
|    time_elapsed         | 319         |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.020205367 |
|    clip_fraction        | 0.228       |
|    clip_range  