# Dataset API Demo

This simple notebook demonstrates how to use OG-MARL's dataset API, which is underpinned by [Flashbax](https://github.com/instadeepai/flashbax)'s Vault utility.

For this example, we'll download the `3m` dataset from the `smac_v1` environment. The zipped file is about 1.3GB in size.

In [None]:
!wget https://s3.kao.instadeep.io/offline-marl-dataset/vaults/3m.zip --show-progress

In [2]:
%%capture
!unzip 3m.zip -d vaults

We should now have a directory of `vaults`, containing the `3m.vlt` vault, which itself contains 3 datasets: `Good`, `Medium`, and `Poor`.

In [None]:
!ls -la vaults/3m.vlt

We'll take a look at the `Good` dataset in this example, but the methodology will apply to any of OG-MARL's Vault-style datasets.

Before continuing, we need to install Flashbax, which is the only necessary dependency. For our example, we'll also use `jax` and `jax.numpy`.

In [14]:
%%capture
! pip install flashbax~=0.1.2

import jax
import jax.numpy as jnp
import flashbax as fbx
from flashbax.vault import Vault

Now we can load in the Vault. Notice the keyword arguments, and how they map to the dataset location: `rel_dir` is the root directory of all vaults; `vault_name` is the set of vaults coming from one environment; `vault_uid` is the unique identifier of each dataset.

In [None]:
vlt = Vault(rel_dir="vaults", vault_name="3m.vlt", vault_uid="Good")

We can read this Vault using `.read()`. By default, we read the entire dataset.

In [9]:
all_data = vlt.read()

The read data is in the structure of a `TrajectoryBufferState` from Flashbax, with auxiliary fields `.current_index` and `.is_full`. For our example, let's just look at `.experience`, containing the experience data itself.

In [38]:
offline_data = all_data.experience

Let's look at the structure of this dataset.

In [39]:
jax.tree_map(lambda x: x.shape, offline_data)

{'actions': (1, 996366, 3),
 'infos': {'legals': (1, 996366, 3, 9), 'state': (1, 996366, 48)},
 'observations': (1, 996366, 3, 30),
 'rewards': (1, 996366, 3),
 'terminals': (1, 996366, 3),
 'truncations': (1, 996366, 3)}

This data is stored with the shapes of: $(B, T, N, *E)$, where $B$ is a stored batch dimension (useful for the synchronous storage of independent trajectories), $T$ is the time-axis of the data, $N$ is the number of agents, and $*E$ represents the one or more experience dimensions. e.g. For `observations`, we have `996366` timesteps from `3` agents, each with an observation of size `30`.

As another illustrative example, let's look at the first `25` timesteps of the `terminals`.

In [40]:
offline_data['terminals'][:, 0:25, ...]

Array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]], dtype=float32)

We see the 20th timestep has an array of terminals, `[1., 1., 1.]`, which signals the end of an episode. One could then, for example, calculate the return for this first episode. We use a `for` loop to illustrate below, though faster approaches can be taken, of course.

In [62]:
returns = jnp.zeros_like(offline_data['rewards'][:, 0, ...])
for t in range(offline_data['rewards'].shape[1]):
    reward = offline_data['rewards'][:, t, ...]
    print(f"Reward at {t}th step: {reward}")
    returns += reward
    terminal_flag = offline_data['terminals'][:, t, ...]
    if terminal_flag.all():
        break
print(f"Episode return: {returns}")

