In [None]:
import h5py
import os
import json
from bc_algos.envs.robosuite import EnvRobosuite
from bc_algos.utils.obs_utils import Modality
import matplotlib.pyplot as plt

In [None]:
obs_key_to_modality = {
    "robot0_eef_pos": Modality.LOW_DIM,
    "robot0_eef_quat": Modality.LOW_DIM, 
    "agentview_image": Modality.RGB,
}
dataset_path = "../datasets/test/square_ph.hdf5"
f = h5py.File(os.path.expanduser(dataset_path), "r")

In [None]:
env_meta = json.loads(f["data"].attrs["env_args"])
env = EnvRobosuite(
    env_name=env_meta["env_name"],
    obs_key_to_modality=obs_key_to_modality,
    render=False,
    use_image_obs=True, 
    use_depth_obs=False,
    **env_meta["env_kwargs"],
)

### Test Environment

In [None]:
xml = f["data/demo_0"].attrs["model_file"]
env.load_env(xml=xml)
init_state = f["data/demo_0/states"][0]
env.reset_to(state=init_state)

for action in f["data/demo_0/actions"]:
    obs = env.step(action)
    for obs_key in obs.keys():
        if obs_key in obs_key_to_modality and obs_key_to_modality[obs_key] == Modality.RGB:
            assert len(obs[obs_key].shape) == 3
            assert obs[obs_key].shape[2] == 3

final_img = env.render()
fig, axs = plt.subplots(1, 1)
axs.imshow(final_img)
plt.show()