# Install dependencies


In [None]:
  !sudo apt-get install -y xvfb ffmpeg
  !pip install -q 'gym==0.10.11'
  !pip install -q 'imageio==2.4.0'
  !pip install -q PILLOW
  !pip install -q 'pyglet==1.3.2'
  !pip install -q pyvirtualdisplay
  !pip install -q --upgrade tensorflow-probability
  !pip install -q tf-agents

# Imports

In [None]:
import base64
import os
import shutil
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_atari, suite_gym, tf_py_environment, batched_py_environment, parallel_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_py_policy, policy_saver, random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

In [None]:
display = pyvirtualdisplay.Display(visible=0, size=(1400,900)).start()

In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
checkpoint_dir = '/gdrive/My Drive/Atari_DQN_Data/checkpoint/'
policy_dir = '/gdrive/My Drive/Atari_DQN_Data/policy/'

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

if not os.path.exists(policy_dir):
    os.makedirs(policy_dir)

# Hyperparameters

In [None]:
num_iterations = 250000

initial_collect_steps = 200
collect_steps_per_iteration = 10
replay_buffer_max_length = 100000

batch_size = 32
learning_rate = 2.5e-3
log_interval = 5000

num_eval_episodes = 10
eval_interval = 25000

In [None]:
env_name = 'Pong-v0'

ATARI_FRAME_SKIP = 4

max_episode_frames = 108000

env = suite_atari.load(
    env_name,
    max_episode_steps = max_episode_frames / ATARI_FRAME_SKIP,
    gym_env_wrappers = suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING
)

# Take a peek at the environment

In [None]:
time_step = env.reset()
PIL.Image.fromarray(env.render())

# Create environments for training and evaluation.

In [None]:
train_py_env = suite_atari.load(
    env_name,
    max_episode_steps = max_episode_frames / ATARI_FRAME_SKIP,
    gym_env_wrappers = suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING
)

eval_py_env = suite_atari.load(
    env_name,
    max_episode_steps = max_episode_frames / ATARI_FRAME_SKIP,
    gym_env_wrappers = suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING
)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

# Create Neural Network


In [None]:
class AtariQNetwork(q_network.QNetwork):
    def call(self, observation, step_type=None, network_state=(), training=False):
        observation = tf.cast(observation, tf.float32)
        observation = observation / 255

        return super(AtariQNetwork, self).call(observation, step_type=step_type, network_state=network_state, training=training)

In [None]:
fc_layer_params = (512,)
conv_layer_params = ((32, (8,8), 4), (64, (4,4),2), (64, (3,3), 1))

q_net = AtariQNetwork(
    input_tensor_spec = train_env.observation_spec(),
    action_spec = train_env.action_spec(),
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params
)

# Optimizer

In [None]:
optimizer = tf.compat.v1.train.RMSPropOptimizer(
    learning_rate=learning_rate,
    decay = 0.95,
    momentum=0.0,
    epsilon=0.00001,
    centered = True
)

# Agent

In [None]:
observation_spec = train_env.observation_spec()
time_step_spec = train_env.time_step_spec()

action_spec = train_env.action_spec()

target_update_period = 2000

global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(
    time_step_spec = time_step_spec,
    action_spec = action_spec,
    q_network = q_net,
    optimizer = optimizer,
    epsilon_greedy = 0.01,
    n_step_update = 1.0,
    target_update_tau = 1.0,
    target_update_period = target_update_period,
    td_errors_loss_fn = common.element_wise_huber_loss,
    gamma = 0.99,
    reward_scale_factor = 1.0,
    gradient_clipping = None,
    debug_summaries = False,
    summarize_grads_and_vars = False,
    train_step_counter = global_step
)

agent.initialize()

# Replay buffer

In [None]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec = agent.collect_data_spec,
    batch_size = train_env.batch_size,
    max_length = replay_buffer_max_length
)

In [None]:
collect_driver = dynamic_step_driver.DynamicStepDriver(
    env = train_env,
    policy = agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps = collect_steps_per_iteration
)

In [None]:
random_policy = random_tf_policy.RandomTFPolicy(
    time_step_spec = train_env.time_step_spec(),
    action_spec = train_env.action_spec()
)


initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
    env = train_env,
    policy = random_policy,
    observers=[replay_buffer.add_batch],
    num_steps = initial_collect_steps
)

In [None]:
initial_collect_driver.run()

# Checkpointer


In [None]:
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

# Policy saver

In [None]:
tf_policy_saver = policy_saver.PolicySaver(agent.policy)

# Some metric for evaluation

In [None]:
def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


In [None]:
dataset = replay_buffer.as_dataset(
    num_parallel_calls = 3,
    sample_batch_size = batch_size,
    num_steps=2
).prefetch(3)

iterator = iter(dataset)

# Agent training


In [None]:
agent.train = common.function(agent.train)

agent.train_step_counter.assign(0)

avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):
    collect_driver.run()

    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = agent.train_step_counter.numpy()

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss))

    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
        print('step = {0}: Average Return = {1}'.format(step, avg_return))
        returns.append(avg_return)

        #global_step = tf.compat.v1.train.get_global_step()

        #train_checkpointer.save(global_step)

# Save the policy at the end of training so that it can be easily deployed.
#tf_policy_saver.save(policy_dir)


In [None]:
global_step = tf.compat.v1.train.get_global_step()
train_checkpointer.save(global_step)

In [None]:
tf_policy_saver.save(policy_dir)

In [None]:
iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=10)

In [None]:
def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)

In [None]:
create_policy_eval_video(agent.policy, "trained-agent")

In [None]:
create_policy_eval_video(random_policy, "random-agent")