DQN with experience replay.

In [1]:
import os
import logging
import sys
import random
import numpy as np
import random
import inspect
import cv2

from datetime import datetime
from abc import ABC, abstractmethod
from collections import deque

import tensorflow as tf

from keras.callbacks import TensorBoard
from keras.models import load_model
from keras.layers import Dense, Conv2D, Flatten
from keras.models import Sequential
from keras.optimizers import RMSprop

from atari_py import ALEInterface, get_game_path, list_games


Using TensorFlow backend.


In [2]:
class EnvManager(ABC):
    @abstractmethod
    def get_legal_actions(self):
        pass

    @abstractmethod
    def get_random_action(self):
        pass

    @abstractmethod
    def initialize_input_sequence(self):
        pass

    @abstractmethod
    def execute_action(self, action):
        pass

    @abstractmethod
    def is_game_over(self):
        pass

    @abstractmethod
    def get_observation_shape(self):
        pass


In [3]:

class ALEManager(EnvManager):

    def __init__(self, rom_name='Space_Invaders.bin', display_screen=False, frame_skip=3, color_averaging=True):
        self.logger = logging.getLogger(__name__)

        self.ale = ALEInterface()
        self.ale.setBool(b'display_screen', display_screen)
        self.ale.setInt(b'frame_skip', frame_skip)
        self.ale.setBool(b'color_averaging', color_averaging)
        self._load_rom(rom_name)
        self.actions = self.ale.getMinimalActionSet()
        self.sequence = np.empty(shape=(84, 84, 4), dtype=np.uint8)

    def _load_rom(self, rom_name):
        if rom_name in list_games():
            self.ale.loadROM(get_game_path(rom_name))
            return

        rom_path = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'ROMs', rom_name)
        if not os.path.exists(rom_path):
            self.logger.error("Invalid ROM path")
            sys.exit(1)

        self.ale.loadROM(bytes(rom_path, encoding='utf-8'))

    def _map_action(self, action):
        return self.actions[action]

    def get_legal_actions(self):
        return np.arange(len(self.actions), dtype=np.int32)

    def get_random_action(self):
        return random.choice(self.get_legal_actions())

    def initialize_input_sequence(self):
        self.ale.reset_game()
        screen = np.empty((210, 160), dtype=np.uint8)
        for i in range(4):
            self.ale.act(self._map_action(self.get_random_action()))
            self.ale.getScreenGrayscale(screen)
            preprocessed_screen = self.preprocess_screen(screen)
            self.sequence[:, :, i] = preprocessed_screen
        return self.sequence

    @staticmethod
    def preprocess_screen(screen):
        resized_screen = cv2.resize(screen, dsize=(84, 110), interpolation=cv2.INTER_AREA)
        cropped_screen = resized_screen[17:110 - 9, :]
        return cropped_screen

    def execute_action(self, action):
        """Executes the action given as parameter and returns a
        reward and a sequence of length 4 containing preprocessed screens."""
        screen = np.empty((210, 160), dtype=np.uint8)
        reward = self.ale.act(self._map_action(action))
        self.ale.getScreenGrayscale(screen)
        preprocessed_screen = self.preprocess_screen(screen)
        self.sequence[:, :, :3] = self.sequence[:, :, 1:]
        self.sequence[:, :, -1] = preprocessed_screen

        return reward, self.sequence

    def is_game_over(self):
        return self.ale.game_over()

    def get_observation_shape(self):
        return (84, 84, 4)



