# Imports

In [1]:
import keras
import numpy as np
import gymnasium as gym
import tensorflow as tf
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
from typing import Optional, Union, Tuple
import ale_py
import numpy as np 
from PIL import Image
import gymnasium.utils.save_video

2025-03-27 07:40:41.536941: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743025241.553115  225830 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743025241.557659  225830 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743025241.569817  225830 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743025241.569846  225830 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743025241.569848  225830 computation_placer.cc:177] computation placer alr

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
              tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [3]:
def get_name(prefix: Optional[str] = None, suffix: Optional[str] = None, separator: str = '_') -> Optional[str]:
    return prefix and prefix + separator + suffix or suffix or None

# Environment

Create the environment. You can use any ATARI environment from [here](https://gymnasium.farama.org/environments/atari/), but prefer to use environments with discrete action space with fewer actions.

In [4]:
batch_size = 64

In [5]:
gym.register_envs(ale_py)

In [6]:
train_env = gym.make_vec("ALE/Boxing-v5", render_mode='rgb_array', num_envs=batch_size)
valid_env = gym.make_vec("ALE/Boxing-v5", render_mode='rgb_array', num_envs=1)

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


# Replay Buffer

Create a replay buffer to hold game history

In [7]:
class ReplayBuffer:

    def __init__(self, max_size: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, seed: Optional[int] = None):
        """Stores the replay history with a maximum of `max_size` entries, removing old entries as needed.

        Parameters:
            max_size: maximal number of entries to keep
            observation_space: specification of the observation space
            action_space: specification of the action space
            seed: seed to initialize the internal random number generator for reproducibility"""

        self.current_observations = np.zeros((max_size, *observation_space.shape), dtype=observation_space.dtype)
        self.next_observations = np.zeros((max_size, *observation_space.shape), dtype=observation_space.dtype)
        self.actions = np.zeros((max_size, *action_space.shape), dtype=action_space.dtype)
        self.rewards = np.zeros((max_size,), dtype=np.float32)
        self.dones = np.zeros((max_size,), dtype=np.float32)
        
        self.max_size = max_size
        self.rng = np.random.default_rng(seed=seed)
        self.buffer_pointer = 0
        self.current_size = 0
        
    def add(self, current_observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_observations: np.ndarray, dones: np.ndarray) -> None:
        """Add a new entry to the buffer.

        Parameters:
            current_observations: environment state observed at the current step
            actions: action taken by the model
            rewards: reward received after taking the action
            next_observations: environment state obversed after taking the action
            dones: whether the episode has ended or not"""

        batch_size = current_observations.shape[0]
        idxs = (np.arange(batch_size) + self.buffer_pointer) % self.max_size

        self.current_observations[idxs] = current_observations
        self.actions[idxs] = actions
        self.rewards[idxs] = rewards
        self.next_observations[idxs] = next_observations
        self.dones[idxs] = dones

        self.buffer_pointer = (idxs[-1] + 1) % self.max_size
        self.current_size = min(self.max_size, self.current_size + batch_size)
    
    def sample(self, n_samples: int, replace: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Randomly samples `n_samples` from the buffer.

        Parameters:
            n_samples: number of samples to select
            replace: sample with or without replacement

        Returns:
            current observations, actions, rewards, next observations, dones"""

        return self[self.rng.choice(self.current_size, size=n_samples, replace=replace)]

    def clear(self) -> None:
        """Clears the buffer"""

        self.buffer_pointer = 0
        self.current_size = 0

    def __getitem__(self, index: Union[int , np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Gets a sample at `index`

        Parameters:
            index: index of the sample to get

        Returns:
            current observations, actions, rewards, next observations, dones"""

        return (
            self.current_observations[index],
            self.actions[index],
            self.rewards[index],
            self.next_observations[index],
            self.dones[index]
        )
        
    def __len__(self) -> int:
        """Returns the number of entries in the buffer"""

        return self.current_size

# Model

Implement your model. Most if not all ATARI environments have an image observation

In [8]:
 def get_model(
    observation_space: gym.spaces.Space,
    action_space: gym.spaces.Space,
    features: int,
    blocks: int, 
    activation: str | keras.layers.Activation | None = 'silu',
    dropout_rate: float = 0.,
    multiply_freq: int = 1,
    name: str | None = None
 ) -> keras.Model:
    inputs = keras.layers.Input(observation_space.shape, name=get_name(name, "Input"), dtype=observation_space.dtype)

    x = keras.layers.Rescaling(1.0 / 255, name=get_name(name, "Rescaling"))(inputs)

    num_filters = features

    for idx in range(blocks):
        num_filters = (idx + 1) * num_filters

        x = keras.layers.Conv2D(num_filters, 3, padding="same", name=get_name(name, f"Conv2D_{idx}"))(x)
        x = keras.layers.MaxPooling2D((2, 2), name=get_name(name, f"MaxPooling2D_{idx}"))(x)
        
    x = keras.layers.BatchNormalization(name=get_name(name, f"BatchNormalization_{idx}"))(x)
    x = keras.layers.GlobalAveragePooling2D(name=get_name(name, "GlobalAveragePooling2D"))(x)
        
    outputs = keras.layers.Dense(action_space.n, name=get_name(name, "Output"))(x)
    
    return keras.Model(inputs, outputs, name=name)

# Sampler

Implement the sampler

In [9]:
class Sampler:
    
    def __init__(self, epsilon: float, seed: int | None = None, greedy: bool = False):
        """Selects a random action with probability `epsilon` otherwise if `greedy` selects the most probable action given by the model,
        if not `greedy` samples the action from the distribution given by the model.

        Parameters:
            epsilon: the probability to select a random action
            seed: seed to initialize the internal random number generator for reproducibility
            greedy: select the most probable action"""
        self.epsilon = epsilon
        self.greedy = greedy
        self.rng = np.random.default_rng(seed=seed)

    def _sample(self, probabilities: np.ndarray) -> np.ndarray:
        cumsum = np.cumsum(probabilities, axis=-1)
        probs = self.rng.uniform(size=(cumsum.shape[0], 1))
        msk = cumsum > probs
        return np.argmax(msk, axis=-1)

    def __call__(self, probabilities: np.ndarray) -> np.ndarray:
        """Select an action given the `probabilities`

        Parameters:
            probabilities: batch of probabilities for each action

        Returns:
            batch of indices of the selected actions"""
        batch, *_ = probabilities.shape
        probs = self.rng.uniform(size=batch)

        return np.where(
            probs < self.epsilon,
            self.rng.choice(probabilities.shape[1], size=batch),
            np.argmax(probabilities, axis=-1) if self.greedy else self._sample(probabilities),
        )

# Play the game

Implement interacting with the environment and storing entries to the replay buffer

In [10]:
def play_game(
    model: keras.Model,
    buffer: ReplayBuffer,
    env: gym.Env,
    steps: int,
    sampler: Sampler,
    observations: np.ndarray | None = None,
    one_episode: bool = False
) -> np.ndarray:
    """Plays the environment `env` using model `model` for a total of `steps` steps.

    Arguments:
        model: model to use
        buffer: buffer to store results to
        env: environment to use
        steps: total number of steps to record
        sampler: sampler to use
        observations: observation to start from
        one_episode: exist as soon as one of the environments finishes

    Returns:
        the last observations
    """
    if observations is None:
        observations, _ = env.reset()

    for _ in range(steps // env.num_envs):
        res = model(observations, training=False).numpy()
        actions = sampler(res)
        new_observations, rewards, terminateds, truncated, _ = env.step(actions) # new_observations, rewards, terminateds, truncated, info

        dones = terminateds | truncated

        buffer.add(
            current_observations=observations,
            actions=actions,
            rewards=rewards,
            next_observations=new_observations,
            dones=dones,
        )
        
        if one_episode and np.any(dones):
            observations = None
            break
        
        observations = new_observations
    
    return observations

# Loss

Implement double q learning loss. Don't forget to stop the gradient for q_ref

In [11]:
class QDataset(keras.utils.PyDataset):

    def __init__(
        self,
        steps_per_epoch: int,
        batch_size: int,
        buffer: ReplayBuffer,
        action_space: gym.spaces.Space,
    ):
        super().__init__()
        self.buffer = buffer
        self.batch_size = batch_size
        self.steps_per_epoch = steps_per_epoch

        self._eye = np.eye(action_space.n)
        self._answer = np.zeros(batch_size, dtype=np.float32)

    def __len__(self):
        return self.steps_per_epoch

    def __getitem__(self, idx: int) -> np.ndarray:
        a, b, c, d, e = self.buffer.sample(self.batch_size)
        b = self._eye[b]

        return (a, b, c, d, e), self._answer

In [12]:
class QLoss(keras.layers.Layer):

    def __init__(self, gamma: float, **kwargs):
        super().__init__(**kwargs)
        
        self.gamma = gamma

    def call(self, q_current: keras.KerasTensor, q_next: keras.KerasTensor, rewards: keras.KerasTensor, actions: keras.KerasTensor, dones: keras.KerasTensor) -> keras.KerasTensor:
        q_ref = rewards + self.gamma * (1 - dones) * keras.ops.max(q_next, axis=-1)

        return keras.ops.square(keras.ops.sum(q_current * actions, axis=-1) - q_ref)

    def get_config(self) -> dict:
        config = super().get_config()
        config['gamma'] = self.gamma
        return config

In [13]:
def get_combined_model(
    observation_space: gym.spaces.Space,
    action_space: gym.spaces.Space,    
    features: int,
    blocks: int, 
    activation: str | keras.layers.Activation | None = 'silu',
    dropout_rate: float = 0.,
    multiply_freq: int = 1,
    name: str | None = None,
    gamma: float = 0.99
) -> keras.Model:
    """Creates a combined model for Q-learning training

    Arguments:
        input_features: model input vector size
        features: initial model embedding size
        out_features: model output vector size
        block: number of perceptron layers
        activation: intermediate model activation
        dropout_rate: nuff said
        multifly_freq: doubles embedding size every `multiply_freq` blocks
        name: model name,
        gamma: rewards discount
    Returns:
        A q-model, a target q-model, a combined model
    """
    model = get_model(
        observation_space=observation_space,
        action_space=action_space,
        features=features,
        blocks=blocks,
        activation=activation,
        dropout_rate=dropout_rate,
        multiply_freq=multiply_freq,
        name=get_name(name, "model"),
    )
    
    target_model = get_model(
        observation_space=observation_space,
        action_space=action_space,
        features=features,
        blocks=blocks,
        activation=activation,
        multiply_freq=multiply_freq,
        name=get_name(name, "target_model"),
    )

    target_model.trainable = False
    target_model.set_weights(model.get_weights())

    current_observation = keras.layers.Input(observation_space.shape, dtype=observation_space.dtype, name=get_name(name, "curr_observ"))
    next_observation = keras.layers.Input(observation_space.shape, dtype=observation_space.dtype, name=get_name(name, "next_observ"))
    current_action = keras.layers.Input((action_space.n,), dtype="float32", name=get_name(name, "curr_action"))
    rewards = keras.layers.Input((), dtype="float32", name=get_name(name, "rewards"))
    dones = keras.layers.Input((), dtype="float32", name=get_name(name, "dones"))

    model_res = model(current_observation)
    target_model_res = target_model(next_observation)

    loss = QLoss(gamma=gamma, name=get_name(name, "Q_loss"))(model_res, target_model_res, rewards, current_action, dones)

    return (
        model, 
        target_model, 
        keras.Model(
            inputs=[
                current_observation,
                current_action,
                rewards,
                next_observation,
                dones,
            ],
            outputs=loss,
            name=name
        ),
    )

In [14]:
def play_game(
    model: keras.Model,
    buffer: ReplayBuffer,
    env: gym.Env,
    steps: int,
    sampler: Sampler,
    observations: np.ndarray | None = None,
    one_episode: bool = False
) -> np.ndarray:
    """Plays the environment `env` using model `model` for a total of `steps` steps.

    Arguments:
        model: model to use
        buffer: buffer to store results to
        env: environment to use
        steps: total number of steps to record
        sampler: sampler to use
        observations: observation to start from
        one_episode: exist as soon as one of the environments finishes

    Returns:
        the last observations
    """
    if observations is None:
        observations, _ = env.reset()

    for _ in range(steps // env.num_envs):
        res = model(observations, training=False).numpy()
        actions = sampler(res)
        new_observations, rewards, terminateds, truncated, _ = env.step(actions) # new_observations, rewards, terminateds, truncated, info

        dones = terminateds | truncated

        buffer.add(
            current_observations=observations,
            actions=actions,
            rewards=rewards,
            next_observations=new_observations,
            dones=dones,
        )
        
        if one_episode and np.any(dones):
            observations = None
            break
        
        observations = new_observations
    
    return observations

In [15]:
class ModelUpdateCallback(keras.callbacks.Callback):

    def __init__(self, frequency: int, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_train_batch_end(self, batch: int, logs=None):
        if (batch + 1) % self.frequency == 0:
            self.model.get_layer(
                f'{self.model.name}_target_model'
            ).set_weights(
                self.model.get_layer(
                    f'{self.model.name}_model'
                ).get_weights()
            )

In [16]:
class ModelEvalCallback(keras.callbacks.Callback):

    def __init__(self, env: gym.Env, max_steps: int, **kwargs):
        super().__init__(**kwargs)
        self.env = env
        self.max_steps = max_steps

        self.buffer = ReplayBuffer(
            max_size=max_steps,
            observation_space=env.single_observation_space,
            action_space=env.single_observation_space,
        )

        self.sampler = Sampler(epsilon=0., greedy=True)

    def on_epoch_end(self, epoch: int, logs=None):
        if logs is None:
            return

        model = self.model.get_layer(
            f'{self.model.name}_model'
        )

        self.buffer.clear()

        play_game(
            model=model,
            buffer=self.buffer,
            env=self.env,
            steps=self.max_steps,
            sampler=self.sampler,
            one_episode=True,
        )

        rewards = self.buffer.rewards[:len(self.buffer)].reshape((-1, self.env.num_envs))
        logs['score'] = np.mean(np.sum(rewards, axis=0))

In [17]:
class SometimesPlayCallback(keras.callbacks.Callback):
    def __init__(
        self,
        env: gym.Env,
        buffer: ReplayBuffer,
        sampler: Sampler,
        frequency: int,
        steps_per_play: int,
        epsilon_decay: keras.optimizers.schedules.PolynomialDecay | None = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.env = env
        self.buffer = buffer
        self.sampler = sampler
        self.frequency = frequency
        self.steps_per_play = steps_per_play
        self.epsilon_decay = epsilon_decay

        self.observations = None

    def on_epoch_begin(self, epoch: int, logs=None):
        if self.epsilon_decay is not None:
            self.sampler.epsilon = self.epsilon_decay(epoch)
        if epoch == 0:
            model = self.model.get_layer(
                f'{self.model.name}_model'
            )

            self.buffer.clear()
            
            self.observations = play_game(
                model=model,
                buffer=self.buffer,
                env=self.env,
                steps=self.steps_per_play,
                sampler=self.sampler,
                observations=self.observations
            )

    def on_train_batch_start(self, batch_num: int, logs=None):
        if (batch_num + 1) % self.frequency:
            return
        
        model = self.model.get_layer(
            f'{self.model.name}_model'
        )

        self.observations = play_game(
            model=model,
            buffer=self.buffer,
            env=self.env,
            steps=self.steps_per_play,
            sampler=self.sampler,
            observations=self.observations
        )

# Training

Create models, replay buffers, sampler, optimizer, epsilon decay etc. Implement training loop, show training progress and perform model evaluation once in a while

In [18]:
model, target_model, combined_model = get_combined_model(
    observation_space=train_env.single_observation_space,
    action_space=train_env.single_action_space,
    features=32,
    blocks=4,
    activation='relu',
    dropout_rate=0.2,
    multiply_freq=1,
    name='enduro',
    gamma=0.99
)

I0000 00:00:1743025290.628208  225830 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5660 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050, pci bus id: 0000:09:00.0, compute capability: 8.6


In [19]:
buffer_size = 2 ** 15     # 32_768
steps_per_play = 2 ** 15  # 2048
steps_per_epoch = 2 ** 10 # 1024

end_epsilon = 0.1
epochs = 128
decay_epochs = epochs // 2
update_frequency = 512
play_frequency = 256

In [20]:
sampler = Sampler(epsilon=1, seed=42)

In [21]:
epsilon_decay = keras.optimizers.schedules.PolynomialDecay(1., decay_epochs, end_learning_rate=end_epsilon)

In [22]:
train_buffer = ReplayBuffer(
    max_size=buffer_size,
    observation_space=train_env.single_observation_space,
    action_space=train_env.single_action_space
)

In [23]:
dataset = QDataset(
    steps_per_epoch=steps_per_epoch,
    batch_size=train_env.num_envs,
    buffer=train_buffer,
    action_space=train_env.single_action_space
)

In [24]:
callbacks = [
    ModelUpdateCallback(update_frequency),
    ModelEvalCallback(valid_env, 1000),
    SometimesPlayCallback(
        env=train_env,
        buffer=train_buffer,
        sampler=sampler,
        frequency=play_frequency,
        steps_per_play=steps_per_play,
        epsilon_decay=epsilon_decay,
    )
]

In [25]:
combined_model.compile(
    loss=(lambda y_true, y_pred: keras.ops.mean(y_pred)),
    optimizer=keras.optimizers.Adam(1e-3, clipnorm=5, weight_decay=2e-5),
)

In [26]:
_ = play_game(
    model=model,
    buffer=train_buffer,
    env=train_env,
    steps=steps_per_play,
    sampler=sampler,
)

I0000 00:00:1743025310.487942  225830 cuda_dnn.cc:529] Loaded cuDNN version 90300


In [27]:
combined_model.fit(dataset, callbacks=callbacks, epochs=epochs)

Epoch 1/128


I0000 00:00:1743025409.320450  225928 service.cc:152] XLA service 0x7fbdb8004670 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1743025409.320490  225928 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 3050, Compute Capability 8.6
2025-03-27 07:43:29.369065: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.



I0000 00:00:1743025425.223512  225928 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m166s[0m 145ms/step - loss: 1.7980 - score: -24.0000
Epoch 2/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 146ms/step - loss: 0.4306 - score: -28.0000
Epoch 3/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m149s[0m 145ms/step - loss: 0.2641 - score: -39.0000
Epoch 4/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 147ms/step - loss: 0.1900 - score: -31.0000
Epoch 5/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m154s[0m 150ms/step - loss: 3.4870 - score: -24.0000
Epoch 6/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 146ms/step - loss: 0.2120 - score: -2.0000
Epoch 7/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 147ms/step - loss: 0.5903 - score: -28.0000
Epoch 8/128
[1m1024/1024[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 147ms/step - loss: 11.8502 - score: -10.0000
Epoc

<keras.src.callbacks.history.History at 0x7fbe707ebcd0>

# Testing

Test the model on the environment and get a cool video

In [33]:
loaded_target_model = keras.saving.load_model('./models/target_model-v2.keras')

In [42]:
saving_steps = 1750

In [43]:
saving_env = gym.make_vec("ALE/Boxing-v5", render_mode='rgb_array_list', num_envs=1)

In [44]:
def save_gameplay(
    model: tf.keras.Model,
    max_steps: int = 1000,
    env: gym.Env | None = None,
):
    save_buffer = ReplayBuffer(
        max_size=buffer_size,
        observation_space=train_env.single_observation_space,
        action_space=train_env.single_action_space
    )
    
    save_sampler = Sampler(1, greedy=False)
    
    _ = play_game(
        model=model,
        env=env,
        steps=max_steps,
        sampler=save_sampler,
        buffer=save_buffer
    )
    
    render = env.render()

    gym.utils.save_video.save_video(
        frames=render[0],
        video_folder="videos",
        fps=env.metadata["render_fps"],
    )

In [45]:
save_gameplay(loaded_target_model, max_steps=saving_steps, env=saving_env)

In [32]:
model.save('./models/model-v2.keras')
target_model.save('./models/target_model-v2.keras')

In [64]:
model_test = get_model(
    observation_space=train_env.single_observation_space,
    action_space=train_env.single_action_space,
    features=32,
    blocks=4,
    activation='relu',
    dropout_rate=0.2,
    multiply_freq=1,
    name='enduro',
)

In [65]:
model_test.summary()