In [4]:
import os
import glob
import h5py
import datetime
import numpy as np
from tqdm import tqdm, trange

normalize_acts = True
normalize_obs = False

def shortest_angle(angles):
    return (angles + np.pi) % (2*np.pi) - np.pi

def normalize(arr, stats, key):
    min_val, max_val = stats[f"{key}_min"], stats[f"{key}_max"]
    return 2 * (arr - min_val) / (max_val - min_val) - 1

def unnormalize(arr, stats, key):
    min_val, max_val = stats[f"{key}_min"], stats[f"{key}_max"]
    return 0.5 * (arr + 1) * (max_val - min_val) + min_val

# create dataset
dataset_path = "data/sim2real_sl_1k/"

print("Loading data and compute stats ...")
# compute action min and max
file_names = glob.glob(f"{dataset_path}/train/episode_*.npy")
data = []
for file in tqdm(file_names):
    data.append(np.load(file, allow_pickle=True))

stats = {}

actions = np.concatenate([[np.stack(d["action"]) for d in sample] for sample in data])
actions[...,3:6] = shortest_angle(actions[...,3:6])
stats["actions_min"] = actions.min(axis=0)
stats["actions_max"] = actions.max(axis=0)
del actions

for key in ["lowdim_ee", "lowdim_qpos"]:
    stack = np.concatenate([[np.stack(d[key]) for d in sample] for sample in data])
    stats[f"{key}_min"] = stack.min(axis=0)
    stats[f"{key}_max"] = stack.max(axis=0)
    del stack

print("Processing data ...")
hdf5_path = os.path.join(dataset_path, "demos.hdf5")
f = h5py.File(hdf5_path, "w")

# create data group
grp = f.create_group("data")
grp_mask = f.create_group("mask")

episodes = 0

for split in ["train", "eval"]:

    # gather filenames
    file_names = glob.glob(os.path.join(dataset_path, split,"episode_*.npy"))
    demo_keys = []

    for i in trange(len(file_names)):

        # load data
        data = np.load(file_names[i], allow_pickle=True)

        # stack data
        dic = {}
        obs_keys = data[0].keys()
        for key in obs_keys:
            dic[key] = np.stack([d[key] for d in data])
        actions = np.stack([d["action"] for d in data])

        # create demo group
        demo_key = f"demo_{episodes}"
        demo_keys.append(demo_key)
        ep_data_grp = grp.create_group(demo_key)

        # compute shortest angle
        actions[...,3:6] = shortest_angle(actions[...,3:6])
        # normalize -> [-1,1]
        if normalize_acts:
            actions = normalize(actions, stats, key="actions")

        # add action dataset
        ep_data_grp.create_dataset("actions", data=actions)
        
        # add done dataset
        dones = np.zeros(len(actions)).astype(bool)
        dones[-1] = True
        ep_data_grp.create_dataset("dones", data=dones)

        # create obs and next_obs groups
        ep_obs_grp = ep_data_grp.create_group("obs")
        # ep_next_obs_grp = ep_data_grp.create_group("next_obs")

        # add obs and next_obs datasets
        for obs_key in obs_keys:
            if obs_key == "language_instruction":
                continue
            obs = dic[obs_key]
            # normalize -> [-1,1]
            if obs_key in stats.keys() and normalize_obs:
                obs = normalize(obs, stats, key=obs_key)
            ep_obs_grp.create_dataset(obs_key, data=obs)
            # ep_obs_grp.create_dataset(obs_key, data=obs[:-1])
            # ep_next_obs_grp.create_dataset(obs_key, data=obs[1:])

        ep_data_grp.attrs["num_samples"] = len(actions)

        episodes += 1

    # create mask dataset
    grp_mask.create_dataset(split, data=np.array(demo_keys, dtype="S"))

# write dataset attributes (metadata)
grp.attrs["episodes"] = episodes
grp.attrs["env_args"] = "blub"

stats_grp = grp.create_group("stats")
for key in stats.keys():
    stats_grp.create_dataset(key, data=stats[key])

now = datetime.datetime.now()
grp.attrs["date"] = "{}-{}-{}".format(now.month, now.day, now.year)
grp.attrs["time"] = "{}:{}:{}".format(now.hour, now.minute, now.second)

f.close()

print("Saved at: {}".format(hdf5_path))

Loading data and compute stats ...


100%|██████████| 1000/1000 [00:05<00:00, 179.24it/s]


Processing data ...


100%|██████████| 1000/1000 [00:08<00:00, 117.53it/s]
100%|██████████| 100/100 [00:01<00:00, 80.87it/s]


Saved at: data/sim2real_sl_1k/demos.hdf5


In [7]:
stats["actions_min"], stats["actions_max"]

(array([-0.01337784, -0.01347073, -0.02040082, -0.03218729, -0.0674465 ,
        -0.0447364 ,  0.        ]),
 array([0.01054632, 0.0194168 , 0.01450608, 0.04999172, 0.04921545,
        0.03100239, 1.        ]))