Use Stable Baselines3 to train an expert policy for the Simple Speaker Listener environment in PettingZoo.

In [17]:
import numpy as np
import torch
import pickle
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from pettingzoo.mpe import simple_speaker_listener_v4
from pettingzoo.utils import parallel_to_aec

In [30]:
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from pettingzoo.mpe import simple_speaker_listener_v4
import pickle

In [None]:
def env_creator(config):
    env = simple_speaker_listener_v4.env()
    return PettingZooEnv(env)  # Wrap to make it Gymnasium-compatible

# Create a temporary environment to retrieve spaces.
temp_env = env_creator({})

speaker_obs_space = temp_env.observation_space["speaker_0"]
speaker_act_space = temp_env.action_space["speaker_0"]
listener_obs_space = temp_env.observation_space["listener_0"]
listener_act_space = temp_env.action_space["listener_0"]

def policy_mapping_fn(agent_id, episode, **kwargs):
    if "speaker" in agent_id:
        return "speaker_policy"
    else:
        return "listener_policy"


# Configure the PPO algorithm
from ray.rllib.algorithms.ppo import PPOConfig

config = PPOConfig() \
    .environment("simple_speaker_listener") \
    .framework("torch") \
    .multi_agent(
        policies={
            "speaker_policy": (None, speaker_obs_space, speaker_act_space, {}),
            "listener_policy": (None, listener_obs_space, listener_act_space, {}),
        },
        policy_mapping_fn=policy_mapping_fn,
    ) \
    .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False) \
    .env_runners(num_env_runners=0)



# Define the checkpoint path (replace this with your actual path)
checkpoint_path = "C:/Users/wangy/ray_results/PPO_2025-02-23_14-39-33/PPO_simple_speaker_listener_9d625_00000_0_2025-02-23_14-39-33/checkpoint_000000"  # This could be the path to a checkpoint folder or a specific checkpoint file
# Restore the trainer from the best checkpoint
trainer = config.build()
trainer.restore(checkpoint_path)


def generate_expert_data(trainer, num_episodes=50):
    env = simple_speaker_listener_v4.parallel_env(continuous_actions=False, render_mode="rgb_array", max_cycles=25)
    
    obs, _ = env.reset()  # Unpack properly
    
    expert_data = {agent: {"states": [], "actions": []} for agent in env.agents}

    for _ in range(num_episodes):
        obs, _ = env.reset()  # Unpack here as well
        # print(f"Agents after reset: {env.agents}")  # Check if agents are populated after reset
        done = {agent: False for agent in env.agents}
        
        while not all(done.values()):
            actions = {}
            for agent in env.agents:
                policy_id = policy_mapping_fn(agent, None)  # Get correct policy
                policy = trainer.get_policy(policy_id)  # Fetch policy from trainer

                return_from_compute_single_action = policy.compute_single_action(obs[agent], explore=False)
                actions[agent] = return_from_compute_single_action[0]
            
            next_obs, rewards, done, infos, _ = env.step(actions)
            
            for agent in env.agents:
                if not done[agent]:
                    expert_data[agent]["states"].append(obs[agent])
                    expert_data[agent]["actions"].append(actions[agent])
            
            obs = next_obs  # Update observations
    return expert_data

# Generate expert data using the trained policy
expert_data = generate_expert_data(trainer, num_episodes=50)

expert_data

In [18]:
from pettingzoo.mpe import simple_speaker_listener_v4
from pettingzoo.utils.conversions import aec_to_parallel
import supersuit as ss

# Create your environment (using the parallel version)
env = simple_speaker_listener_v4.env(render_mode="rgb_array")

# Convert to a parallel environment if not already in parallel form
env = aec_to_parallel(env)

# Apply wrappers to pad the action and observation spaces
env = ss.pad_action_space_v0(env)
env = ss.pad_observations_v0(env)


# Convert to a vectorized environment compatible with SB3
vec_env = ss.pettingzoo_env_to_vec_env_v1(env)


In [22]:
import gymnasium as gym

class GymEnvWrapper(gym.Env):
    def __init__(self, env):
        self.env = env
        self.action_space = env.action_space
        self.observation_space = env.observation_space

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        return self.env.step(action)

    def render(self, mode="human"):
        return self.env.render(mode)

    def close(self):
        return self.env.close()


## Train the PPO Model
Since it's a multi-agent setting, PPO will learn a joint policy for both agents. The observations are a tuple (speaker_obs, listener_obs), and actions are a tuple (speaker_action, listener_action).

In [23]:
wrapped_env = GymEnvWrapper(vec_env)

from stable_baselines3 import PPO
model = PPO("MlpPolicy", wrapped_env, verbose=1)
model.learn(total_timesteps=100000)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


Exception ignored in: <function ProcConcatVec.__del__ at 0x0000029EDF214B80>
Traceback (most recent call last):
  File "c:\Users\wangy\anaconda3\envs\Multi-Agent-General\Lib\site-packages\supersuit\vector\multiproc_vec.py", line 223, in __del__
    self.close()
  File "c:\Users\wangy\anaconda3\envs\Multi-Agent-General\Lib\site-packages\supersuit\vector\multiproc_vec.py", line 238, in close
    for pipe, proc in zip(self.pipes, self.procs):
                          ^^^^^^^^^^
AttributeError: 'ProcConcatVec' object has no attribute 'pipes'
Exception ignored in: <function ProcConcatVec.__del__ at 0x0000029EDF214B80>
Traceback (most recent call last):
  File "c:\Users\wangy\anaconda3\envs\Multi-Agent-General\Lib\site-packages\supersuit\vector\multiproc_vec.py", line 223, in __del__
    self.close()
  File "c:\Users\wangy\anaconda3\envs\Multi-Agent-General\Lib\site-packages\supersuit\vector\multiproc_vec.py", line 238, in close
    for pipe, proc in zip(self.pipes, self.procs):
           

ValueError: could not broadcast input array from shape (2,11) into shape (11,)