Skip to content

Commit

Permalink
Rl/py: Add ExtendObservationWrapper & observe progress
Browse files Browse the repository at this point in the history
  • Loading branch information
pkel committed Oct 21, 2022
1 parent 931676c commit 5093252
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
40 changes: 40 additions & 0 deletions python/gym/cpr_gym/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,46 @@ def step(self, action):
return obs, reward, done, info


class ExtendObservationWrapper(gym.Wrapper):
"""
Adds fields from info dict or elsewhere to the observation space.
"""

def __init__(self, env, fields):
super().__init__(env)
self.eow_fields = fields
self.eow_n = len(fields)
low = numpy.zeros(self.eow_n)
high = numpy.zeros(self.eow_n)
for i in range(self.eow_n):
_fn, l, h, _default = fields[i]
low[i] = l
high[i] = h
low = numpy.append(self.observation_space.low, low)
high = numpy.append(self.observation_space.high, high)
self.observation_space = gym.spaces.Box(low, high, dtype=numpy.float64)

def reset(self):
raw_obs = self.env.reset()
obs = numpy.zeros(self.eow_n)
for i in range(self.eow_n):
_fn, _low, _high, default = self.eow_fields[i]
obs[i] = default
return numpy.append(raw_obs, obs)

def step(self, action):
raw_obs, reward, done, info = self.env.step(action)
obs = numpy.zeros(self.eow_n)
for i in range(self.eow_n):
f, _low, _high, _default = self.eow_fields[i]
obs[i] = f(self, info)
return numpy.append(raw_obs, obs), reward, done, info

def policy(self, obs, name="honest"):
obs = obs[: -self.eow_n]
return self.env.policy(obs, name)


class AlphaScheduleWrapper(gym.Wrapper):
"""
Reconfigures alpha on each reset.
Expand Down
14 changes: 14 additions & 0 deletions python/gym/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ def test_alphaScheduleWrapper():
assert len(alphas.keys()) > 30


def test_ExtendObservationWrapper():
env = gym.make("cpr_gym:core-v0")
was_n = len(env.observation_space.low)

fields = []
fields.append(((lambda self, info: info["episode_progress"]), 0, float("inf"), 0))
env = wrappers.ExtendObservationWrapper(env, fields)

n = len(env.observation_space.low)
assert n == was_n + len(fields)

check_env(env)


def test_EpisodeRecorderWrapper():
env = gym.make("cpr_gym:core-v0")
env = wrappers.EpisodeRecorderWrapper(env, n=10, info_keys=["head_height"])
Expand Down
6 changes: 6 additions & 0 deletions python/train/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ def env_fn(eval=False, n_recordings=42):
info_keys=["alpha", "episode_chain_time", "episode_progress"],
)

fields = []
fields.append(((lambda self, info: info["episode_progress"]), 0, float("inf"), 0))
fields.append(((lambda self, info: info["episode_chain_time"]), 0, float("inf"), 0))
fields.append(((lambda self, info: info["episode_n_steps"]), 0, float("inf"), 0))
env = cpr_gym.wrappers.ExtendObservationWrapper(env, fields)

env = cpr_gym.wrappers.ClearInfoWrapper(env)

return env
Expand Down

0 comments on commit 5093252

Please sign in to comment.