In [11]:
from typing import Any, Dict
import gymnasium as gym
import numpy as np
from luxai_s3.wrappers import LuxAIS3GymEnv
import gymnasium as gym
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
import numpy as np
from typing import Optional, Dict, Any

In [12]:
class LuxAIWrapper(gym.Wrapper):
    def __init__(self, env: LuxAIS3GymEnv, player_id: int = 0):
        """
        A comprehensive wrapper for the LuxAI environment that handles both:
        1. Stable Baselines compatibility
        2. Reward calculations
        
        Args:
            env: The base LuxAI environment
            player_id: Which player this agent controls (0 or 1)
        """
        super().__init__(env)
        self.player_id = player_id
        
        # Set up state tracking for dense rewards
        self.previous_total_energy = None
        self.previous_relic_points = None
        self.previous_unit_count = None
        
        # Set up the observation and action spaces
        self._setup_spaces()


    def _setup_spaces(self):
        """
        Configures the observation and action spaces for compatibility with 
        Stable Baselines. We use Box spaces for both to ensure compatibility
        with standard neural network policies.
        """
        # Action space configuration:
        # - Each unit (16 total) needs 3 values:
        #   1. Action type (0-5)
        #   2. Sap x coordinate (-4 to 4)
        #   3. Sap y coordinate (-4 to 4)
        self.action_space = gym.spaces.Box(
            low=np.array([0, -4, -4] * 16),
            high=np.array([5, 4, 4] * 16),
            dtype=np.int32
        )

        # Observation space configuration:
        # Using a flat vector representation of the game state
        total_obs_size = 1862  # Total size of flattened observation
        self.observation_space = gym.spaces.Box(
            low=-float('inf'),
            high=float('inf'),
            shape=(total_obs_size,),
            dtype=np.float32
        )


    def calculate_reward(self, obs : Dict[str, Any]) -> float:
        """
        Calculates a reward signal based on multiple
        gameplay factors.
        
        Args:
            obs: Dictionary containing the current observation
            
        Returns:
            float: The caluclated reward
        """
        current_reward = 0.0
        # TODO: Maybe extract more from environment
        # Extract current state information
        unit_energies = obs['units']['energy'].reshape(-1)
        unit_mask = obs['units_mask'].reshape(-1)
        team_points = obs['team_points']
        
        # Calculate current metrics
        total_energy = np.sum(unit_energies * unit_mask)
        current_unit_count = np.sum(unit_mask)
        current_relic_points = team_points[self.player_id]
        # TODO: Need to tinker with these rewards
        # --Energy management reward--
        # Reward for maintaining and increasing total energy across all units
        if self.previous_total_energy is not None:
            energy_diff = total_energy - self.previous_total_energy
            current_reward += 0.001 * energy_diff
        
        # --Unit survival reward--
        # Significant reward for maintaining and growing the unit count
        if self.previous_unit_count is not None:
            unit_diff = current_unit_count - self.previous_unit_count
            current_reward += 0.1 * unit_diff
        
        # --Relic control reward--
        # Major reward for controlling relic points as it's a key victory condition
        if self.previous_relic_points is not None:
            points_diff = current_relic_points - self.previous_relic_points
            current_reward += 0.2 * points_diff
        
        # --Unit health reward--
        # Small reward for maintaining healthy units
        alive_units = unit_energies[unit_mask > 0]
        if len(alive_units) > 0:
            avg_unit_energy = np.mean(alive_units)
            current_reward += 0.0005 * avg_unit_energy
                
        # Update state tracking for next step
        self.previous_total_energy = total_energy
        self.previous_unit_count = current_unit_count
        self.previous_relic_points = current_relic_points
        
        return current_reward

    def _flatten_observation(self, obs:  Dict[str, Any]) -> Any:
        """
        Converts the dictionary observation into a flat array for the neural network.
        This standardizes the input format for the learning algorithm.
        """
        components = []
        
        # Process unit information
        unit_positions = obs['units']['position'].reshape(-1)
        unit_energies = obs['units']['energy'].reshape(-1)
        components.extend([
            unit_positions.astype(np.float32),
            unit_energies.astype(np.float32)
        ])
        
        # Process map information
        units_mask = obs['units_mask'].reshape(-1)
        sensor_mask = obs['sensor_mask'].reshape(-1)
        map_energy = obs['map_features']['energy'].reshape(-1)
        map_tiles = obs['map_features']['tile_type'].reshape(-1)
        components.extend([
            units_mask.astype(np.float32),
            sensor_mask.astype(np.float32),
            map_energy.astype(np.float32),
            map_tiles.astype(np.float32)
        ])
        
        # Process game state information
        components.extend([
            obs['team_points'].astype(np.float32),
            obs['team_wins'].astype(np.float32),
            np.array([obs['steps']], dtype=np.float32),
            np.array([obs['match_steps']], dtype=np.float32)
        ])
        
        return np.concatenate(components)

    def _unflatten_observation(self, obs) -> Dict[str, Any]:
        """
        Converts a flat observation back into dictionary format for reward calculation.
        """
        map_size = 24
        max_units = 16
        idx = 0
        
        obs_dict = {}
        
        # Reconstruct unit information
        obs_dict['units'] = {
            'position': obs[idx:idx + max_units * 2].reshape(max_units, 2),
            'energy': obs[idx + max_units * 2:idx + max_units * 3].reshape(max_units, 1)
        }
        idx += max_units * 3
        
        # Reconstruct map information
        obs_dict['units_mask'] = obs[idx:idx + max_units].reshape(max_units)
        idx += max_units
        
        obs_dict['sensor_mask'] = obs[idx:idx + map_size * map_size].reshape(map_size, map_size)
        idx += map_size * map_size
        
        obs_dict['map_features'] = {
            'energy': obs[idx:idx + map_size * map_size].reshape(map_size, map_size),
            'tile_type': obs[idx + map_size * map_size:idx + 2 * map_size * map_size].reshape(map_size, map_size)
        }
        idx += 2 * map_size * map_size
        
        # Reconstruct game state information
        obs_dict['team_points'] = obs[idx:idx + 2]
        idx += 2
        obs_dict['team_wins'] = obs[idx:idx + 2]
        idx += 2
        obs_dict['steps'] = obs[idx]
        obs_dict['match_steps'] = obs[idx + 1]
        
        return obs_dict

    def step(self, action: Dict[str, list])->tuple:
        """
        Take a step in the environment with the provided action and return the updated state.

        Args:
            action (Dict[str, list]): The action to be taken by the player.
        Returns:
            tuple: A tuple containing the following:
                - obs (np.ndarray): The flattened observation for the current player after the step.
                - total_reward (float): The total reward for the current player, combining environment and additional rewards.
                - terminated (bool): Whether the episode has ended for the current player.
                - truncated (bool): Whether the episode was truncated for the current player.
                - info (dict): Additional information from the environment.
        """

       
        # Convert and reshape action for the environment
        shaped_action = {
            f'player_{self.player_id}': action.reshape(16, 3).astype(np.int32),
            f'player_{1-self.player_id}': np.zeros((16, 3), dtype=np.int32)
        }
        
        # Take step in environment
        obs, rewards, terminated, truncated, info = self.env.step(shaped_action)
        
        # Get observation for our player
        player_obs = obs[f'player_{self.player_id}']
        
        # Calculate reward
        original_reward = rewards[f'player_{self.player_id}']
        added_reward = self.calculate_reward(player_obs)
        
        # Combine rewards
        total_reward = original_reward + added_reward

        # Convert observation to flat format
        obs = self._flatten_observation(player_obs)
        
        return obs, total_reward, terminated[f'player_{self.player_id}'], truncated[f'player_{self.player_id}'], info

    
    
    
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None)-> tuple:
        """
        Resets the environment and initializes tracking variables for rewards.

        Args:
            seed (Optional[int]): An optional seed for resetting the environment's random state.
            options (Optional[Dict[str, Any]]): Optional dictionary containing additional reset parameters.

        Returns:
            tuple: A tuple containing:
                - dict: The observation for the specified player.
                - dict: Additional information about the environment after reset.

        """
        # Reset reward tracking
        self.previous_total_energy = None
        self.previous_relic_points = None
        self.previous_unit_count = None
        
        # Reset environment
        obs, info = self.env.reset(seed=seed, options=options)
        player_obs = obs[f'player_{self.player_id}']
        obs = self._flatten_observation(player_obs)
        
        return obs, info

