Skip to content

Commit

Permalink
Upgrade WrapperNormalizeActions to match Dict&Tuple action_spaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed May 21, 2019
1 parent 91e0e65 commit 41a2bf4
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions digideep/environment/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,37 @@ def observation(self, observation):

class WrapperNormalizedActions(gym.ActionWrapper):
"""
This is taken from `RL-Adventure-2 <https://github.com/higgsfield/RL-Adventure-2>`__.
This is based on `RL-Adventure-2 <https://github.com/higgsfield/RL-Adventure-2>`__.
"""
def action(self, action):
low = self.action_space.low
high = self.action_space.high

action = low + (action + 1.0) * 0.5 * (high - low)
action = np.clip(action, low, high)

return action
# TODO: What if action is a Dict?
return self._get_action_direct(self.action_space, action, reversed=False)

def reverse_action(self, action):
low = self.action_space.low
high = self.action_space.high

action = 2 * (action - low) / (high - low) - 1
action = np.clip(action, low, high)
return self._get_action_direct(self.action_space, action, reversed=True)

def _get_action_direct(self, action_space, action, reversed = False):
if isinstance(action_space, spaces.Dict):
action_ret = {}
for key in action_space.spaces:
action_ret[key] = self._get_action_direct(action_space.spaces[key], action[key])
elif isinstance(action_space, spaces.Tuple):
action_ret = []
for index in range(len(action_space.spaces)):
action_ret += [self._get_action_direct(action_space.spaces[index], action[index])]
action_ret = tuple(action_ret)
else:
low = action_space.low
high = action_space.high

return actions

if reversed:
action_ret = 2 * (action - low) / (high - low) - 1
action_ret = np.clip(action_ret, low, high)
else:
action_ret = low + (action + 1.0) * 0.5 * (high - low)
action_ret = np.clip(action_ret, low, high)

return action_ret

class WrapperDummyMultiAgent(gym.ActionWrapper):
"""
Expand Down

0 comments on commit 41a2bf4

Please sign in to comment.