In [2]:
import jax
import jax.numpy as jnp
from flax import struct


In [3]:
@struct.dataclass
class Rollout:
    """A class for storing batched rollout data with methods for padding"""
    obs: jnp.ndarray
    actions: jnp.ndarray
    value_targets: jnp.ndarray
    policy_targets: jnp.ndarray
    rewards: jnp.ndarray
    dones: jnp.ndarray
    priorities: jnp.ndarray

@struct.dataclass
class GameHistory:
    observations: jnp.ndarray
    actions: jnp.ndarray
    values: jnp.ndarray
    policies: jnp.ndarray
    rewards: jnp.ndarray
    dones: jnp.ndarray


local_num_envs = 100
num_steps = 500
num_actions = 4

In [4]:
import zstandard as zstd
from jax import flatten_util

In [32]:
# flatten pytree to a single array
rollout = Rollout(
    obs = jnp.zeros((local_num_envs, num_steps, 3, 84, 84)),
    actions = jnp.zeros((local_num_envs, num_steps)),
    value_targets = jnp.zeros((local_num_envs, num_steps)),
    policy_targets = jnp.zeros((local_num_envs, num_steps, num_actions)),
    rewards = jnp.zeros((local_num_envs, num_steps)),
    dones = jnp.zeros((local_num_envs, num_steps)),
    priorities = jnp.zeros((local_num_envs, num_steps)),
)

In [35]:
50_000 * 200

10000000

In [36]:
flatten_fn = jax.vmap(lambda x: flatten_util.ravel_pytree(x)[0]) # batched rollout to flattened rollouts
flat_example, unflatten_fn = flatten_util.ravel_pytree(rollout)
unflatten_fn = jax.vmap(unflatten_fn) #

In [37]:
flat_rollout = flatten_fn(rollout)
flat_rollout_bytes = flat_rollout.tobytes()
print("type: ", type(flat_rollout))
print("shape: ", flat_rollout.shape)
print("nbytes: ", flat_rollout.nbytes)
print("tobytes nbytes: ", len(flat_rollout_bytes))
print("ngbytes: ", flat_rollout.nbytes / 1028**3)

type:  <class 'jaxlib.xla_extension.ArrayImpl'>
shape:  (100, 10588500)
nbytes:  4235400000
tobytes nbytes:  4235400000
ngbytes:  3.8986575407139363


In [21]:
compressor = zstd.ZstdCompressor()
decompressor = zstd.ZstdDecompressor()

In [22]:
compressed_flat_rollout = compressor.compress(flat_rollout_bytes)
decompressed_flat_rollout_bytes = decompressor.decompress(compressed_flat_rollout)
flat_rollout_reconstruciton = jnp.frombuffer(decompressed_flat_rollout_bytes, dtype=jnp.float32)

25871

In [28]:
rollout = unflatten_fn(flat_rollout_reconstruciton)

In [None]:
# put rollout on learner device and add rollout to queue
# transfer rollout to (game_history, priority) tuple
# add (game_history, priority) to replay buffer

### Prototyping 

In [1]:
import jax
import jax.numpy as jnp
from flax import struct
import zstandard as zstd
from jax import flatten_util

local_num_envs = 100
num_steps = 500
num_actions = 4

@struct.dataclass
class Batch:
    observation: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    policy: jnp.ndarray
    reward: jnp.ndarray
    weight: jnp.ndarray
    index: jnp.ndarray


@struct.dataclass
class GameHistory:
    observations: jnp.ndarray
    actions: jnp.ndarray
    values: jnp.ndarray
    policies: jnp.ndarray
    rewards: jnp.ndarray
    dones: jnp.ndarray
    
    def index(self, i):
        return Batch(
            observations=self.observations[i],
            actions=self.actions[i],
            values=self.values[i],
            policies=self.policies[i],
            rewards=self.rewards[i],
            dones=self.dones[i],
        )