In [4]:
class DQN(object):
    def __init__(self, input_shape, output_units, save_model_dir='models', save_model_name='model.h5',
                 load_model_dir=None, load_model_name=None):
        self.input_shape = input_shape
        self.output_units = output_units
        self.save_model_dir = save_model_dir
        self.save_model_name = save_model_name
        self.load_model_dir = load_model_dir
        self.load_model_name = load_model_name
        self.model = self._load_model()

        if not os.path.exists(self.save_model_dir):
            os.makedirs(self.save_model_dir)

    def _load_model(self):
        if self.load_model_dir is None or self.load_model_name is None:
            print("Creating new neural-network")
            return self.get_q_network()

        model_name = os.path.join(self.load_model_dir, self.load_model_name)

        if os.path.exists(model_name):
            print("Loading existing model, " + str(model_name))
            return load_model(model_name)

        raise Exception("Model could not be loaded.")

    @abstractmethod
    def get_q_network(self):
        pass

    def get_prediction(self, preprocessed_input):
        return self.model.predict(np.expand_dims(preprocessed_input, 0))[0]

    def get_predicted_action(self, preprocessed_input):
        return np.argmax(self.get_prediction(preprocessed_input))

    def prepare_minibatch(self, transitions_minibatch, gamma):
        expected_output_minibatch = []
        input_minibatch = []

        for current_input, action, reward, next_input, is_terminal_state in transitions_minibatch:
            q_value = reward
            if not is_terminal_state:
                q_value += gamma * np.amax(self.get_prediction(next_input))
            prediction = self.get_prediction(current_input)
            prediction[action] = q_value
            expected_output_minibatch.append(prediction)
            input_minibatch.append(current_input)

        expected_output_minibatch = np.array(expected_output_minibatch)
        input_minibatch = np.array(input_minibatch)

        return input_minibatch, expected_output_minibatch

    def perform_gradient_descent_step(self, _input, _output):
        self.model.fit(x=_input, y=_output, epochs=1, verbose=0)

    def save_model(self, step=''):
        model_name = os.path.join(self.save_model_dir, (str(step) + '--' + self.save_model_name))
        self.model.save(model_name)


In [5]:
class DQNSpaceInvaders(DQN):

    def __init__(self, input_shape, output_units, save_model_dir="models/space_invaders", save_model_name='model.h5',
                 load_model_dir=None, load_model_name=None):
        super().__init__(input_shape, output_units, save_model_dir=save_model_dir, save_model_name=save_model_name,
                         load_model_dir=load_model_dir, load_model_name=load_model_name)

    def get_q_network(self):
        model = Sequential()
        model.add(
            Conv2D(filters=16, kernel_size=(8, 8), strides=(4, 4), input_shape=self.input_shape, activation='relu'))
        model.add(Conv2D(filters=32, kernel_size=(4, 4), strides=(2, 2), activation='relu'))
        model.add(Flatten())
        model.add(Dense(units=256, activation='relu'))
        model.add(Dense(units=self.output_units))

        model.compile(loss="mse", optimizer=RMSprop())

        return model


In [6]:
class DQNBreakout(DQNSpaceInvaders):
    def __init__(self, input_shape, output_units, save_model_dir, save_model_name, load_model_dir, load_model_name):
        super().__init__(input_shape, output_units, save_model_dir, save_model_name, load_model_dir, load_model_name)

In [7]:

