In [None]:
from bc_algos.dataset.robomimic import RobomimicDataset
from bc_algos.dataset.isaac_gym import IsaacGymDataset
from bc_algos.envs.robosuite import RobosuiteEnv
from bc_algos.utils.constants import Modality
import bc_algos.utils.constants as Constants
import matplotlib.pyplot as plt
import json

# change this to test different simulator environments
type = Constants.EnvType.ROBOSUITE

### Dataset Init

In [None]:
obs_key_to_modality = {
    "robot0_eef_pos": Modality.LOW_DIM, 
    "robot0_eef_quat": Modality.LOW_DIM, 
    "agentview_image": Modality.RGB
}
obs_group_to_key = {
    "obs": ["robot0_eef_pos", "robot0_eef_quat", "agentview_image"], 
    "goal": ["agentview_image"]
}

if type == Constants.EnvType.ROBOSUITE:
    path = "../datasets/test/square_ph.hdf5"
    action_key = "actions"
    demo = "demo_0"

elif type == Constants.DatasetType.ISAAC_GYM:
    path = "../datasets/dataset_v10_diff"
    action_key = "action"
    demo = 0

In [None]:
if type == Constants.EnvType.ROBOSUITE:
    dataset = RobomimicDataset(
        path=path,
        obs_key_to_modality=obs_key_to_modality,
        obs_group_to_key=obs_group_to_key, 
        dataset_keys=[action_key], 
        frame_stack=0,
        seq_length=1,
        pad_frame_stack=False,
        pad_seq_length=False,
        get_pad_mask=False,
        demos=[demo],
        normalize=False,
    )

if type == Constants.EnvType.ISAAC_GYM:
    dataset = IsaacGymDataset(
        path=path,
        obs_key_to_modality=obs_key_to_modality,
        obs_group_to_key=obs_group_to_key, 
        dataset_keys=[action_key], 
        frame_stack=0,
        seq_length=1,
        pad_frame_stack=False,
        pad_seq_length=False,
        get_pad_mask=False,
        demos=[demo],
        normalize=False,
    )

### Environment Init

In [None]:
if type == Constants.EnvType.ROBOSUITE:
    env_meta = json.loads(dataset.hdf5_file["data"].attrs["env_args"])
    env = RobosuiteEnv(
        env_name=env_meta["env_name"],
        obs_key_to_modality=obs_key_to_modality,
        render=False,
        use_image_obs=True, 
        use_depth_obs=False,
        **env_meta["env_kwargs"],
    )

### Test Environment

In [None]:
if type == Constants.EnvType.ROBOSUITE:
    xml = dataset.hdf5_file[f"data/{demo}"].attrs["model_file"]
    env.load_env(xml=xml)
    init_state = dataset.hdf5_file[f"data/{demo}/states"][0]
    env.reset_to(state=init_state)

for i in range(len(dataset)):
    frame = dataset[i]
    action = frame[action_key][0]
    env.step(action=action)

final_img = env.render()
fig, axs = plt.subplots(1, 1)
axs.imshow(final_img)
plt.show()