In [None]:
from bc_benchmark_algos.dataset.robomimic import RobomimicDataset
import robomimic.utils.obs_utils as ObsUtils
import matplotlib.pyplot as plt
import numpy as np

modality_to_obs_keys = {"low_dim": ["robot0_eef_pos", "robot0_eef_quat", "actions"],
                        "rgb": ["agentview_image"]}
obs_group_to_keys = {"obs": ["robot0_eef_pos", "robot0_eef_quat", "agentview_image"], 
                     "goal": ["agentview_image"]}
dataset_keys = ["actions"]

ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping=modality_to_obs_keys)

dataset_path = "../../datasets/test/square_ph.hdf5"
dataset = RobomimicDataset(
    hdf5_path=dataset_path,
    obs_group_to_keys=obs_group_to_keys, 
    dataset_keys=dataset_keys, 
    goal_mode="subgoal", 
    num_subgoal=None,
    hdf5_cache_mode=None,
    # pad_frame_stack=False, 
    # pad_seq_length=False,
    )

In [None]:
# verify gc on last frame only
print("gc on last frame only")
dataset.goal_mode = "last"
goal0 = dataset[0]["goal"]["agentview_image"][0]
goal99 = dataset[99]["goal"]["agentview_image"][0]
assert np.all(np.equal(goal0, goal99))
fig, axs = plt.subplots(1, 2)
axs[0].imshow(goal0)
axs[1].imshow(goal99)
plt.show()

In [None]:
# verify dense subgoal
print("dense subgoals")
dataset.goal_mode = "subgoal"
goal98 = dataset[98]["goal"]["agentview_image"][0]
goal99 = dataset[99]["goal"]["agentview_image"][0]
assert np.any(np.not_equal(goal0, goal99))
fig, axs = plt.subplots(1, 2)
axs[0].imshow(goal98)
axs[1].imshow(goal99)
plt.show()

In [None]:
# verify sparse subgoal
print("sparse subgoals")
dataset.goal_mode = "subgoal"
dataset.num_subgoal = 15 # ~ every 10 frames
goal0 = dataset[0]["goal"]["agentview_image"][0]
goal1 = dataset[1]["goal"]["agentview_image"][0]
goal14 = dataset[14]["goal"]["agentview_image"][0]
assert np.all(np.equal(goal0, goal1))
assert np.any(np.not_equal(goal1, goal14))
fig, axs = plt.subplots(1, 3)
axs[0].imshow(goal0)
axs[1].imshow(goal1)
axs[2].imshow(goal14)
plt.show()