In [1]:
import jax
import jax.numpy as jnp
from flax import struct

from jax.flatten_util import ravel_pytree

In [2]:
@struct.dataclass
class Trajectory:
    observation: jnp.ndarray
    action: jnp.ndarray
    value_target: jnp.ndarray
    policy_target: jnp.ndarray
    reward_target: jnp.ndarray
    priority: jnp.ndarray

In [15]:
batch = 32
time = 10
trajectory = Trajectory(
    observation= jnp.zeros((batch, 128, 84, 84)),
    action= jnp.zeros((batch)),
    value_target=jnp.zeros((batch)),
    policy_target=jnp.zeros((batch, 4)),
    reward_target=jnp.zeros((batch)),
    priority=jnp.zeros((batch))
)


In [28]:
flatten_fn = jax.vmap(lambda x: ravel_pytree(x)[0]) # to be used on a single trajectory
_, unflatten_fn = ravel_pytree(trajectory)
unflatten_fn = jax.vmap(unflatten_fn) # to be used on entire replay_buffer

# Try brax uniform sampling buffer with your trajectory to better understand how it works 
# then introduce PER

In [None]:
from typing import Any, Tuple
from jax.random import PRNGKey

State = Any
Sample = Any

@struct.dataclass
class ReplayBufferState:
  """Contains data related to a replay buffer."""
  data: jnp.ndarray
  current_position: jnp.ndarray
  current_size: jnp.ndarray
  key: PRNGKey

class ReplayBuffer:
  """
  Priotized experience replay buffer.
  
  Modified port of brax.training.replay_buffers
  https://github.com/google/brax/blob/b373f5a45e62189a4a260131c17b10181ccda96a/brax/training/replay_buffers.py
  
  """

  def __init__(self, max_replay_size: int, dummy_data_sample: Sample,
               sample_batch_size: int) -> State:
    """Init the replay buffer."""
    self._flatten_fn = jax.vmap(lambda x: ravel_pytree(x)[0])

    dummy_flatten, self._unflatten_fn = ravel_pytree(dummy_data_sample)
    self._unflatten_fn = jax.vmap(self._unflatten_fn)
    data_size = len(dummy_flatten)

    self._data_shape = (max_replay_size, data_size)
    self._data_dtype = dummy_flatten.dtype
    self._sample_batch_size = sample_batch_size
    
  def init(self, key: PRNGKey) -> ReplayBufferState:
      return ReplayBufferState(
        data=jnp.zeros(self._data_shape, self._data_dtype),
        current_size=jnp.zeros((), jnp.int32),
        current_position=jnp.zeros((), jnp.int32),
        key=key)

  def insert(self, buffer_state: State, samples: Sample) -> State:
    """Insert data in the replay buffer.

    Args:
      buffer_state: Buffer state
      samples: Sample to insert with a leading batch size.

    Returns:
      New buffer state.
    """
    if buffer_state.data.shape != self._data_shape:
      raise ValueError(
          f'buffer_state.data.shape ({buffer_state.data.shape}) '
          f'doesn\'t match the expected value ({self._data_shape})')

    update = self._flatten_fn(samples)
    data = buffer_state.data

    # Make sure update is not larger than the maximum replay size.
    if len(update) > len(data):
      raise ValueError(
          'Trying to insert a batch of samples larger than the maximum replay '
          f'size. num_samples: {len(update)}, max replay size {len(data)}')

    # If needed, roll the buffer to make sure there's enough space to fit
    # `update` after the current position.
    position = buffer_state.current_position
    roll = jnp.minimum(0, len(data) - position - len(update))
    data = jax.lax.cond(roll, lambda: jnp.roll(data, roll, axis=0),
                        lambda: data)
    position = position + roll

    # Update the buffer and the control numbers.
    data = jax.lax.dynamic_update_slice_in_dim(data, update, position, axis=0)
    position = (position + len(update)) % len(data)
    size = jnp.minimum(buffer_state.current_size + len(update), len(data))

    return buffer_state.replace(
        data=data, current_position=position, current_size=size)

  def sample(self, buffer_state: State) -> Tuple[State, Sample]:
    """Sample a batch of data according to prioritized experience replay."""
    
    

  def size(self, buffer_state: State) -> int:
    """Total amount of elements that are sampleable."""