Reward at 0th step: [[0. 0. 0.]]
Reward at 1th step: [[0. 0. 0.]]
Reward at 2th step: [[0. 0. 0.]]
Reward at 3th step: [[0. 0. 0.]]
Reward at 4th step: [[0.32876712 0.32876712 0.32876712]]
Reward at 5th step: [[0.98630136 0.98630136 0.98630136]]
Reward at 6th step: [[0.32876712 0.32876712 0.32876712]]
Reward at 7th step: [[0.65753424 0.65753424 0.65753424]]
Reward at 8th step: [[0.32876712 0.32876712 0.32876712]]
Reward at 9th step: [[0.65753424 0.65753424 0.65753424]]
Reward at 10th step: [[0.7123288 0.7123288 0.7123288]]
Reward at 11th step: [[0.65753424 0.65753424 0.65753424]]
Reward at 12th step: [[0.32876712 0.32876712 0.32876712]]
Reward at 13th step: [[0.65753424 0.65753424 0.65753424]]
Reward at 14th step: [[1.0410959 1.0410959 1.0410959]]
Reward at 15th step: [[0.32876712 0.32876712 0.32876712]]
Reward at 16th step: [[0.98630136 0.98630136 0.98630136]]
Reward at 17th step: [[0.32876712 0.32876712 0.32876712]]
Reward at 18th step: [[0. 0. 0.]]
Reward at 19th step: [[11.671233 1

We can also inspect a single timestep easily:

In [45]:
jax.tree_map(lambda x: x[:, 19, ...], offline_data)

{'actions': Array([[6, 0, 6]], dtype=int32),
 'infos': {'legals': Array([[[0., 1., 1., 1., 1., 1., 1., 0., 0.],
          [1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 1., 1., 1., 1., 1., 1., 0., 0.]]], dtype=float32),
  'state': Array([[ 0.2       ,  0.        , -0.04060582,  0.05044992,  0.        ,
           0.        ,  0.        ,  0.        ,  0.2       ,  0.        ,
          -0.06469727,  0.        ,  0.06666667,  0.11004639,  0.002485  ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        ]], dtype=float32)},
 'observations': Array

We can easily use the Vault data as above, but Flashbax itself adds additional layers of convenient functionality. Here, we create a set of pure buffer functions, which we can use with the read data. Specifically, we sample a batch from the offline data.

In [51]:
BATCH_SIZE = 32
SEQUENCE_LENGTH = 20

buffer = fbx.make_trajectory_buffer(
    # Sampling parameters
    sample_batch_size=BATCH_SIZE,
    sample_sequence_length=SEQUENCE_LENGTH,
    period=1,
    # Not important in this example, as we are not adding to the buffer
    max_length_time_axis=1_000_000,
    min_length_time_axis=SEQUENCE_LENGTH,
    add_batch_size=1,
)

buffer_sample = jax.jit(buffer.sample)
seed = 0
key = jax.random.PRNGKey(seed)

samples = buffer_sample(all_data, key)

jax.tree_map(lambda x: x.shape, samples.experience)



{'actions': (32, 20, 3),
 'infos': {'legals': (32, 20, 3, 9), 'state': (32, 20, 48)},
 'observations': (32, 20, 3, 30),
 'rewards': (32, 20, 3),
 'terminals': (32, 20, 3),
 'truncations': (32, 20, 3)}

Notice the shape of this data, `(BATCH_SIZE, SEQUENCE_LENGTH, ...)`.

Though Vaults have tight integration with a JAX-oriented ecosystem (using Flashbax etc.), it is trivial to read in the dataset and convert to your array-type of choice. For example, vanilla numpy or tensorflow:

In [58]:
import numpy as np
all_data_np = jax.tree_map(lambda x: np.array(x), all_data)

print(all_data_np.experience)

{'actions': array([[[4, 2, 4],
        [4, 4, 4],
        [4, 4, 4],
        ...,
        [2, 7, 2],
        [2, 7, 2],
        [2, 7, 2]]], dtype=int32), 'infos': {'legals': array([[[[0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.]],

        [[0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.]],

        [[0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.]],

        ...,

        [[0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.]],

        [[0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.]],

        [[0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.]]]], dtype=float32), 'state': array([[[ 1.        ,  0.        , -0.25      , ..., 

In [61]:
import tensorflow as tf
all_data_tf = jax.tree_map(lambda x: tf.convert_to_tensor(x), all_data)

print(all_data_tf.experience)

{'actions': <tf.Tensor: shape=(1, 996366, 3), dtype=int32, numpy=
array([[[4, 2, 4],
        [4, 4, 4],
        [4, 4, 4],
        ...,
        [2, 7, 2],
        [2, 7, 2],
        [2, 7, 2]]], dtype=int32)>, 'infos': {'legals': <tf.Tensor: shape=(1, 996366, 3, 9), dtype=float32, numpy=
array([[[[0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.]],

        [[0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.]],

        [[0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.],
         [0., 1., 1., ..., 0., 0., 0.]],

        ...,

        [[0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.]],

        [[0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.]],

        [[0., 1., 1., ..., 0., 1., 0.],
         [0., 1., 1., ..., 0., 1., 0.],
       

Notice that the above code is independent of OG-MARL itself. This emphasises that the data from Vaults is not locked into our ecosystem. Nonetheless, OG-MARL provides many additional layers of useful, tightly integrated functionality. For example, we can easily analyse the distribution of the returns in a given vault.

In [16]:
%%capture
! pip install git+https://github.com/instadeepai/og-marl.git

In [None]:
from og_marl.offline_dataset import analyse_vault

episode_returns = analyse_vault("3m.vlt", visualise=True)