In [5]:
import json
import os
from typing import Any, Dict, Tuple
import flax
import flax.serialization
import gymnasium as gym
import numpy as np
import dataclasses
from luxai_s3.env import LuxAIS3Env
from luxai_s3.wrappers import LuxAIS3GymEnv
from luxai_s3.params import EnvParams, env_params_ranges
from luxai_s3.state import serialize_env_actions, serialize_env_states
from luxai_s3.utils import to_numpy
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
import numpy as np
import jax
import jax.numpy as jnp
import dataclasses
from typing import Callable, Optional, Dict, Any
from luxai_s3.env import LuxAIS3Env
from luxai_s3.params import EnvParams, env_params_ranges
from stable_baselines3.common.callbacks import BaseCallback

In [6]:
class StableBaselinesLuxAI(gym.Wrapper):
    def __init__(self, env: LuxAIS3GymEnv, player_id: int = 0):
        super().__init__(env)
        self.player_id = player_id
        self._setup_spaces()

    def _setup_spaces(self):
        """
        Sets up the observation and action spaces. The critical change here is
        that we're using MultiDiscrete for actions to ensure proper integer handling.
        """
        # For the action space, we need 16 units, each with 3 values:
        # - First value: action type (0-5)
        # - Second value: sap x coordinate (-4 to 4)
        # - Third value: sap y coordinate (-4 to 4)
        self.action_space = gym.spaces.Box(
            low=np.array([0, -4, -4] * 16),
            high=np.array([5, 4, 4] * 16),
            dtype=np.int32  # Changed to int32 to ensure integer actions
        )

        # Calculate observation space size
        map_size = 24  # The game map is 24x24
        max_units = 16 
        total_obs_size = 1862  # Using the actual size we observed

        self.observation_space = gym.spaces.Box(
            low=-float('inf'),
            high=float('inf'),
            shape=(total_obs_size,),
            dtype=np.float32
        )

    def step(self, action):
        """
        Takes a step in the environment with careful type handling for actions.
        The key change is proper action type conversion.
        """
        # Convert action to integer type and proper shape
        action = np.round(action).astype(np.int32)  # Round and convert to int
        shaped_action = {
            f'player_{self.player_id}': action.reshape(16, 3),
            f'player_{1-self.player_id}': np.zeros((16, 3), dtype=np.int32)
        }
        
        # Take step in environment
        obs, rewards, terminated, truncated, info = self.env.step(shaped_action)
        
        # Process observation and other returns
        player_obs = obs[f'player_{self.player_id}']
        flat_obs = self._flatten_observation(player_obs)
        player_reward = rewards[f'player_{self.player_id}']
        player_terminated = terminated[f'player_{self.player_id}']
        player_truncated = truncated[f'player_{self.player_id}']


        
        return flat_obs, player_reward, player_terminated, player_truncated, info

    def _flatten_observation(self, obs_dict):
        """
        Converts dictionary observation to flat array with proper type handling.
        """
        components = []
        
        # Process unit data
        unit_positions = obs_dict['units']['position'].reshape(-1)
        unit_energies = obs_dict['units']['energy'].reshape(-1)
        components.extend([
            unit_positions.astype(np.float32),
            unit_energies.astype(np.float32)
        ])
        
        # Process map data
        units_mask = obs_dict['units_mask'].reshape(-1)
        sensor_mask = obs_dict['sensor_mask'].reshape(-1)
        map_energy = obs_dict['map_features']['energy'].reshape(-1)
        map_tiles = obs_dict['map_features']['tile_type'].reshape(-1)
        components.extend([
            units_mask.astype(np.float32),
            sensor_mask.astype(np.float32),
            map_energy.astype(np.float32),
            map_tiles.astype(np.float32)
        ])
        
        # Process game state data
        components.extend([
            obs_dict['team_points'].astype(np.float32),
            obs_dict['team_wins'].astype(np.float32),
            np.array([obs_dict['steps']], dtype=np.float32),
            np.array([obs_dict['match_steps']], dtype=np.float32)
        ])
        
        return np.concatenate(components)

    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None):
        """Reset the environment with proper observation flattening."""
        obs, info = self.env.reset(seed=seed, options=options)
        player_obs = obs[f'player_{self.player_id}']
        flat_obs = self._flatten_observation(player_obs)
        return flat_obs, info

In [7]:


# Create environment
base_env = LuxAIS3GymEnv(numpy_output=True)
wrapped_env = StableBaselinesLuxAI(base_env, player_id=0)
vec_env = DummyVecEnv([lambda: wrapped_env])

# Create and train model with additional parameters to handle discrete actions
model = PPO(
    "MlpPolicy",
    vec_env,
    verbose=1,
    learning_rate=3e-5,
    n_steps=4096,
    batch_size=128,
    n_epochs=15,
    gamma=0.99,
    policy_kwargs=dict(
        net_arch=dict(
            pi=[256, 256],
            vf=[256, 256]
        )
    ),
    device='cpu'  # Use CPU as recommended for MLP
)


model.learn(total_timesteps=1000000)



Using cpu device
-----------------------------
| time/              |      |
|    fps             | 255  |
|    iterations      | 1    |
|    time_elapsed    | 16   |
|    total_timesteps | 4096 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 246          |
|    iterations           | 2            |
|    time_elapsed         | 33           |
|    total_timesteps      | 8192         |
| train/                  |              |
|    approx_kl            | 0.0150145525 |
|    clip_fraction        | 0.171        |
|    clip_range           | 0.2          |
|    entropy_loss         | -68.1        |
|    explained_variance   | -0.00209     |
|    learning_rate        | 3e-05        |
|    loss                 | 61.9         |
|    n_updates            | 15           |
|    policy_gradient_loss | -0.0587      |
|    std                  | 1            |
|    value_loss           | 162          |

<stable_baselines3.ppo.ppo.PPO at 0x729d428df0d0>

In [8]:
# Save the model
from stable_baselines3.common.evaluation import evaluate_policy
try:
    model.save("lux_ai_model")
    print("Model saved successfully!")
except Exception as e:
    print(f"Error saving model: {e}")

Model saved successfully!


In [9]:
mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=10, deterministic=True)

# Print results
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")



Mean reward: 324.20 +/- 247.31
