In [1]:
import gym
from gym import spaces
import numpy as np

class SimpleMaskableEnv(gym.Env):
    """A simple custom environment with action masking."""
    
    def __init__(self):
        super(SimpleMaskableEnv, self).__init__()
        self.action_space = spaces.Discrete(5)  # 5 possible actions
        self.observation_space = spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32)
        self.state = None
        self.current_step = 0

    def reset(self):
        self.current_step = 0
        self.state = np.random.rand(3).astype(np.float32)
        return self.state, {}

    def step(self, action):
        assert self.action_space.contains(action), f"Invalid action: {action}"

        reward = 1.0 if action == self.correct_action() else -1.0
        done = self.current_step >= 10

        self.state = np.random.rand(3).astype(np.float32)
        self.current_step += 1

        info = {"action_mask": self.compute_action_mask()}
        return self.state, reward, done, False, info

    def compute_action_mask(self):
        # Let's make action 0 always illegal, and randomly make others illegal
        mask = np.ones(self.action_space.n, dtype=bool)
        mask[0] = False  # Always mask action 0
        for i in range(1, self.action_space.n):
            if np.random.rand() < 0.3:  # 30% chance to disable an action
                mask[i] = False
        return mask

    def correct_action(self):
        # Just a dummy way to define a correct action
        # Choose the smallest legal action
        mask = self.compute_action_mask()
        legal_actions = np.where(mask)[0]
        return np.min(legal_actions) if len(legal_actions) > 0 else 1  # fallback to action 1


In [2]:
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker

def mask_fn(env):
    # this tells MaskablePPO how to get the action mask from the env
    return env.compute_action_mask()

env = SimpleMaskableEnv()
env = ActionMasker(env, mask_fn)

model = MaskablePPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=5000)




AssertionError: 