Permalink
Browse files

DQN supports gym as well.

  • Loading branch information...
ppwwyyxx committed Sep 17, 2018
1 parent 3aab66f commit 0dbcbac7ae5dc62f89ecdbe0961731513c11f2ec
@@ -8,18 +8,19 @@
import cv2
import numpy as np
import tensorflow as tf
import gym
from tensorpack import *
from DQNModel import Model as DQNModel
from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import FrameStack, MapState, FireResetEnv
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from expreplay import ExpReplay
from atari import AtariPlayer
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
IMAGE_CHANNEL = None # 3 in gym and 1 in our own wrapper
FRAME_HISTORY = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
@@ -33,24 +34,39 @@
EVAL_EPISODE = 50
NUM_ACTIONS = None
ROM_FILE = None
USE_GYM = False
ENV_NAME = None
METHOD = None
def resize_keepdims(im, size):
# Opencv's resize remove the extra dimension for grayscale images.
# We add it back.
ret = cv2.resize(im, size)
if im.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def get_player(viz=False, train=False):
env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000)
if USE_GYM:
env = gym.make(ENV_NAME)
else:
env = AtariPlayer(ENV_NAME, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)[:, :, np.newaxis])
env = MapState(env, lambda im: resize_keepdims(im, IMAGE_SIZE))
if not train:
# in training, history is taken care of in expreplay buffer
env = FrameStack(env, FRAME_HISTORY)
if train and USE_GYM:
env = LimitLength(env, 60000)
return env
class Model(DQNModel):
def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, 1, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
super(Model, self).__init__(IMAGE_SIZE, IMAGE_CHANNEL, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image):
image = image / 255.0
@@ -86,7 +102,7 @@ def get_config():
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
state_shape=IMAGE_SIZE + (1,),
state_shape=IMAGE_SIZE + (IMAGE_CHANNEL,),
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE,
@@ -126,18 +142,21 @@ def get_config():
parser.add_argument('--load', help='load model')
parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train'], default='train')
parser.add_argument('--rom', help='atari rom', required=True)
parser.add_argument('--env', required=True,
help='either an atari rom file (that ends with .bin) or a gym atari environment name')
parser.add_argument('--algo', help='algorithm',
choices=['DQN', 'Double', 'Dueling'], default='Double')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
ROM_FILE = args.rom
ENV_NAME = args.env
USE_GYM = not ENV_NAME.endswith('.bin')
IMAGE_CHANNEL = 3 if USE_GYM else 1
METHOD = args.algo
# set num_actions
NUM_ACTIONS = AtariPlayer(ROM_FILE).action_space.n
logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))
NUM_ACTIONS = get_player().action_space.n
logger.info("ENV: {}, Num Actions: {}".format(ENV_NAME, NUM_ACTIONS))
if args.task != 'train':
assert args.load is not None
@@ -153,7 +172,7 @@ def get_config():
else:
logger.set_logger_dir(
os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(ROM_FILE).split('.')[0])))
os.path.basename(ENV_NAME).split('.')[0])))
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
@@ -26,6 +26,7 @@ Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 24
## How to use
### With ALE (paper's setting):
Install [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) and gym.
Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms), e.g.:
@@ -35,15 +36,23 @@ wget https://github.com/openai/atari-py/raw/master/atari_py/atari_roms/breakout.
Start Training:
```
./DQN.py --rom breakout.bin
./DQN.py --env breakout.bin
# use `--algo` to select other DQN algorithms. See `-h` for more options.
```
Watch the agent play:
```
# Download pretrained models or use one you trained:
wget http://models.tensorpack.com/DeepQNetwork/DoubleDQN-Breakout.npz
./DQN.py --rom breakout.bin --task play --load DoubleDQN-Breakout.npz
./DQN.py --env breakout.bin --task play --load DoubleDQN-Breakout.npz
```
### With gym's Atari:
Install gym and atari_py.
```
./DQN.py --env BreakoutDeterministic-v4
```
A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym)
@@ -95,7 +95,7 @@ def __init__(self, rom_file, viz=0,
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8)
self._restart_episode()
def get_action_meanings(self):
@@ -110,7 +110,7 @@ def _grab_raw_image(self):
def _current_state(self):
"""
:returns: a gray-scale (h, w) uint8 image
:returns: a gray-scale (h, w, 1) uint8 image
"""
ret = self._grab_raw_image()
# max-pooled over the last screen
@@ -121,7 +121,7 @@ def _current_state(self):
cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
return ret.astype('uint8') # to save some memory
def _restart_episode(self):
@@ -135,7 +135,7 @@ def __init__(self,
Args:
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (RLEnvironment): the player.
player (gym.Env): the player.
state_shape (tuple): h, w, c
history_len (int): length of history frames to concat. Zero-filled
initial frames.
@@ -30,7 +30,7 @@ def fw(x):
@tf.custom_gradient
def _sign(x):
return tf.sign(x / E) * E, lambda dy: dy
return tf.where(tf.equal(x, 0), tf.ones_like(x), tf.sign(x / E)) * E, lambda dy: dy
return _sign(x)

0 comments on commit 0dbcbac

Please sign in to comment.