In [4]:
game_history = GameHistory(
    observations=jnp.zeros((local_num_envs, num_steps, 3, 84, 84)),
    actions = jnp.zeros((local_num_envs, num_steps)),
    values = jnp.zeros((local_num_envs, num_steps)),
    policies = jnp.zeros((local_num_envs, num_steps, num_actions)),
    rewards = jnp.zeros((local_num_envs, num_steps)),
    dones = jnp.zeros((local_num_envs, num_steps)),
)

test = flatten_util.ravel_pytree(game_history)[0]
bytes = test.tobytes()
print("GB: ", len(bytes) / 1024**3)

GB:  3.9443373680114746


In [3]:
class ReplayBuffer:
    """
    - priorities is of the shape (compression_dim, batch_dim, time_dim)
    """
    def __init__(self):
        self.batch_size
        self.timesteps_seen = 0

        self.compressor = zstd.ZstdCompressor()
        self.decompressor = zstd.ZstdDecompressor()

        self.buffer = []
        self.priorities = None

    def add(self, b_game_history, b_priority):
        if not self._started():
            self._setup(b_game_history)
            self.priorities = jnp.expand_dims(b_priority, axis=0) # add compression dim
        else:
            b_priority = jnp.expand_dims(b_priority, axis=0) # add compression dim
            self.priorities = jnp.concatenate((self.priorities, b_priority), axis=0) # concat along compression dim

        b_compressed = self._compress(b_game_history)
        self.buffer.append(b_compressed)
        del b_game_history

    def sample(self, key):
        index, weights = self._sample(key)
        buffer_index, batch_index, time_index = index
        decompression_index = jnp.unique(buffer_index)
        batch = self.emtpy_batch(index, weights)
        for i in decompression_index:
            b_game_history = self._decompress(self.buffer[i])
            b_index = jnp.where(buffer_index == i)
            b_batch_index = batch_index[b_index]
            b_time_index = time_index[b_index]
            batch = b_game_history.index(batch, b_batch_index, b_time_index)

        del b_game_history
        return batch
    
    # cant jit unless priorities shape is static
    def _sample(self, key):
        flat_priorities = self.priorities.ravel()
        probs /= flat_priorities.sum()
        flat_index = jax.random.choice(key, len(flat_priorities), shape=self.batch_size, replace=False, p=probs)
        sampled_weights = (1 / len(flat_priorities)) / probs[flat_index]
        sampled_index = jnp.unravel_index(flat_index, self.priorities.shape) 
        return sampled_index, sampled_weights

    def _compress(self, data):
        flat = self.flatten_fn(data)
        bytes = flat.tobytes()
        compressed = self.compressor.compress(bytes)
        del data, flat, bytes
        return compressed

    def _decompress(self, data):
        bytes = self.decompressor.decompress(data)
        flat = jnp.frombuffer(bytes, dtype=jnp.float32)
        d_data = self.unflatten_fn(flat)
        del data, bytes, flat
        return d_data

    def _setup(self, b_game_history):
        self.flatten_fn = lambda x: flatten_util.ravel_pytree(x)[0]
        self.flat_b_game_history_example, self.unflatten_fn = flatten_util.ravel_pytree(b_game_history)
        
        self.flatten_fn = jax.jit(self.flatten_fn)
        self.unflatten_fn = jax.jit(self.unflatten_fn)

    def emtpy_batch(self, index, weights):
        return Batch(
            observation=jnp.zeros((self.batch_size, 3, 84, 84)),
            action=jnp.zeros((self.batch_size)),
            value=jnp.zeros((self.batch_size)),
            policy=jnp.zeros((self.batch_size, self.num_actions)),
            reward=jnp.zeros((self.batch_size)),
            done=jnp.zeros((self.batch_size)),
            weights=weights,
            index=index,
        )

    def _started(self):
        return self.timesteps_seen > 0

In [5]:
replay_buffer = ReplayBuffer()

