import gym import collections import cv2 import numpy as np class RepeatActionandMaxFrame(gym.Wrapper): def __init__(self, env=None, repeat=4, clip_reward=False, no_ops=0, fire_first=False): super(RepeatActionandMaxFrame, self).__init__(env) self.env = env self.repeat = repeat self.shape = env.observation_space.shape self.frame_buffer = np.zeros_like((2, self.env.observation_space.shape)) self.clip_reward = clip_reward self.no_ops = no_ops self.fire_first = fire_first def step(self, action): total_reward = 0 done = False for i in range(self.repeat): state, reward, done, info = self.env.step(action) #if self.clip_reward: #reward = np.clip(np.array([reward]), -1, 1)[0] total_reward += reward index = i%2 self.frame_buffer[index] = state if done: break max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1]) return max_frame, total_reward, done, info def reset(self): state = self.env.reset() no_ops = np.random.randint(self.no_ops)+1 if self.no_ops > 0 else 0 for _ in range(no_ops): _, _, done, _ = self.env.step(0) if done: self.env.reset() if self.fire_first: assert self.env.unwrapped.get_action_meanings()[1] == 'FIRE' obs, _, _, _ = self.env.step(1) self.frame_buffer = np.zeros_like((2, self.env.observation_space.shape)) self.frame_buffer[0] = state return state class PreprocessFrame(gym.ObservationWrapper): def __init__(self, shape, env=None): super(PreprocessFrame, self).__init__(env) self.shape = (shape[2], shape[1], shape[0]) self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32) def observation(self, obs): new_frame = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY) resized_screen = cv2.resize(new_frame, self.shape[1:], interpolation=cv2.INTER_AREA) new_obs = np.array(resized_screen, dtype=np.uint8).reshape(self.shape) new_obs = new_obs/255 return new_obs class StackFrames(gym.ObservationWrapper): def __init__(self, env, repeat): super(StackFrames, self).__init__(env) self.observation_space = gym.spaces.Box( env.observation_space.low.repeat(repeat, axis=0), env.observation_space.high.repeat(repeat, axis=0), dtype=np.float32) self.stack = collections.deque(maxlen=repeat) def reset(self): self.stack.clear() state = self.env.reset() for i in range(self.stack.maxlen): self.stack.append(state) return np.array(self.stack).reshape(self.observation_space.low.shape) def observation(self, state): self.stack.append(state) return np.array(self.stack).reshape(self.observation_space.low.shape) def make_env(env_name, shape=(84, 84, 1), repeat=4, clip_rewards=False, no_ops=0, fire_first=False): env = gym.make(env_name) env = RepeatActionandMaxFrame(env, repeat, clip_rewards, no_ops, fire_first) env = PreprocessFrame(shape, env) env = StackFrames(env, repeat) return env