# Replay Buffer Logic
- ReplayBufferState
    - data : batch of flattened trajectories
    - current_size: 
    - current_position
    - key

- Flattening / Unflattening ops
    - flatten_trajectory: ```vmap(lambda x: ravel_pytree(x)[0])```
    - unflatten_trajectory: ```dummy_flatten, jax.vmap(self._unflatten_fn) = flatten_util.ravel_pytree(dummy_data_sample) ```

In [33]:
# example of how to use dynamic slice in dim

arr = jnp.arange(20).reshape((5, 4))
test = jax.lax.dynamic_slice_in_dim(arr, jnp.array([1, 2, 3, 4, 5]), 1, axis=0)

TypeError: start_indices arguments to dynamic_slice must be scalars,  got indices (ShapedArray(int32[5]), ShapedArray(int32[]))

In [3]:
!pip install zstandard

Collecting zstandard
  Downloading zstandard-0.20.0-cp39-cp39-macosx_10_9_x86_64.whl (456 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m456.0/456.0 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: zstandard
Successfully installed zstandard-0.20.0


In [24]:
import zlib
import zstandard as zstd
import numpy as np
import sys

# Create some sample data to compress
data = np.zeros((128, 3, 86, 86), dtype=np.float32).tobytes()
print(f"Data size: {sys.getsizeof(data)} bytes")

# Compress the data using Zstd
cctx = zstd.ZstdCompressor()
compressed_data = cctx.compress(data)
print(f"Compressed data size: {sys.getsizeof(compressed_data)} bytes")

Data size: 11360289 bytes
Compressed data size: 400 bytes


In [63]:
import numpy as np
import zstandard as zstd

# Create a large NumPy array
arr = np.zeros((1024, 64, 84, 84), dtype=np.float32)
arr = arr.tobytes()

In [64]:
# Compress the array using Zstd
cctx = zstd.ZstdCompressor()
%timeit compressed_data = cctx.compress(arr)

342 ms ± 16.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [62]:
# Decompress the data using Zstd
dctx = zstd.ZstdDecompressor()
%timeit decompressed_data = np.frombuffer(dctx.decompress(compressed_data), dtype=arr.dtype)

906 ms ± 82.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# Verify that the decompressed data matches the original array
shaped_decompressed_data = decompressed_data.reshape(arr.shape)
assert np.array_equal(shaped_decompressed_data, arr)

In [59]:
# compression factor
original_data_size = arr.nbytes
compressed_data_size = len(compressed_data)
decompressed_data_size = decompressed_data.nbytes

print(f"original_data_size: {original_data_size}")
print(f"compressed_data_size: {compressed_data_size}")
print(f"decompressed_data: {decompressed_data_size}")

# % compression
percent_compression = (original_data_size - compressed_data_size) / original_data_size * 100
percent_compression

original_data_size: 1849688064
compressed_data_size: 56467
decompressed_data: 1849688064


99.99694721498727

In [46]:
1024 * 64 * 84 * 84 * 4 #/ (1024**3)

1849688064

In [None]:
# Note to self: it is faster to reuse the compressor and decompressor objects

In [77]:
from replay_buffer import GameHistory

history = GameHistory(500)

observations = np.zeros((500, 3, 84, 84), dtype=np.float32)
actions = np.zeros((500,), dtype=np.float32)
values = np.zeros((500,), dtype=np.float32)
policy = np.zeros((500, 18), dtype=np.float32)
reward = np.zeros((500,), dtype=np.float32)
done = np.zeros((500,), dtype=np.float32)

history.init(observations, actions, values, policy, reward, done)

In [79]:
import pickle 

history_bytes = pickle.dumps(history)
history_size = len(history_bytes)
history_size_in_gb = history_size / (1024**3)

print(f"Size of history: {history_size_in_gb} GB")

Size of history: 0.039531199261546135 GB


In [81]:
cctx = zstd.ZstdCompressor()
compressed_data = cctx.compress(history_bytes)

In [82]:
compressed_data_size = len(compressed_data)
compressed_data_size_in_gb = compressed_data_size / (1024**3)
compressed_data_size_in_gb

4.0102750062942505e-06

In [83]:
# % compression
compression_per = (history_size - compressed_data_size) / history_size * 100
compression_per

99.98985541779352

In [85]:
compressed_data_size * 50_000 / (1024**3)

0.20051375031471252