强化学习中会涉及很多的 wrapper 归一化处理过程

In [4]:
import gym
import numpy as np
from gym import spaces

# 创建一个简单的环境
env = gym.make('CartPole-v1')

# 定义一个状态归一化的 Wrapper
class NormalizeObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super(NormalizeObservation, self).__init__(env)
        # 修改观察空间为归一化后的空间
        self.observation_space = spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)

    def observation(self, observation):
        # 将原始观察值归一化到 [0, 1]
        return (observation - self.observation_space.low) / (self.observation_space.high - self.observation_space.low)

# 定义一个奖励缩放的 Wrapper
class ScaleReward(gym.RewardWrapper):
    def __init__(self, env, scale=0.1):
        super(ScaleReward, self).__init__(env)
        self.scale = scale

    def reward(self, reward):
        # 缩放奖励
        return reward * self.scale

# 使用 Wrappers
wrapped_env = NormalizeObservation(env)
wrapped_env = ScaleReward(wrapped_env)

# 进行一次交互
observation = wrapped_env.reset()
for _ in range(1000):
    action = wrapped_env.action_space.sample()  # 随机选择一个动作
    state, reward, done, truncated, info = wrapped_env.step(action)
    # state, reward, done, truncated, info = envs.step(action)
    print(f"Observation: {state}, Reward: {reward}")
    if done:
        break

wrapped_env.close()

Observation: [-0.01550082  0.24244374  0.02131401 -0.2551575 ], Reward: 0.1
Observation: [-0.01065194  0.04702406  0.01621086  0.0441713 ], Reward: 0.1
Observation: [-0.00971146 -0.14832655  0.01709429  0.3419245 ], Reward: 0.1
Observation: [-0.01267799  0.04654808  0.02393278  0.05468074], Reward: 0.1
Observation: [-0.01174703 -0.14890872  0.02502639  0.35481754], Reward: 0.1
Observation: [-0.0147252  -0.3443774   0.03212274  0.6552857 ], Reward: 0.1
Observation: [-0.02161275 -0.53993154  0.04522846  0.95790803], Reward: 0.1
Observation: [-0.03241138 -0.34544593  0.06438661  0.67977065], Reward: 0.1
Observation: [-0.0393203  -0.15127464  0.07798203  0.4080338 ], Reward: 0.1
Observation: [-0.04234579 -0.34741068  0.0861427   0.7242473 ], Reward: 0.1
Observation: [-0.04929401 -0.15357901  0.10062765  0.45987245], Reward: 0.1
Observation: [-0.05236559  0.03998731  0.1098251   0.2005264 ], Reward: 0.1
Observation: [-0.05156584  0.23338115  0.11383563 -0.05559294], Reward: 0.1
Observation:

  if not isinstance(terminated, (bool, np.bool8)):