priorities = jnp.ones((local_num_envs, num_steps))
replay_buffer.add(game_history, priorities)

# Testing 
- Goal
    - determine which is faster
        - performing more zstd operations on smaller arrays 
        - performing fewer zstd operations on larger arrays

    - this will determine how to build the replay buffer
        - hypothesis: fewer zstd operations on larger arrays is faster

In [4]:
import time

compressor = zstd.ZstdCompressor()
decompressor = zstd.ZstdDecompressor()

def simple_timit(fn, args, num_trials):
    times = []
    for _ in range(num_trials):
        start = time.time()
        fn(args)
        end = time.time()
        times.append(end - start)
    print("mean: ", sum(times) / len(times))
    print("min: ", min(times))
    print("max: ", max(times))
    print(times)

batched_flatten_fn = jax.vmap(lambda x: flatten_util.ravel_pytree(x)[0]) 

def more_smaller_ops(b_flat):
    compressed_list = []
    for i in range(len(b_flat)):
        flat = b_flat[i]
        bytes = flat.tobytes()
        compressed = compressor.compress(bytes)
        compressed_list.append(compressed)

    # for i in range(len(compressed_list)):
    #     compressed = compressed_list[i]
    #     bytes = decompressor.decompress(compressed)
    #     flat = jnp.frombuffer(bytes, dtype=jnp.float32)


flatten_fn = lambda x: flatten_util.ravel_pytree(x)[0]


def less_bigger_ops(arr):
    compressor = zstd.ZstdCompressor()
    decompressor = zstd.ZstdDecompressor()
    flat = flatten_fn(arr)
    bytes = flat.tobytes()
    compressed = compressor.compress(bytes)
    bytes = decompressor.decompress(compressed)
    flat = jnp.frombuffer(bytes, dtype=jnp.float32)
    return flat

In [12]:
# more smaller ops
compressor = zstd.ZstdCompressor()
decompressor = zstd.ZstdDecompressor()

def compress_batch(b_flat):
    compressed_list = []
    for i in range(len(b_flat)):
        flat = b_flat[i]
        bytes = flat.tobytes()
        compressed = compressor.compress(bytes)
        compressed_list.append(compressed)

b_flat = batched_flatten_fn(game_history)
simple_timit(compress_batch, b_flat, 10)

mean:  6.7116758823394775
min:  6.511224269866943
max:  7.019108295440674
[6.511224269866943, 6.511491298675537, 6.555873870849609, 6.587079048156738, 6.634561538696289, 6.790647745132446, 7.019108295440674, 6.890180349349976, 6.8210227489471436, 6.795569658279419]


In [19]:
compressor = zstd.ZstdCompressor()
decompressor = zstd.ZstdDecompressor()

flat = b_flat
bytes = flat.tobytes()
compressed_bytes = compressor.compress(bytes)

In [20]:

def decompress(bytes):
    #bytes = flat.tobytes()
    bytes = decompressor.decompress(bytes)
    flat = jnp.frombuffer(bytes, dtype=jnp.float32)

In [21]:
simple_timit(decompress, compressed_bytes, 3)

mean:  2.9494717121124268
min:  2.8094582557678223
max:  3.0224905014038086
[2.8094582557678223, 3.0224905014038086, 3.0164663791656494]


# special indexing

In [5]:
import jax.numpy as jnp
arr = jnp.arange(1000).reshape(20, 5, 10)

arr_flat = arr.ravel()

In [22]:
flat_index = jnp.array([500, 600, 700])

index = jnp.unravel_index(flat_index, arr.shape)
index

(Array([10, 12, 14], dtype=int32),
 Array([0, 0, 0], dtype=int32),
 Array([0, 0, 0], dtype=int32))

In [25]:
# take a slice of size 5 on the last dim
index = (index[0], index[1], slice(index[2], index[2] + 5))
test = arr[index]

TypeError: unsupported operand type(s) for +: 'slice' and 'int'

: 

In [24]:
test.shape

(3, 2)