In [1]:
import os
import glob
import h5py
import datetime
import numpy as np
from tqdm import tqdm, trange

blocking_control = True
normalize_acts = True
normalize_obs = False

def shortest_angle(angles):
    return (angles + np.pi) % (2*np.pi) - np.pi

def normalize(arr, stats, key):
    min_val, max_val = stats[f"{key}_min"], stats[f"{key}_max"]
    return 2 * (arr - min_val) / (max_val - min_val) - 1

def unnormalize(arr, stats, key):
    min_val, max_val = stats[f"{key}_min"], stats[f"{key}_max"]
    return 0.5 * (arr + 1) * (max_val - min_val) + min_val

# create dataset
data_dir = "data"
input_datasets = ["left_25", "right_25", "middle_25"]
splits = ["train", "eval"]
output_dataset = "left_right_middle_75"
dataset_paths = [os.path.join(data_dir, dataset_name) for dataset_name in input_datasets]

print("Processing data ...")

hdf5_path = os.path.join(data_dir, output_dataset)
os.makedirs(hdf5_path, exist_ok=True)
f = h5py.File(os.path.join(hdf5_path, "demos.hdf5"), "w")

# create data group
grp = f.create_group("data")
grp_mask = f.create_group("mask")

episodes = 0

demo_keys = {}

for split in splits:
    
    demo_keys[split] = []

    for dataset_path in dataset_paths:

        print(f"Loading {dataset_path} {split} ...")
        
        # gather filenames
        file_names = glob.glob(os.path.join(dataset_path, split,"episode_*.npy"))
        
        for i in trange(len(file_names)):

            # load data
            data = np.load(file_names[i], allow_pickle=True)

            # stack data
            dic = {}
            obs_keys = data[0].keys()
            for key in obs_keys:
                dic[key] = np.stack([d[key] for d in data])
            actions = np.stack([d["action"] for d in data])

            if blocking_control:
                # compute actual deltas s_t+1 - s_t (keep gripper actions)
                actions_tmp = actions.copy()
                actions_tmp[:-1,...,:6] = dic["lowdim_ee"][1:,...,:6] - dic["lowdim_ee"][:-1,...,:6]
                actions = actions_tmp[:-1]
                
                # remove last state s_T 
                for key in obs_keys:
                    dic[key] = dic[key][:-1]

            # create demo group
            demo_key = f"demo_{episodes}"
            demo_keys[split].append(demo_key)
            ep_data_grp = grp.create_group(demo_key)

            # compute shortest angle
            actions[...,3:6] = shortest_angle(actions[...,3:6])

            # add action dataset
            ep_data_grp.create_dataset("actions", data=actions)
            
            # add done dataset
            dones = np.zeros(len(actions)).astype(bool)
            dones[-1] = True
            ep_data_grp.create_dataset("dones", data=dones)

            # create obs and next_obs groups
            ep_obs_grp = ep_data_grp.create_group("obs")

            # add obs and next_obs datasets
            for obs_key in obs_keys:
                if obs_key == "language_instruction":
                    continue
                obs = dic[obs_key]
                ep_obs_grp.create_dataset(obs_key, data=obs)

            ep_data_grp.attrs["num_samples"] = len(actions)

            episodes += 1

    # create mask dataset
    grp_mask.create_dataset(split, data=np.array(demo_keys[split], dtype="S"))

# write dataset attributes (metadata)
grp.attrs["episodes"] = episodes
grp.attrs["env_args"] = "blub"

print("Computing training statistics ...")
actions = np.concatenate([grp[demo_key]["actions"] for demo_key in demo_keys["train"]])

stats = {}
stats["actions_min"] = actions.min(axis=0)
stats["actions_max"] = actions.max(axis=0)
stats_grp = grp.create_group("stats")
for key in stats.keys():
    stats_grp.create_dataset(key, data=stats[key])

print("Normalizing actions ...")
for split in splits:
    for demo_key in demo_keys[split]:
        actions = grp[demo_key]["actions"]
        actions = normalize(actions, stats, key="actions")
        grp[demo_key]["actions"][...] = actions

now = datetime.datetime.now()
grp.attrs["date"] = "{}-{}-{}".format(now.month, now.day, now.year)
grp.attrs["time"] = "{}:{}:{}".format(now.hour, now.minute, now.second)
grp.attrs["blocking_control"] = blocking_control

f.close()

print("Saved at: {}".format(hdf5_path))

Processing data ...
Loading data/left_25 train ...


100%|██████████| 25/25 [00:00<00:00, 146.30it/s]


Loading data/right_25 train ...


100%|██████████| 25/25 [00:00<00:00, 160.20it/s]


Loading data/middle_25 train ...


100%|██████████| 25/25 [00:00<00:00, 152.03it/s]


Loading data/left_25 eval ...


100%|██████████| 2/2 [00:00<00:00, 157.00it/s]


Loading data/right_25 eval ...


100%|██████████| 2/2 [00:00<00:00, 153.66it/s]


Loading data/middle_25 eval ...


100%|██████████| 2/2 [00:00<00:00, 148.76it/s]


Computing training statistics ...
Normalizing actions ...
Saved at: data/left_right_middle_75


In [3]:
stats

{'actions_min': array([-0.00635335, -0.00486431, -0.01493801, -0.02185639, -0.03439584,
        -0.02118944,  0.        ]),
 'actions_max': array([0.00385384, 0.01571385, 0.01000646, 0.01401203, 0.02496785,
        0.01949651, 1.        ])}