In [1]:
from bc_algos.dataset.robomimic import RobomimicDataset
from bc_algos.utils.constants import Modality, GoalMode
import bc_algos.utils.obs_utils as ObsUtils
import matplotlib.pyplot as plt
import numpy as np

def display(img):
    if not isinstance(img, list):
        img = [img]
    _, axs = plt.subplots(1, len(img))
    for i in range(len(img)):
        axs[i].imshow(img[i].astype(int))
    plt.show()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from bc_algos.utils.misc import load_gzip_pickle
test = load_gzip_pickle("../datasets/dataset_v4/run_0.pkl.gzip")
print(test.keys())
print(test["state"].keys())
print(test["obs"].keys())
print(test["policy"].keys())
print(test["metadata"].keys())
print(test["metadata"]["num_steps"])
print(test["obs"]["images"].shape)

### Normalization Utils Unit Tests

In [2]:
T = 10
state_shape = [2, 2,]

traj0_dict = {
    "state": 2*np.random.randn(T, *state_shape)-1,
}
traj1_dict = {
    "state": 2*np.random.randn(T, *state_shape)-1,
}

In [3]:
traj0_stats = ObsUtils.compute_traj_stats(traj0_dict)
traj1_stats = ObsUtils.compute_traj_stats(traj1_dict)
merged_stats = ObsUtils.aggregate_traj_stats(traj0_stats, traj1_stats)
traj0_norm_stats = ObsUtils.compute_normalization_stats(traj0_stats)
merged_norm_stats = ObsUtils.compute_normalization_stats(merged_stats)

traj0_state = traj0_dict["state"]
assert np.allclose(traj0_norm_stats["state"]["mean"], traj0_state.mean(axis=0))
assert np.allclose(traj0_norm_stats["state"]["stdv"], np.std(traj0_state, axis=0))
merged_state = np.concatenate((traj0_dict["state"], traj1_dict["state"]), axis=0)
assert np.allclose(merged_norm_stats["state"]["mean"], merged_state.mean(axis=0))
assert np.allclose(merged_norm_stats["state"]["stdv"], np.std(merged_state, axis=0))

### Dataset Init

In [45]:
obs_key_to_modality = {"robot0_eef_pos": Modality.LOW_DIM, "robot0_eef_quat": Modality.LOW_DIM, "agentview_image": Modality.RGB}
obs_group_to_key = {"obs": ["robot0_eef_pos", "robot0_eef_quat", "agentview_image"], 
                     "goal": ["agentview_image"]}
dataset_keys = ["actions"]
dataset_path = "../datasets/test/square_ph.hdf5"
demos = ["demo_0"]
frame_stack = 1
seq_length = 2

In [46]:
dataset = RobomimicDataset(
    path=dataset_path,
    obs_key_to_modality=obs_key_to_modality,
    obs_group_to_key=obs_group_to_key, 
    dataset_keys=dataset_keys, 
    frame_stack=frame_stack,
    seq_length=seq_length,
    goal_mode=None, 
    num_subgoal=None,
    pad_frame_stack=True, 
    pad_seq_length=True,
    get_pad_mask=True,
    demos=demos,
    preprocess=False,
    normalize=True,
)

caching index: 100%|██████████| 127/127 [00:00<00:00, 33750.02demo/s]
loading dataset into memory: 100%|██████████| 1/1 [00:00<00:00, 106.55demo/s]
computing normalization stats: 100%|██████████| 1/1 [00:00<00:00, 2226.28demo/s]
normalizing data: 100%|██████████| 1/1 [00:00<00:00, 7810.62demo/s]


### Padding Unit Tests

In [47]:
for i in range(frame_stack):
    pad_mask = dataset[i]["pad_mask"]
    gt_mask = np.array([0] * (frame_stack-i) + [1] * (seq_length+i))
    assert np.all(np.equal(pad_mask, gt_mask))
for i in range(1, seq_length):
    pad_mask = dataset[-i]["pad_mask"]
    gt_mask = np.array([1] * (frame_stack+i) + [0] * (seq_length-i))
    assert np.all(np.equal(pad_mask, gt_mask))

### Sequence Fetching Unit Tests

In [57]:
for i in range(frame_stack, len(dataset)-(seq_length-1)):
    frame = dataset[i]
    obs = frame["obs"]["agentview_image"]
    T = obs.shape[0]
    for j in range(T-1):
        assert np.any(np.not_equal(obs[j], obs[j+1]))

# for i in range(frame_stack, len(dataset)-1):
#     frame_a = dataset[i]
#     frame_b = dataset[i+1]
#     obs_a = frame_a["obs"]["agentview_image"]
#     obs_b = frame_b["obs"]["agentview_image"]
#     T = obs_a.shape[0]
#     for j in range(1, T):
#         assert np.any(np.not_equal(obs_a[j-1], obs_a[j]))
    
# framefs = dataset[frame_stack]["obs"]["agentview_image"]
# assert np.any(np.not_equal(frames0[0], frames0[1]))
# assert np.any(np.not_equal(frames0[1], frames0[2]))
# frames1 = dataset[1]["obs"]["agentview_image"]
# assert np.all(np.equal(frames0[1], frames1[0]))
# assert np.all(np.equal(frames0[2], frames1[1]))

### Test Goal Condition on Last Frame

In [None]:
dataset.goal_mode = GoalMode.LAST
dataset.cache_index()
itemF = dataset[0]
itemL = dataset[-1]
T_obs = itemF["obs"]["agentview_image"].shape[0]
T_goal = itemF["goal"]["agentview_image"].shape[0]
assert T_obs == T_goal
goalF = itemF["goal"]["agentview_image"][0]
goalL = itemL["goal"]["agentview_image"][0]
assert np.all(np.equal(goalF, goalL))
display([goalF, goalL])

### Test Dense Subgoals

In [None]:
dataset.goal_mode = GoalMode.
dataset.cache_index()
goal0 = dataset[0]["goal"]["agentview_image"][0]
goal1 = dataset[1]["goal"]["agentview_image"][0]
assert np.any(np.not_equal(goal0, goal1))
display([goal0, goal1])

### Test Sparse Subgoals

In [None]:
dataset.goal_mode = "subgoal"
dataset.num_subgoal = 10
dataset.cache_index()
goal0 = dataset[0]["goal"]["agentview_image"][0]
goal1 = dataset[1]["goal"]["agentview_image"][0]
goalL = dataset[-1]["goal"]["agentview_image"][0]
assert np.all(np.equal(goal0, goal1))
assert np.any(np.not_equal(goal1, goalL))
display([goal0, goal1, goalL])