/
env.py
108 lines (91 loc) · 3.39 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""
import gym_super_mario_bros
from gym.spaces import Box
from gym import Wrapper
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
import cv2
import numpy as np
import subprocess as sp
class Monitor:
def __init__(self, width, height, saved_path):
self.command = ["ffmpeg", "-y", "-f", "rawvideo", "-vcodec", "rawvideo", "-s", "{}X{}".format(width, height),
"-pix_fmt", "rgb24", "-r", "80", "-i", "-", "-an", "-vcodec", "mpeg4", saved_path]
try:
self.pipe = sp.Popen(self.command, stdin=sp.PIPE, stderr=sp.PIPE)
except FileNotFoundError:
pass
def record(self, image_array):
self.pipe.stdin.write(image_array.tostring())
def process_frame(frame):
if frame is not None:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (84, 84))[None, :, :] / 255.
return frame
else:
return np.zeros((1, 84, 84))
class CustomReward(Wrapper):
def __init__(self, env=None, monitor=None):
super(CustomReward, self).__init__(env)
self.observation_space = Box(low=0, high=255, shape=(1, 84, 84))
self.curr_score = 0
if monitor:
self.monitor = monitor
else:
self.monitor = None
def step(self, action):
state, reward, done, info = self.env.step(action)
if self.monitor:
self.monitor.record(state)
state = process_frame(state)
reward += (info["score"] - self.curr_score) / 40.
self.curr_score = info["score"]
if done:
if info["flag_get"]:
reward += 50
else:
reward -= 50
return state, reward / 10., done, info
def reset(self):
self.curr_score = 0
return process_frame(self.env.reset())
class CustomSkipFrame(Wrapper):
def __init__(self, env, skip=4):
super(CustomSkipFrame, self).__init__(env)
self.observation_space = Box(low=0, high=255, shape=(4, 84, 84))
self.skip = skip
def step(self, action):
total_reward = 0
states = []
state, reward, done, info = self.env.step(action)
for i in range(self.skip):
if not done:
state, reward, done, info = self.env.step(action)
total_reward += reward
states.append(state)
else:
states.append(state)
states = np.concatenate(states, 0)[None, :, :, :]
return states.astype(np.float32), reward, done, info
def reset(self):
state = self.env.reset()
states = np.concatenate([state for _ in range(self.skip)], 0)[None, :, :, :]
return states.astype(np.float32)
def create_train_env(world, stage, action_type, output_path=None):
env = gym_super_mario_bros.make("SuperMarioBros-{}-{}-v0".format(world, stage))
if output_path:
monitor = Monitor(256, 240, output_path)
else:
monitor = None
if action_type == "right":
actions = RIGHT_ONLY
elif action_type == "simple":
actions = SIMPLE_MOVEMENT
else:
actions = COMPLEX_MOVEMENT
env = JoypadSpace(env, actions)
env = CustomReward(env, monitor)
env = CustomSkipFrame(env)
return env, env.observation_space.shape[0], len(actions)