class DeepQLearningAgent(object):
    def __init__(self, env_manager=ALEManager, q_network=DQNSpaceInvaders, num_total_episode=10000,
                 episode_starts_from=0, epsilon_decay_rate=9.000000000000001e-07, save_model_step=100, epsilon=1.,
                 logdir=None):
        self.minibatch_size = 32
        self.experience_replay_memory = deque([], maxlen=1000000)
        self.env_manager = env_manager() if inspect.isclass(env_manager) else env_manager
        self.possible_actions = self.env_manager.get_legal_actions()
        self.input_shape = self.env_manager.get_observation_shape()
        self.output_units = len(self.possible_actions)
        self.DQN = q_network(input_shape=self.input_shape, output_units=self.output_units) if inspect.isclass(
            q_network) else q_network
        self.epsilon = float(epsilon)
        self.gamma = 0.9
        self.num_total_episode = num_total_episode
        self.n_episode = episode_starts_from
        self.epsilon_decay_rate = epsilon_decay_rate
        logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S") if logdir is None else logdir
        self.file_writer = tf.summary.create_file_writer(logdir=logdir)
        self.save_model_step = save_model_step

    def update_epsilon(self):
        if self.epsilon < 0.1:
            self.epsilon = 0.1
            return
        elif self.epsilon == 0.1:
            return
        else:
            self.epsilon -= self.epsilon_decay_rate

    def e_greedy_select_action(self, preprocessed_input):
        if random.random() <= self.epsilon:
            action = self.env_manager.get_random_action()
        else:
            action = self.DQN.get_predicted_action(preprocessed_input)

        self.update_epsilon()

        return action

    def learn_with_experience_replay(self):
        """vanilla deep_q_learning_with_experience_replay"""
        while self.n_episode < self.num_total_episode:
            preprocessed_input = self.env_manager.initialize_input_sequence()
            cumulative_reward = 0
            episode_q_value_list = []
            while not self.env_manager.is_game_over():
                action = self.e_greedy_select_action(preprocessed_input)
                reward, next_preprocessed_input = self.env_manager.execute_action(action)

                cumulative_reward += reward
                q_value_for_selected_action = self.DQN.get_prediction(preprocessed_input)[action]
                episode_q_value_list.append(q_value_for_selected_action)

                self.experience_replay_memory.append(
                    (preprocessed_input, action, reward, next_preprocessed_input, self.env_manager.is_game_over()))

                preprocessed_input = next_preprocessed_input

                if len(self.experience_replay_memory) > self.minibatch_size:
                    sample_minibatch = random.sample(self.experience_replay_memory, k=self.minibatch_size)
                    _input, _output = self.DQN.prepare_minibatch(sample_minibatch, self.gamma)
                    self.DQN.perform_gradient_descent_step(_input, _output)

            avg_q_value_per_action = sum(episode_q_value_list) / float(len(episode_q_value_list))

            with self.file_writer.as_default():
                tf.summary.scalar('Return per episode', cumulative_reward, step=self.n_episode)
                tf.summary.scalar('Average q_value', avg_q_value_per_action, step=self.n_episode)
                tf.summary.scalar('epsilon', self.epsilon, step=self.n_episode)
                tf.summary.flush()

            if ((self.n_episode + 1) % self.save_model_step) == 0:
                self.DQN.save_model('-episode:' + str(self.n_episode + 1) + '-epsilon:' + str(self.epsilon))

            self.n_episode += 1



In [8]:
model_dir = "models/breakout/"
if not os.path.exists(model_dir):
    print(model_dir + " does not exist.")

In [9]:
logdir = "logs/scalars/breakout/"
if not os.path.exists(logdir):
    print(logdir + " does not exist.")

In [10]:
env_manager = ALEManager(rom_name='breakout', frame_skip=4)

save_model_dir = model_dir
save_model_name="breakout_dqn.h5"
load_model_dir = model_dir
load_model_name = '-episode_5100-epsilon_0.13606480000951018--breakout_dqn.h5'
episode_start_from = 5100 + 1
epsilon = 0.13606480000951018
input_shape=env_manager.get_observation_shape()
output_units = len(env_manager.get_legal_actions())
q_network = DQNBreakout(input_shape=input_shape, output_units=output_units, save_model_dir=save_model_dir, save_model_name=save_model_name, load_model_dir=load_model_dir, load_model_name=load_model_name)

num_total_episode = 10000
epsilon_decay_rate=9.000000000000001e-07
save_model_step=100


Loading existing model, models/breakout/-episode_5100-epsilon_0.13606480000951018--breakout_dqn.h5


In [11]:
breakout_agent = DeepQLearningAgent(env_manager=env_manager, q_network=q_network, num_total_episode=num_total_episode, epsilon_decay_rate=epsilon_decay_rate, save_model_step=save_model_step, epsilon=epsilon, logdir=logdir, episode_starts_from=episode_start_from)

In [12]:
breakout_agent.learn_with_experience_replay()

KeyboardInterrupt: 