-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NormalizeObservation and NormalizeReward
- Loading branch information
1 parent
214ad21
commit 0bdc4f0
Showing
5 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
import gym | ||
|
||
from lagom.transform import RunningMeanVar | ||
|
||
|
||
class NormalizeObservation(gym.ObservationWrapper): | ||
def __init__(self, env, clip=5., constant_moments=None): | ||
super().__init__(env) | ||
self.clip = clip | ||
self.constant_moments = constant_moments | ||
self.eps = 1e-8 | ||
if constant_moments is None: | ||
self.obs_moments = RunningMeanVar(shape=env.observation_space.shape) | ||
else: | ||
self.constant_mean, self.constant_var = constant_moments | ||
|
||
def observation(self, observation): | ||
if self.constant_moments is None: | ||
self.obs_moments([observation]) | ||
mean = self.obs_moments.mean | ||
std = np.sqrt(self.obs_moments.var + self.eps) | ||
else: | ||
mean = self.constant_mean | ||
std = np.sqrt(self.constant_var + self.eps) | ||
observation = np.clip((observation - mean)/std, -self.clip, self.clip) | ||
return observation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
import gym | ||
|
||
from lagom.transform import RunningMeanVar | ||
|
||
|
||
class NormalizeReward(gym.RewardWrapper): | ||
def __init__(self, env, clip=10., gamma=0.99, constant_var=None): | ||
super().__init__(env) | ||
self.clip = clip | ||
assert gamma > 0.0 and gamma < 1.0, 'we do not allow discounted factor as 1.0. See docstring for details. ' | ||
self.gamma = gamma | ||
self.constant_var = constant_var | ||
self.eps = 1e-8 | ||
if constant_var is None: | ||
self.reward_moments = RunningMeanVar(shape=()) | ||
|
||
# Buffer to save discounted returns from each environment | ||
self.all_returns = 0.0 | ||
|
||
def reset(self): | ||
# Reset returns buffer | ||
self.all_returns = 0.0 | ||
return super().reset() | ||
|
||
def step(self, action): | ||
observation, reward, done, info = super().step(action) | ||
# Set discounted return buffer as zero if episode terminates | ||
if done: | ||
self.all_returns = 0.0 | ||
return observation, reward, done, info | ||
|
||
def reward(self, reward): | ||
if self.constant_var is None: | ||
self.all_returns = reward + self.gamma*self.all_returns | ||
self.reward_moments([self.all_returns]) | ||
std = np.sqrt(self.reward_moments.var + self.eps) | ||
else: | ||
std = np.sqrt(self.constant_var + self.eps) | ||
# Do NOT subtract from mean, but only divided by std | ||
reward = np.clip(reward/std, -self.clip, self.clip) | ||
return reward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters