In [8]:
import jax
import jax.numpy as jnp
from flashbax.vault import Vault
import numpy as np
import os
import torch as th
import yaml

from components.episode_buffer import EpisodeBatch
from functools import partial
from envs import REGISTRY as env_REGISTRY
from components.transforms import OneHot
from components.offline_buffer import DataSaver
from types import SimpleNamespace as SN

In [None]:
# You can transform og-marl (https://github.com/instadeepai/og-marl) dataset into h5 dataset suitable for offpymarl framework in this file
# Use the Google Drive URL(https://drive.google.com/drive/folders/1lw-e5VwIdCtmsGWgQG902yZRArU69TrH) 
# or follow https://github.com/instadeepai/og-marl/blob/main/examples/download_dataset.py
# to download the og-marl dataset
# Create 'ogmarl_dataset' folder in offpymarl to store corresponding .vlt dataset

# Extra package requirements:
# jax==0.4.28
# flashbax==0.1.2

In [6]:
dataset_path = os.path.join("/".join(os.getcwd().split('/')[:-1]), 'ogmarl_dataset')
vault_uid2quality = {
    "Good": "expert",
    "Medium": "medium",
    "Poor": "poor"
}
# You can change the following parameters according to your needs
map_name = "3m"
og_quality = "Good"
num_traj_per_file = 10000

offpymarl_quality = vault_uid2quality[og_quality]



In [7]:
vlt = Vault(rel_dir=dataset_path, vault_name=f"{map_name}.vlt", vault_uid=og_quality)
all_data = vlt.read()
offline_data = all_data.experience
jax.tree_map(lambda x: x.shape, offline_data)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Loading vault found at /home/zzq/Project/GitProject/offpymarl/ogmarl_dataset/3m.vlt/Good


  jax.tree_map(lambda x: x.shape, offline_data)


{'actions': (1, 996366, 3),
 'infos': {'legals': (1, 996366, 3, 9), 'state': (1, 996366, 48)},
 'observations': (1, 996366, 3, 30),
 'rewards': (1, 996366, 3),
 'terminals': (1, 996366, 3),
 'truncations': (1, 996366, 3)}

In [10]:
with open("config/envs/sc2.yaml", "r") as f:
    env_config = yaml.load(f)
env_args = SN(**env_config)
env_args.env_args['map_name'] = map_name

  env_config = yaml.load(f)


In [12]:
# env_args.env_args

In [13]:
env = env_REGISTRY[env_args.env](**env_args.env_args)
env_info = env.get_env_info()
for k, v in env_info.items():
    setattr(env_args, k, v)

In [14]:
scheme = {
    "state": {"vshape": env_info["state_shape"]},
    "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
    "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
    "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
    "reward": {"vshape": (1,)},
    "terminated": {"vshape": (1,), "dtype": th.uint8},
    "corrected_terminated": {"vshape": (1,), "dtype": th.uint8},
}
groups = {
    "agents": env_args.n_agents
}
preprocess = {
    "actions": ("actions_onehot", [OneHot(out_dim=env_args.n_actions)])
}

In [15]:
scheme

{'state': {'vshape': 48},
 'obs': {'vshape': 30, 'group': 'agents'},
 'actions': {'vshape': (1,), 'group': 'agents', 'dtype': torch.int64},
 'avail_actions': {'vshape': (9,), 'group': 'agents', 'dtype': torch.int32},
 'reward': {'vshape': (1,)},
 'terminated': {'vshape': (1,), 'dtype': torch.uint8},
 'corrected_terminated': {'vshape': (1,), 'dtype': torch.uint8}}

In [16]:

episode_limit = env.episode_limit
new_batch_fn = partial(EpisodeBatch, scheme, groups, 1, episode_limit + 1,
                                 preprocess=preprocess, device="cpu")


In [17]:
# from jnp.array -> np.array
avail_actions = offline_data["infos"]["legals"]
states = offline_data["infos"]["state"]
terminated = jnp.maximum(offline_data["terminals"], offline_data["truncations"])[..., 0]
actions = offline_data["actions"]
observations = offline_data["observations"]
rewards = offline_data["rewards"][..., 0]
print(avail_actions.shape, states.shape, terminated.shape, actions.shape, observations.shape, rewards.shape)

avail_actions, states, terminated, actions, observations, rewards = np.asarray(avail_actions), np.asarray(states), np.asarray(terminated), np.asarray(actions), np.asarray(observations), np.asarray(rewards)

(1, 996366, 3, 9) (1, 996366, 48) (1, 996366) (1, 996366, 3) (1, 996366, 3, 30) (1, 996366)


In [18]:
avail_actions, states, terminated, actions, observations, rewards = np.asarray(avail_actions), np.asarray(states), np.asarray(terminated), np.asarray(actions), np.asarray(observations), np.asarray(rewards)
episode_idxs = np.nonzero(terminated)[1]

In [19]:
save_path = f"../ogmarl_dataset/sc2/{map_name}/{offpymarl_quality}"

offline_saver = DataSaver(save_path, None, num_traj_per_file)

start_idx = 0
from tqdm import tqdm
for end_idx in tqdm(episode_idxs):
    tmp_batch = new_batch_fn()
    episode_slice = slice(start_idx, end_idx + 1)
    t_slice = slice(0, end_idx - start_idx + 1)
    episode_avail_actions = avail_actions[:, episode_slice]
    episode_terminated = terminated[:, episode_slice]
    # Notice: no last data as "episode_runner"! 
    # for teriminated states s_t, Q(s_t, a_t) will still be updated with Q(s_{t+1},...))
    # can not see s_{t+1} now, so we force a new terminated 
    episode_corrected_terminated = episode_terminated.copy()
    episode_corrected_terminated[0][-2] = 1
    episode_states = states[:, episode_slice]
    episode_actions = actions[:, episode_slice]
    episode_observations = observations[:, episode_slice]
    episode_rewards = rewards[:, episode_slice]
    
    transition_data = {
        "state": episode_states,
        "obs": episode_observations,
        "actions": episode_actions,
        "avail_actions": episode_avail_actions,
        "reward": episode_rewards,
        "terminated": episode_terminated,
        "corrected_terminated": episode_corrected_terminated
    }
   
    tmp_batch.update(transition_data, ts=t_slice)
    offline_saver.append(data={
        k:tmp_batch[k].clone().cpu() for k in tmp_batch.data.transition_data.keys()
    })

    start_idx = end_idx + 1

offline_saver.close()


 23%|██▎       | 10000/43559 [00:32<26:12, 21.34it/s] 

Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_0.h5 with 10000 episodes


 46%|████▌     | 20000/43559 [01:05<15:27, 25.40it/s]  

Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_1.h5 with 10000 episodes


 69%|██████▉   | 30000/43559 [01:38<11:49, 19.11it/s]  

Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_2.h5 with 10000 episodes


 92%|█████████▏| 40000/43559 [02:08<02:10, 27.20it/s]  

Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_3.h5 with 10000 episodes


100%|██████████| 43559/43559 [02:09<00:00, 335.25it/s] 


Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_4.h5 with 3559 episodes
