Skip to content

Commit 98dc855

Browse files
authored
Invalid masking support (#16)
* Support for using action_masks in ppo Updated default microrts envs to use action_mask. Added non-masked variants for comparison. * Bump to v0.0.7
1 parent e68ca12 commit 98dc855

File tree

14 files changed

+470
-269
lines changed

14 files changed

+470
-269
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rl_algo_impls"
3-
version = "0.0.6"
3+
version = "0.0.7"
44
description = "Implementations of reinforcement learning algorithms"
55
authors = [
66
{name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},

rl_algo_impls/a2c/a2c.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from rl_algo_impls.shared.algorithm import Algorithm
1212
from rl_algo_impls.shared.callbacks.callback import Callback
13+
from rl_algo_impls.shared.gae import compute_advantages
1314
from rl_algo_impls.shared.policy.on_policy import ActorCritic
1415
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
1516
from rl_algo_impls.shared.stats import log_scalars
@@ -84,7 +85,7 @@ def learn(
8485
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
8586
actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
8687
rewards = np.zeros(epoch_dim, dtype=np.float32)
87-
episode_starts = np.zeros(epoch_dim, dtype=np.byte)
88+
episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
8889
values = np.zeros(epoch_dim, dtype=np.float32)
8990
logprobs = np.zeros(epoch_dim, dtype=np.float32)
9091

@@ -126,23 +127,16 @@ def learn(
126127
clamped_action
127128
)
128129

129-
advantages = np.zeros(epoch_dim, dtype=np.float32)
130-
last_gae_lam = 0
131-
for t in reversed(range(self.n_steps)):
132-
if t == self.n_steps - 1:
133-
next_nonterminal = 1.0 - next_episode_starts
134-
next_value = self.policy.value(next_obs)
135-
else:
136-
next_nonterminal = 1.0 - episode_starts[t + 1]
137-
next_value = values[t + 1]
138-
delta = (
139-
rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
140-
)
141-
last_gae_lam = (
142-
delta
143-
+ self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
144-
)
145-
advantages[t] = last_gae_lam
130+
advantages = compute_advantages(
131+
rewards,
132+
values,
133+
episode_starts,
134+
next_episode_starts,
135+
next_obs,
136+
self.policy,
137+
self.gamma,
138+
self.gae_lambda,
139+
)
146140
returns = advantages + values
147141

148142
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)

rl_algo_impls/hyperparams/ppo.yml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ _microrts: &microrts-defaults
218218
env_hyperparams: &microrts-env-defaults
219219
n_envs: 8
220220
vec_env_class: sync
221+
mask_actions: true
221222
policy_hyperparams:
222223
<<: *atari-policy-defaults
223224
cnn_style: microrts
@@ -227,10 +228,23 @@ _microrts: &microrts-defaults
227228
clip_range_decay: none
228229
clip_range_vf: 0.1
229230

230-
debug-MicrortsMining-v1:
231+
_no-mask-microrts: &no-mask-microrts-defaults
231232
<<: *microrts-defaults
233+
env_hyperparams:
234+
<<: *microrts-env-defaults
235+
mask_actions: false
236+
237+
MicrortsMining-v1-NoMask:
238+
<<: *no-mask-microrts-defaults
232239
env_id: MicrortsMining-v1
233-
device: cpu
240+
241+
MicrortsAttackShapedReward-v1-NoMask:
242+
<<: *no-mask-microrts-defaults
243+
env_id: MicrortsAttackShapedReward-v1
244+
245+
MicrortsRandomEnemyShapedReward3-v1-NoMask:
246+
<<: *no-mask-microrts-defaults
247+
env_id: MicrortsRandomEnemyShapedReward3-v1
234248

235249
HalfCheetahBulletEnv-v0: &pybullet-defaults
236250
n_timesteps: !!float 2e6

0 commit comments

Comments
 (0)