# Comparing the Performance of Different Implementations

This notebook was used to compare the performance of different implementations that do the same to improve the performance of the code.

In [1]:
import os.path

import numpy as np

## Transposing the sample extracted from a replay buffer

In [23]:
sample_batch = [
    (np.array([0, 0, 0]), 0, 0, np.array([0, 0, 1]), False),
    (np.array([1, 1, 1]), 1, 1, np.array([1, 1, 1]), False),
    (np.array([2, 2, 2]), 2, 2, np.array([2, 2, 1]), False),
    (np.array([3, 3, 3]), 3, 3, np.array([3, 3, 1]), False),
    (np.array([4, 4, 4]), 4, 4, np.array([4, 4, 1]), True),
    (np.array([5, 5, 5]), 5, 5, np.array([5, 5, 1]), False),
] * 1000

In [24]:
%%timeit

states1, actions1, rewards1, next_states1, dones1 = [], [], [], [], []
for s, a, r, n, d in sample_batch:
    states1.append(np.array(s, copy=False))
    actions1.append(np.array(a, copy=False))
    rewards1.append(np.array(r, copy=False))
    next_states1.append(np.array(n, copy=False))
    dones1.append(np.array(d, copy=False))

7.45 ms ± 483 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
%%timeit
states2, actions2, rewards2, next_states2, dones2 = list(map(list, zip(*sample_batch)))

213 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Can uint8 be used instead of int16?

In [9]:
import gymnasium as gym
from tqdm.auto import tqdm
import numpy as np

In [10]:
env = gym.make("BreakoutNoFrameskip-v4")
env = gym.wrappers.AtariPreprocessing(env, noop_max=30, frame_skip=4, screen_size=84, terminal_on_life_loss=False,
                                      grayscale_obs=True, grayscale_newaxis=False,
                                      scale_obs=False)  # Frame stacking
env = gym.wrappers.FrameStack(env, 4)

In [13]:
env.reset()
for _ in tqdm(range(5000)):
    next_state, reward, terminated, truncated, _ = env.step(env.action_space.sample())
    if terminated or truncated:
        env.reset()

    assert np.all(np.array(next_state, dtype=np.uint8) == np.array(next_state, dtype=np.int16)), "Failed!"

  0%|          | 0/5000 [00:00<?, ?it/s]