# Breakout (Reinforcement Learning)

In [1]:
import sys

# Colab setup
if 'google.colab' in sys.modules:
    %pip install -q -U gymnasium
    %pip install -q -U gymnasium[classic_control,box2d,atari,accept-rom-license]
    
import warnings
import logging
import numpy as np
import sklearn
import tensorflow as tf
from tensorflow import keras
from tf_agents.environments import suite_gym, suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4
from tf_agents.environments.tf_py_environment import TFPyEnvironment
from tf_agents.networks.q_network import QNetwork
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.metrics import tf_metrics
from tf_agents.eval.metric_utils import log_metrics
import gym
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns

# Make notebook output stable across runs
random_state = 1000
np.random.seed(random_state)
tf.random.set_seed(random_state)

# Plot settings
%matplotlib inline
sns.set()
mpl.rc('font', size=14)
mpl.rc('axes', labelsize=14, titlesize=14)
mpl.rc('legend', fontsize=14)
mpl.rc('xtick', labelsize=10)
mpl.rc('ytick', labelsize=10)
mpl.rc('animation', html='jshtml')

In [2]:
# Utility functions

def show_env(env):
    img = env.render(mode='rgb-array')
    plt.figure(figsize=(6, 8))
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    
    
def update_scene(num, frames, patch):
    patch.set_data(frames[num])
    return patch


def plot_animation(frames, repeat=False, interval=100):
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')
    anim = animation.FuncAnimation(fig, update_scene, fargs=(frames, patch),
                                   frames=len(frames), repeat=repeat, 
                                   interval=interval)
    plt.close()
    return anim       

## Running Breakout

In [3]:
# To install Breakout, may need to run:
# pip install 'gym[atari, accept-rom-license]'

#warnings.filterwarnings('ignore')
env = gym.make('BreakoutNoFrameskip-v4', render_mode='human')

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [4]:
env.action_space

Discrete(4)

In [5]:
env.observation_space

Box(0, 255, (210, 160, 3), uint8)

In [6]:
_ = env.reset()

In [10]:
_, _, _, _ = env.step(1)

ValueError: too many values to unpack (expected 4)

In [None]:
print(Observation Space: ", env.observation_space)
print("Action Space       ", env.action_space)


obs = env.reset()

for i in range(1000):
    action = env.action_space.sample()
    obs, reward, done, info = env.step(action)
    time.sleep(0.01)
env.close()

## How An Agent Can Play Breakout

In [None]:
max_episode_steps = 27000 # 108,000 frames, since 1 step is 4 frames
environment_name = 'BreakoutNoFrameskip-v4'


class AtariPreprocessingWithAutoFire(AtariPreprocessing):
    def reset(self, **kwargs):
        obs = super().reset(**kwargs)
        super().step(1) # Action fire to start
        return obs
    
    
    def step(self, action):
        lives_before_action = self.ale.lives()
        obs, rewards, done, info = super().step(action)
        if self.ale.lives() < lives_before_action and not done:
            super().step(1) # Action fire to start after life lost
        return obs, rewards, done, info

    
env = suite_atari.load(
    environment_name,
    max_episode_steps=max_episode_steps,
    gym_env_wrappers=[AtariPreprocessingWithAutoFire, FrameStack4]
)

In [None]:
# Visualize a series of moves

def plot_observation(obs):
    # Since there are only 3 color channels, you cannot display 4 frames
    # with one primary color per frame. So this code computes the delta between
    # the current frame and the mean of the other frames, and it adds this delta
    # to the red and blue channels to get a pink color for the current frame.
    obs = obs.astype(np.float32)
    img = obs[..., :3]
    current_frame_delta = np.maximum(obs[..., 3]
                                     - obs[..., :3].mean(axis=-1), 0.)
    img[..., 0] += current_frame_delta
    img[..., 2] += current_frame_delta
    img = np.clip(img / 150, 0, 1)
    plt.imshow(img)
    plt.axis("off")
    
    
env.seed(random_state)
env.reset()
for _ in range(4):
    time_step = env.step(3) # Action 3 is "move left"
    
plt.figure(figsize=(6, 6))
plot_observation(time_step.observation)
plt.show()

## Training an Agent to Play Breakout

In [None]:
tf_env = TFPyEnvironment(env)

preprocessing_layer = \
    keras.layers.Lambda(lambda obs: tf.cast(obs, np.float32) / 255.)

conv_layer_params=[(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]

fc_layer_params=[512]

q_net = QNetwork(
    tf_env.observation_spec(),
    tf_env.action_spec(),
    preprocessing_layers=preprocessing_layer,
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params
)

In [None]:
train_step = tf.Variable(0)
update_period = 4 # run a training step every 4 collect steps

optimizer = keras.optimizers.RMSprop(learning_rate=2.5e-4, rho=0.95,
                                     momentum=0.0, epsilon=0.00001,
                                     centered=True)

epsilon_fn = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=1.0,  # initial ε
    decay_steps=250000 // update_period,  # <=> 1,000,000 frames
    end_learning_rate=0.01  # final ε
) 

agent = DqnAgent(
    tf_env.time_step_spec(),
    tf_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    target_update_period=2000,  # <=> 32,000 frames
    td_errors_loss_fn=keras.losses.Huber(reduction='none'),
    gamma=0.99,  # Discount factor
    train_step_counter=train_step,
    epsilon_greedy=lambda: epsilon_fn(train_step)
)

agent.initialize()

In [None]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=100000  # Reduce if memory error
) 