In [None]:
# Create environment
base_env = LuxAIS3GymEnv(numpy_output=True)
wrapped_env = LuxAIWrapper(base_env, player_id=0)
vec_env = DummyVecEnv([lambda: wrapped_env])

# Create and train model with additional parameters to handle discrete actions
model = PPO(
    "MlpPolicy",
    vec_env,
    verbose=1,
    learning_rate=3e-4, 
    n_steps=4096,
    batch_size=128,
    n_epochs=15,
    gamma=0.99,
    policy_kwargs=dict(
        net_arch=dict(
            pi=[256, 256],
            vf=[256, 256]
        )
    ),
    device='cpu'  # Use CPU as recommended for MLP
)


model.learn(total_timesteps=1000000)

Using cpu device
-----------------------------
| time/              |      |
|    fps             | 206  |
|    iterations      | 1    |
|    time_elapsed    | 19   |
|    total_timesteps | 4096 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 197         |
|    iterations           | 2           |
|    time_elapsed         | 41          |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.015079102 |
|    clip_fraction        | 0.177       |
|    clip_range           | 0.2         |
|    entropy_loss         | -68.1       |
|    explained_variance   | 0.018       |
|    learning_rate        | 3e-05       |
|    loss                 | 42.2        |
|    n_updates            | 15          |
|    policy_gradient_loss | -0.0564     |
|    std                  | 1           |
|    value_loss           | 107         |
-----------------

KeyboardInterrupt: 

In [8]:
# Save the model
from stable_baselines3.common.evaluation import evaluate_policy
try:
    model.save("lux_ai_model")
    print("Model saved successfully!")
except Exception as e:
    print(f"Error saving model: {e}")

Model saved successfully!


In [9]:
mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=10, deterministic=True)

# Print results
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")



Mean reward: 324.20 +/- 247.31
