# Play Atari games with DQN

This notebook shows how to play image based Atari games (will leave out the RAM games since the pipeline can be very different) with DQN agents.

In [1]:
import gym
import numpy as np
import tensorflow as tf

In [2]:
def play_episode(env, policy, collect=False, buffer=None, render=False):
    state = env.reset()
    done = False
    while not done:
        if render:
            env.render()
        action = policy(state)
        next_state, reward, done, info, _, _ = env.step(action)
        state = next_state
    # TODO: add logic to collect data into replay buffer

In [3]:
class DqnAgent:
    def __init__(self, action_space_size):
        self.action_space_size = action_space_size
    
    def random_policy(self, state):
        return np.random.randint(self.action_space_size)
    
    def net_policy(self, state):
        return 0
    
    def collect_policy(self, state):
        return 0
    
    def train(self, replay_buffer):
        pass

In [4]:
def sanitize_state(state):
    return state

In [5]:
def normalize_reward(reward):
    return reward

In [10]:
class GameEnv:
    """
    Although OpenAI gym library provides most of the
    things we need, the state and reward it emits is
    too primative to be used. To keep the state and
    reward conversion consistent, we wrap the OpenAI
    gym environment in a shim class.
    """
    def __init__(self, game_id):
        self.game_id = game_id
        self.env = gym.make(game_id)
        
    def get_action_space_size(self):
        return self.env.action_space.n
        
    def render(self):
        self.env.render()
        
    def reset(self):
        return self.env.reset()
    
    def close(self):
        self.env.close()
    
    def step(self, action):
        unsanitized_state, unnormalized_reward, done, info = self.env.step(action)
        state = sanitize_state(unsanitized_state)
        reward = normalize_reward(unnormalized_reward)
        return state, reward, done, info, unsanitized_state, unnormalized_reward

In [7]:
def train_agent(max_eps=5):
    env = GameEnv('Breakout-v0')
    agent = DqnAgent(env.get_action_space_size())
    for eps in range(max_eps):
        play_episode(env, agent.random_policy, render=True)
    env.close()
    print('Done')

In [8]:
train_agent()

Done
