## <center>CSE 546: Reinforcement Learning</center>
### <center>Prof. Alina Vereshchaka</center>
#### <center>Fall 2025</center>

# Welcome to the Bonus: Firecast RL

## IMPORTS & SETUP

In [4]:
import gymnasium as gym
import numpy as np
# Adjust this import based on your folder structure, or pass the helper module if needed
from firecastrl_env.envs.environment import helper 

class MultiAgentRewardWrapper(gym.Wrapper):
    def __init__(self, env, mode="cooperative"):
        super().__init__(env)
        self.mode = mode.lower()
        if self.mode not in ["cooperative", "competitive"]:
            raise ValueError("Mode must be 'cooperative' or 'competitive'")
        
    def step(self, action):
        # Run the environment step normally
        obs, original_reward, terminated, truncated, info = self.env.step(action)
        
        # Override the reward based on the selected mode
        if self.mode == "cooperative":
            new_reward = self._calculate_cooperative(info, obs)
        else:
            new_reward = self._calculate_competitive(info)
            
        return obs, float(new_reward), terminated, truncated, info

    def _calculate_cooperative(self, info, obs):
        """
        GLOBAL GOAL: Minimize total fire damage.
        - High penalty for existing fire (Fear of spread).
        - Reward for extinguishing.
        """
        curr_burning = info['cells_burning']
        total_extinguished = obs['quenched_cells'][0]
        
        reward = 0.0
        # 1. Team Achievement: Extinguish fires
        reward += 10.0 * total_extinguished
        
        # 2. Team Penalty: The existence of fire anywhere is bad
        reward -= 0.1 * curr_burning  # Strong pressure to contain spread
        
        # 3. Wasted Water Check (Still needed so they learn to aim)
        reward -= self._calculate_wasted_water_penalty()
        
        return np.clip(reward, -50.0, 50.0)

    def _calculate_competitive(self, info):
        """
        INDIVIDUAL/GREEDY GOAL: Maximize personal score.
        - No penalty for fire spread (Don't care about the forest).
        - Only care about hitting targets and not wasting ammo.
        """
        # Note: For a Centralized Agent, this is the "Sum of Greedy Objectives"
        reward = 0.0
        
        # 1. We need to recalculate extinguishing based on individual hits if possible, 
        # but since 'quenched_cells' is aggregated, we use the aggregate + strict local penalties.
        # Ideally, we trust the env's 'quenched_cells' is the sum of valid hits.
        
        # We rely heavily on the 'Wasted Water' penalty to define the greedy behavior.
        # If they hit: +10. If they miss: -2. If they ignore fire: 0 penalty (unlike cooperative).
        
        wasted_penalty = self._calculate_wasted_water_penalty()
        
        # If they didn't waste water, did they actually hit something?
        # We infer hits from total_extinguished (which is passed in info/obs usually, but let's grab from state)
        total_extinguished = self.env.unwrapped.state['quenched_cells'][0]
        
        reward += 10.0 * total_extinguished
        reward -= wasted_penalty
        
        # CRITICAL DIFFERENCE: NO PENALTY for 'curr_burning'. 
        # The agent feels no pressure if the fire is growing, only pressure to get points.
        
        return np.clip(reward, -50.0, 50.0)

    def _calculate_wasted_water_penalty(self):
        """Iterate through agents to find who missed."""
        penalty = 0.0
        
        # Access the internal state of the environment
        # Note: We use env.unwrapped to bypass any other wrappers
        base_env = self.env.unwrapped
        
        for i in range(base_env.num_agents):
            last_act = base_env.state['last_action'][i]
            hx, hy = base_env.state['helicopter_coord'][i]
            
            if last_act == 4: # Attempted Drop
                # Check what is at this location
                cell_idx = helper.get_grid_index_for_location(hx, hy, base_env.gridWidth)
                cell = base_env.cells[cell_idx]
                
                # If dropping on non-burning cell -> Wasted Water
                # Assuming FireState.Burning is 1 (Check your enums.py to be sure!)
                if cell.fireState != 1: 
                    penalty += 2.0
                    
        return penalty

In [5]:
from firecastrl_env.envs.wildfire_env import WildfireEnv
# from multi_agent_wrappers import MultiAgentRewardWrapper # If in separate file

# 1. Init Base Env
raw_env = WildfireEnv(num_agents=3)

# 2. Apply Safety Wrapper (Fixes Infinity Bug)
safe_env = SafeWildfireWrapper(raw_env)

# 3. Apply Reward Strategy
coop_env = MultiAgentRewardWrapper(safe_env, mode="cooperative")

# 4. Train
model_coop = PPO("MultiInputPolicy", coop_env, verbose=1)
model_coop.learn(total_timesteps=100_000)
model_coop.save("ppo_fire_squad_coop")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 694       |
|    ep_rew_mean     | -3.56e+03 |
| time/              |           |
|    fps             | 21        |
|    iterations      | 1         |
|    time_elapsed    | 93        |
|    total_timesteps | 2048      |
----------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 677         |
|    ep_rew_mean          | -3.63e+03   |
| time/                   |             |
|    fps                  | 20          |
|    iterations           | 2           |
|    time_elapsed         | 197         |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008162169 |
|    clip_fraction        | 0.0499      |
|    clip_range           | 0.2         |
|    entro