replay_buffer_observer = replay_buffer.add_batch

In [None]:
class ShowProgress:
    def __init__(self, total):
        self.counter = 0
        self.total = total
    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % 100 == 0:
            print('\r{}/{}'.format(self.counter, self.total), end='')

In [None]:
train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]

In [None]:
logging.getLogger().setLevel(logging.INFO)
log_metrics(train_metrics)

In [None]:
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

collect_driver = DynamicStepDriver(
    tf_env,
    agent.collect_policy,
    observers=[replay_buffer_observer] + train_metrics,
    num_steps=update_period  # Collect 4 steps for each training iteration
) 

In [None]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                        tf_env.action_spec())

init_driver = DynamicStepDriver(
    tf_env,
    initial_collect_policy,
    observers=[replay_buffer.add_batch, ShowProgress(20000)],
    num_steps=20000  # <=> 80,000 frames
)  

final_time_step, final_policy_state = init_driver.run()

In [None]:
# Seed chosen to show an example of trajectory at the end of an episode
tf.random.set_seed(9) 

trajectories, buffer_info = next(iter(replay_buffer.as_dataset(
    sample_batch_size=2,
    num_steps=3,
    single_deterministic_pass=False
)))

In [None]:
plt.figure(figsize=(10, 6.8))
for row in range(2):
    for col in range(3):
        plt.subplot(2, 3, row * 3 + col + 1)
        plot_observation(trajectories.observation[row, col].numpy())
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0.02)
plt.show()

In [None]:
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2,
    num_parallel_calls=3
).prefetch(3)

In [None]:
from tf_agents.utils.common import function

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)


def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        print("\r{} loss:{:.5f}".format(
            iteration, train_loss.loss.numpy()), end="")
        if iteration % 1000 == 0:
            log_metrics(train_metrics)            

In [None]:
train_agent(n_iterations=100)

In [None]:
frames = []
def save_frames(trajectory):
    global frames
    frames.append(tf_env.pyenv.envs[0].render(mode='rgb_array'))

watch_driver = DynamicStepDriver(
    tf_env,
    agent.policy,
    observers=[save_frames, ShowProgress(1000)],
    num_steps=1000)
final_time_step, final_policy_state = watch_driver.run()

plot_animation(frames)