In [1]:
from __future__ import annotations

import glob
import os
import time

from pettingzoo.test import api_test
import pettingzoo
import gymnasium as gym

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker

from lib.briscola_env.briscola_env import BriscolaEnv

In [9]:
env = BriscolaEnv()
api_test(env, num_cycles=1000)


Starting API test
Passed API test


In [3]:
# To pass into other gymnasium wrappers, we need to ensure that pettingzoo's wrappper
# can also be a gymnasium Env. Thus, we subclass under gym.Env as well.
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper, gym.Env):
    """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""

    def reset(self, seed=None, options=None):
        """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.

        This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
        """
        super().reset(seed, options)

        # Strip the action mask out from the observation space
        self.observation_space = super().observation_space(self.possible_agents[0])[
            "observation"
        ]
        self.action_space = super().action_space(self.possible_agents[0])

        # Return initial observation, info (PettingZoo AEC envs do not by default)
        return self.observe(self.agent_selection), {}

    def step(self, action):
        """Gymnasium-like step function, returning observation, reward, termination, truncation, info.

        The observation is for the next agent (used to determine the next action), while the remaining
        items are for the agent that just acted (used to understand what just happened).
        """
        current_agent = self.agent_selection

        super().step(action)

        next_agent = self.agent_selection
        return (
            self.observe(next_agent),
            self._cumulative_rewards[current_agent],
            self.terminations[current_agent],
            self.truncations[current_agent],
            self.infos[current_agent],
        )

    def observe(self, agent):
        """Return only raw observation, removing action mask."""
        return super().observe(agent)["observation"]

    def action_mask(self):
        """Separate function used in order to access the action mask."""
        return super().observe(self.agent_selection)["action_mask"]

In [None]:
def mask_fn(env):
    return env.action_mask()

def train(
    steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
    # Train a single model to play as each agent in a cooperative Parallel environment
    env = BriscolaEnv()
    env = SB3ActionMaskWrapper(env)
    env.reset(seed=seed)
    env = ActionMasker(env, mask_fn)

    print(f"Starting training on {str(env.metadata)}.")
    model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
    model.set_random_seed(seed)
    model.learn(total_timesteps=steps)
    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")
    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n")
    env.close()



In [5]:
train()

Starting training on {'render_modes': []}.
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 40       |
|    ep_rew_mean     | 1.6e+03  |
| time/              |          |
|    fps             | 294      |
|    iterations      | 1        |
|    time_elapsed    | 6        |
|    total_timesteps | 2048     |
---------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 40            |
|    ep_rew_mean          | 1.67e+03      |
| time/                   |               |
|    fps                  | 243           |
|    iterations           | 2             |
|    time_elapsed         | 16            |
|    total_timesteps      | 4096          |
| train/                  |               |
|    approx_kl            | 0.00028958335 |
|    clip_fraction        | 0         

In [6]:
def eval_action_mask(player, num_games=100):
    # Evaluate a trained agent vs a random agent
    env = BriscolaEnv()

    print(
        f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[player]}."
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)
    print("using", latest_policy)
    model = MaskablePPO.load(latest_policy)

    scores = {agent: 0 for agent in env.possible_agents}
    total_rewards = {agent: 0 for agent in env.possible_agents}
    for i in range(num_games):
        env.reset(seed=i)

        for agent in env.agent_iter():
            obs, reward, termination, truncation, info = env.last()

            # Separate observation and action mask
            observation, action_mask = obs.values()

            if termination or truncation:
                winner = max(env.rewards, key=env.rewards.get)
                scores[winner] += env.rewards[
                    winner
                ]  # only tracks the largest reward (winner of game)
                # Also track negative and positive rewards (penalizes illegal moves)
                for a in env.possible_agents:
                    total_rewards[a] += env.rewards[a]
                # List of rewards by round, for reference
                break
            else:
                if agent != env.possible_agents[player]:
                    act = env.action_space(agent).sample(action_mask)
                else:
                    # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
                    act = int(
                        model.predict(
                            observation, action_masks=action_mask, deterministic=True
                        )[0]
                    )
            env.step(act)
    env.close()

    print("Winrates:")
    for p in env.possible_agents:
        if sum(scores.values()) == 0:
            winrate = 0
        else:
            winrate = scores[p] / sum(scores.values())
        print(f"\t{p}: {winrate}")
        print(f"\t{total_rewards[p]}")
    print("Total rewards: ", total_rewards)
    print("Final scores: ", scores)

In [7]:
eval_action_mask(1)

Starting evaluation vs a random agent. Trained agent will play as player_1.
using briscola_20250426-232304.zip
Winrates:
	player_0: 0.3187653592843802
	19692
	player_1: 0.19136931092106557
	12501
	player_2: 0.2887840361741866
	18173
	player_3: 0.20108129362036764
	13175
Total rewards:  {'player_0': 19692, 'player_1': 12501, 'player_2': 18173, 'player_3': 13175}
Final scores:  {'player_0': 16214, 'player_1': 9734, 'player_2': 14689, 'player_3': 10228}
