In [2]:
from __future__ import annotations

import glob
import os
import time

from pettingzoo.test import api_test
import pettingzoo
import gymnasium as gym

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker

from lib.briscola_env.briscola_env import BriscolaEnv

In [3]:
env = BriscolaEnv()
api_test(env, num_cycles=1000)


Starting API test
Passed API test




In [4]:
# To pass into other gymnasium wrappers, we need to ensure that pettingzoo's wrappper
# can also be a gymnasium Env. Thus, we subclass under gym.Env as well.
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper, gym.Env):
    """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""

    def reset(self, seed=None, options=None):
        """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.

        This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
        """
        super().reset(seed, options)

        # Strip the action mask out from the observation space
        self.observation_space = super().observation_space(self.possible_agents[0])[
            "observation"
        ]
        self.action_space = super().action_space(self.possible_agents[0])

        # Return initial observation, info (PettingZoo AEC envs do not by default)
        return self.observe(self.agent_selection), {}

    def step(self, action):
        """Gymnasium-like step function, returning observation, reward, termination, truncation, info.

        The observation is for the next agent (used to determine the next action), while the remaining
        items are for the agent that just acted (used to understand what just happened).
        """
        current_agent = self.agent_selection

        super().step(action)

        next_agent = self.agent_selection
        return (
            self.observe(next_agent),
            self._cumulative_rewards[current_agent],
            self.terminations[current_agent],
            self.truncations[current_agent],
            self.infos[current_agent],
        )

    def observe(self, agent):
        """Return only raw observation, removing action mask."""
        return super().observe(agent)["observation"]

    def action_mask(self):
        """Separate function used in order to access the action mask."""
        return super().observe(self.agent_selection)["action_mask"]

In [5]:
def mask_fn(env):
    return env.action_mask()

def train(
    steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
    # Train a single model to play as each agent in a cooperative Parallel environment
    env = BriscolaEnv()
    env = SB3ActionMaskWrapper(env)
    env.reset(seed=seed)
    env = ActionMasker(env, mask_fn)

    print(f"Starting training on {str(env.metadata)}.")
    model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
    model.set_random_seed(seed)
    model.learn(total_timesteps=steps)
    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")
    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n")
    env.close()



In [None]:
train(steps=100_000)

In [None]:
def eval_action_mask(player, num_games=100):
    # Evaluate a trained agent vs a random agent
    env = BriscolaEnv()

    print(
        f"Starting evaluation vs random agents. Trained agent will play as {env.possible_agents[player]}."
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)
    print("using", latest_policy)
    model = MaskablePPO.load(latest_policy)

    scores = {agent: 0 for agent in env.possible_agents}
    total_rewards = {agent: 0 for agent in env.possible_agents}
    for i in range(num_games):
        env.reset(seed=i)

        for agent in env.agent_iter():
            obs, reward, termination, truncation, info = env.last()

            # Separate observation and action mask
            observation, action_mask = obs.values()

            if termination or truncation:
                winner = max(env.rewards, key=env.rewards.get)
                scores[winner] += env.rewards[
                    winner
                ]  # only tracks the largest reward (winner of game)
                # Also track negative and positive rewards (penalizes illegal moves)
                for a in env.possible_agents:
                    total_rewards[a] += env.rewards[a]
                # List of rewards by round, for reference
                break
            else:
                if agent != env.possible_agents[player]:
                    act = env.action_space(agent).sample(action_mask)
                else:
                    # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
                    act = int(
                        model.predict(
                            observation, action_masks=action_mask, deterministic=True
                        )[0]
                    )
            env.step(act)
    env.close()

    print("Winrates:")
    for p in env.possible_agents:
        if sum(scores.values()) == 0:
            winrate = 0
        else:
            winrate = scores[p] / sum(scores.values())
        print(f"\t{p}: {winrate*100}%`")
        print(f"\t{total_rewards[p]}")
    print("Total rewards: ", total_rewards)
    print("Final scores: ", scores)

In [None]:
for position in range(4):
	print(f"--- Testing position {position} ---")
	eval_action_mask(position, num_games=1_000)

--- At position 0 ---
Starting evaluation vs random agents. Trained agent will play as player_0.
using briscola_20250427-142816.zip
Winrates:
	player_0: 32.322438717787556%`
	19939
	player_1: 16.905641106222504%`
	11726
	player_2: 30.768777498428662%`
	18708
	player_3: 20.003142677561282%`
	12786
Total rewards:  {'player_0': 19939, 'player_1': 11726, 'player_2': 18708, 'player_3': 12786}
Final scores:  {'player_0': 16456, 'player_1': 8607, 'player_2': 15665, 'player_3': 10184}
--- At position 1 ---
Starting evaluation vs random agents. Trained agent will play as player_1.
using briscola_20250427-142816.zip
Winrates:
	player_0: 28.867187500000004%`
	17753
	player_1: 21.064453125%`
	13870
	player_2: 33.09765625%`
	20363
	player_3: 16.970703125%`
	11796
Total rewards:  {'player_0': 17753, 'player_1': 13870, 'player_2': 20363, 'player_3': 11796}
Final scores:  {'player_0': 14780, 'player_1': 10785, 'player_2': 16946, 'player_3': 8689}
--- At position 2 ---
Starting evaluation vs random age