Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ppo_pettingzoo_ma_atari.py #408

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d9b9b11
Update ppo_pettingzoo_ma_atari.py
elliottower Jul 12, 2023
edc79d6
Pre-commit
elliottower Jul 13, 2023
d39da5e
Update PZ version
elliottower Jul 13, 2023
2b2dfce
Update Super
elliottower Jul 13, 2023
6d37313
Run pre-commit --hook-stage manual --all-files
elliottower Jul 13, 2023
0168986
run poetry lock --no-update to fix inconsistencies with versions
elliottower Jul 13, 2023
b7bffe9
re-run pre-commit with --hook-stage manual
elliottower Jul 13, 2023
2c76bb1
Change torch.maximum to torch.logical_or for dones
elliottower Jul 17, 2023
025f491
Use np.logical_or instead of torch (allows subtraction)
elliottower Jul 18, 2023
09f7a7f
Merge remote-tracking branch 'upstream/master' into patch-1
elliottower Jan 18, 2024
16e0764
Finish merge with upstream master
elliottower Jan 18, 2024
928b7b3
Fix SuperSuit to most recent version
elliottower Jan 18, 2024
d7a2aa2
Fix SuperSuit version in poetry lockfile and tinyscaler in pettingzoo…
elliottower Jan 18, 2024
d77cca0
Fix pettingzoo-requirements export (pre-commit hooks)
elliottower Jan 18, 2024
afba4e8
Test updating pettingzoo to new version 1.24.3
elliottower Jan 18, 2024
8671154
Update ma_atari to match regular atari (tyro, minor code style changes)
elliottower Jan 18, 2024
d2cf1a5
pre-commit
elliottower Jan 18, 2024
981bc63
Revert accidentally changed files (zoo and ipynb, which randomly seem…
elliottower Jan 18, 2024
454364d
Revert ipynb change
elliottower Jan 18, 2024
06473b2
Update dead pettingzoo.ml links to Farama foundation links
elliottower Jan 18, 2024
1b725cf
Update to newly release SuperSuit 3.9.2 (minor bugfixes but best to k…
elliottower Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 90 additions & 76 deletions cleanrl/ppo_pettingzoo_ma_atari.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,83 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy
import argparse
import importlib
import os
import random
import time
from distutils.util import strtobool
from dataclasses import dataclass

import gym
import gymnasium as gym
import numpy as np
import supersuit as ss
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture_video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="pong_v3",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=20000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=16,
help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=128,
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--gae-lambda", type=float, default=0.95,
help="the lambda for the general advantage estimation")
parser.add_argument("--num-minibatches", type=int, default=4,
help="the number of mini-batches")
parser.add_argument("--update-epochs", type=int, default=4,
help="the K epochs to update the policy")
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles advantages normalization")
parser.add_argument("--clip-coef", type=float, default=0.1,
help="the surrogate clipping coefficient")
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--ent-coef", type=float, default=0.01,
help="coefficient of the entropy")
parser.add_argument("--vf-coef", type=float, default=0.5,
help="coefficient of the value function")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
# fmt: on
return args
env_id: str = "pong_v3"
"""the id of the environment"""
total_timesteps: int = 20000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 16
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 4
"""the K epochs to update the policy"""
norm_adv: bool = True
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
clip_vloss: bool = True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef: float = 0.01
"""coefficient of the entropy"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
target_kl: float = None
"""the target KL divergence threshold"""

# to be filled in runtime
batch_size: int = 0
"""the batch size (computed in runtime)"""
minibatch_size: int = 0
"""the mini-batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
Expand Down Expand Up @@ -118,7 +120,10 @@ def get_action_and_value(self, x, action=None):


if __name__ == "__main__":
args = parse_args()
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
Expand Down Expand Up @@ -156,11 +161,10 @@ def get_action_and_value(self, x, action=None):
env = ss.frame_stack_v1(env, 4)
env = ss.agent_indicator_v0(env, type_only=False)
env = ss.pettingzoo_env_to_vec_env_v1(env)
envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gym")
envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gymnasium")
envs.single_observation_space = envs.observation_space
envs.single_action_space = envs.action_space
envs.is_vector_env = True
envs = gym.wrappers.RecordEpisodeStatistics(envs)
if args.capture_video:
envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}")
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
Expand All @@ -173,27 +177,31 @@ def get_action_and_value(self, x, action=None):
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
terminations = torch.zeros((args.num_steps, args.num_envs)).to(device)
truncations = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)
next_obs, info = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_termination = torch.zeros(args.num_envs).to(device)
next_truncation = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size

for update in range(1, num_updates + 1):
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow

for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
global_step += args.num_envs
obs[step] = next_obs
dones[step] = next_done
terminations[step] = next_termination
truncations[step] = next_truncation

# ALGO LOGIC: action logic
with torch.no_grad():
Expand All @@ -203,10 +211,15 @@ def get_action_and_value(self, x, action=None):
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, done, info = envs.step(action.cpu().numpy())
next_obs, reward, termination, truncation, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
next_obs, next_termination, next_truncation = (
torch.Tensor(next_obs).to(device),
torch.Tensor(termination).to(device),
torch.Tensor(truncation).to(device),
)

# TODO: fix this
for idx, item in enumerate(info):
player_idx = idx % 2
if "episode" in item.keys():
Expand All @@ -219,6 +232,8 @@ def get_action_and_value(self, x, action=None):
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
next_done = np.logical_or(next_termination, next_truncation)
Copy link

