In [None]:
from bc_algos.dataset.robomimic import RobomimicDataset
from bc_algos.utils.constants import Modality
import matplotlib.pyplot as plt
import numpy as np

def display(img):
    if not isinstance(img, list):
        img = [img]
    _, axs = plt.subplots(1, len(img))
    for i in range(len(img)):
        axs[i].imshow(img[i].astype(int))
    plt.show()

obs_key_to_modality = {"robot0_eef_pos": Modality.LOW_DIM, "robot0_eef_quat": Modality.LOW_DIM, "agentview_image": Modality.RGB}
obs_group_to_key = {"obs": ["robot0_eef_pos", "robot0_eef_quat", "agentview_image"], 
                     "goal": ["agentview_image"]}
dataset_keys = ["actions"]
dataset_path = "../datasets/test/square_ph.hdf5"
demos = ["demo_0"]

In [None]:
dataset = RobomimicDataset(
    path=dataset_path,
    obs_key_to_modality=obs_key_to_modality,
    obs_group_to_key=obs_group_to_key, 
    dataset_keys=dataset_keys, 
    frame_stack=1,
    seq_length=2,
    goal_mode=None, 
    num_subgoal=None,
    pad_frame_stack=False, 
    pad_seq_length=False,
    demos=demos,
    preprocess=False,
)

### Verify Sequence Fetching

In [None]:
frames0 = dataset[0]["obs"]["agentview_image"]
assert np.any(np.not_equal(frames0[0], frames0[1]))
assert np.any(np.not_equal(frames0[1], frames0[2]))
frames1 = dataset[1]["obs"]["agentview_image"]
assert np.all(np.equal(frames0[1], frames1[0]))
assert np.all(np.equal(frames0[2], frames1[1]))

### Goal Condition on Last Frame

In [None]:
dataset.goal_mode = "last"
dataset.cache_index()
goalF = dataset[0]["goal"]["agentview_image"][0]
goalL = dataset[-1]["goal"]["agentview_image"][0]
assert np.all(np.equal(goalF, goalL))
display([goalF, goalL])

### Dense Subgoals

In [None]:
dataset.goal_mode = "subgoal"
dataset.cache_index()
goal0 = dataset[0]["goal"]["agentview_image"][0]
goal1 = dataset[1]["goal"]["agentview_image"][0]
assert np.any(np.not_equal(goal0, goal1))
display([goal0, goal1])

### Sparse Subgoals

In [None]:
dataset.goal_mode = "subgoal"
dataset.num_subgoal = 10
dataset.cache_index()
goal0 = dataset[0]["goal"]["agentview_image"][0]
goal1 = dataset[1]["goal"]["agentview_image"][0]
goalL = dataset[-1]["goal"]["agentview_image"][0]
assert np.all(np.equal(goal0, goal1))
assert np.any(np.not_equal(goal1, goalL))
display([goal0, goal1, goalL])