@KaleabTessera KaleabTessera Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is bug here. We should still bootstrap if next_truncation=True :

"you should bootstrap if infos[env_idx]["TimeLimit.truncated"] is True (episode over due to a timeout/truncation) or dones[env_idx] is False (episode not finished)." - stable baselines

So next_done=next_termination and dones=terminations (probs just use next_terminations and terminations directly e.g. nextnonterminal = 1.0 - next_termination ).

To implement this correctly we also need access to terminal_observation from pettingzoo_env_to_vec_env_v1 since we need access to the true terminal obs and not the obs returned by the next restart (the case currently -- so we need infos to provide access to the terminal obs). I have a PR out for this . Then we can implement something like this to do correct bootstrapping for truncating/timeout.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @KaleabTessera would you be willing to update this branch with the changes? I can give you edit access, I currently have a lot of other obligations from work so don’t have much time for this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh shoot it’s a patch-1 so I don’t know if you can be given access. But if you clone the repo you can make a new branch from this branch and make a new PR if it’s not possible to edit this branch? Or maybe make a PR to update this branch itself. Sorry I can’t help more

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I am doing a refactor at #424 . Gonna try run a whole suite of benchmark soon.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay cool, sorry I remember you gave access to the WandB thing but I've not had time to do it. Probably simplest if you do it anyways, so thanks for that. It may be interesting to compare performance with the AgileRL multi agent atari example https://docs.agilerl.com/en/latest/tutorials/pettingzoo/maddpg.html

I see the issue linked in that PR mentions timeout handling, is that the same as mentioned below with termination vs truncation? Anyways there's anything needed from PettingZoo or SuperSuit's end let me know.

Copy link
Author

@elliottower elliottower Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is bug here. We should still bootstrap if next_truncation=True :

"you should bootstrap if infos[env_idx]["TimeLimit.truncated"] is True (episode over due to a timeout/truncation) or dones[env_idx] is False (episode not finished)." - stable baselines

So next_done=next_termination and dones=terminations (probs just use next_terminations and terminations directly e.g. nextnonterminal = 1.0 - next_termination ).

To implement this correctly we also need access to terminal_observation from pettingzoo_env_to_vec_env_v1 since we need access to the true terminal obs and not the obs returned by the next restart (the case currently -- so we need infos to provide access to the terminal obs). I have a PR out for this . Then we can implement something like this to do correct bootstrapping for truncating/timeout.

Btw, just as an update, the SuperSuit PR linked above has been merged. My only concern with this is that whatever bootstrapping behavior is done here should mirror what is done with the single agent PPO implementations, so this is a question for @vwxyzjn.

My inclination is to keep the logic as it currently is in this PR and address that bootstrapping issue in another PR (maybe @KaleabTessera is interested in doing that? I don't have a whole lot of time to look into it nor am I the best person to do it as I'm not an expert)

dones = np.logical_or(terminations, truncations)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
Expand Down Expand Up @@ -289,9 +304,8 @@ def get_action_and_value(self, x, action=None):
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()

if args.target_kl is not None:
if approx_kl > args.target_kl:
break
if args.target_kl is not None and approx_kl > args.target_kl:
break

y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
var_y = np.var(y_true)
Expand Down
4 changes: 2 additions & 2 deletions docs/rl-algorithms/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ Tracked experiments and game play videos:


## `ppo_pettingzoo_ma_atari.py`
[ppo_pettingzoo_ma_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py) trains an agent to learn playing Atari games via selfplay. The selfplay environment is implemented as a vectorized environment from [PettingZoo.ml](https://www.pettingzoo.ml/atari). The basic idea is to create vectorized environment $E$ with `num_envs = N`, where $N$ is the number of players in the game. Say $N = 2$, then the 0-th sub environment of $E$ will return the observation for player 0 and 1-th sub environment will return the observation of player 1. Then the two environments takes a batch of 2 actions and execute them for player 0 and player 1, respectively. See "Vectorized architecture" in [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) for more detail.
[ppo_pettingzoo_ma_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py) trains an agent to learn playing Atari games via selfplay. The selfplay environment is implemented as a vectorized environment from [PettingZoo](https://pettingzoo.farama.org/environments/atari/). The basic idea is to create vectorized environment $E$ with `num_envs = N`, where $N$ is the number of players in the game. Say $N = 2$, then the 0-th sub environment of $E$ will return the observation for player 0 and 1-th sub environment will return the observation of player 1. Then the two environments takes a batch of 2 actions and execute them for player 0 and player 1, respectively. See "Vectorized architecture" in [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) for more detail.

`ppo_pettingzoo_ma_atari.py` has the following features:

Expand Down Expand Up @@ -1064,7 +1064,7 @@ Tracked experiments and game play videos:
python cleanrl/ppo_pettingzoo_ma_atari.py --env-id surround_v2
```

See [https://www.pettingzoo.ml/atari](https://www.pettingzoo.ml/atari) for a full-list of supported environments such as `basketball_pong_v3`. Notice pettingzoo sometimes introduces breaking changes, so make sure to install the pinned dependencies via `poetry`.
See [https://pettingzoo.farama.org/environments/atari/](https://pettingzoo.farama.org/environments/atari/) for a full-list of supported environments such as `basketball_pong_v3`. Notice pettingzoo sometimes introduces breaking changes, so make sure to install the pinned dependencies via `poetry`.

### Explanation of the logged metrics

Expand Down
Loading
Loading