In [None]:
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.train_utils as TrainUtils
import robomimic.utils.torch_utils as TorchUtils
from bc_benchmark_algos.dataset.robomimic import RobomimicDataset
from bc_benchmark_algos.rollout_env.robomimic import RobomimicRolloutEnv
from robomimic.config import config_factory
from robomimic.algo import algo_factory, RolloutPolicy
import matplotlib.pyplot as plt
import torch
import numpy as np
import json

In [None]:
# setup config
config_path = "../config/bc_rnn.json"
dataset_path = "../../datasets/test/square_ph.hdf5"
output_dir = "output"
ext_cfg = json.load(open(config_path, 'r'))
config = config_factory(ext_cfg["algo_name"])
with config.unlocked():
    config.update(ext_cfg)
config.train.data = dataset_path
config.train.output_dir = output_dir
config.train.frame_stack = 2
config.train.seq_length = 1
config.lock()

ObsUtils.initialize_obs_utils_with_config(config)

In [None]:
# test tensor utils
x = {"obs": {"agentview_image": np.random.randn(84, 84, 3)}}
x = TensorUtils.to_tensor(x)
assert isinstance(x["obs"]["agentview_image"], torch.Tensor) 
x = TensorUtils.to_batch(x)
assert x["obs"]["agentview_image"].shape == (1, 84, 84, 3)
x = TensorUtils.to_sequence(x)
assert x["obs"]["agentview_image"].shape == (1, 1, 84, 84, 3)
x = TensorUtils.repeat_seq(x=x, k=10)
assert x["obs"]["agentview_image"].shape == (1, 10, 84, 84, 3)
x = TensorUtils.slice(x=x, dim=1, start=0, end=5)
assert x["obs"]["agentview_image"].shape == (1, 5, 84, 84, 3)
y = TensorUtils.shift_seq(x=x, k=1)
assert torch.equal(x["obs"]["agentview_image"][:, 0, :], y["obs"]["agentview_image"][:, 1, :])
y = TensorUtils.shift_seq(x=x, k=-1)
assert torch.equal(x["obs"]["agentview_image"], y["obs"]["agentview_image"])

In [None]:
# create validset
dataset_path = "../../datasets/test/square_ph.hdf5"
validset = RobomimicDataset.dataset_factory(
    config=config,
    obs_group_to_keys=ObsUtils.OBS_GROUP_TO_KEYS, 
    filter_by_attribute="valid"
    )

In [None]:
# create rollout env
_, _, video_dir = TrainUtils.get_exp_dir(config)
rollout_env = RobomimicRolloutEnv(config=config, validset=validset)
print(video_dir)
print(rollout_env.env_meta)

In [None]:
# test inputs_from_initial_obs
demo_id = validset.demos[0]
initial_state = dict(states=validset.hdf5_file[f"data/{demo_id}/states"][0])
initial_state["model"] = validset.hdf5_file[f"data/{demo_id}"].attrs["model_file"]
rollout_env.env.reset()
obs = rollout_env.env.reset_to(initial_state)
inputs = rollout_env.inputs_from_initial_obs(obs=obs, demo_id=demo_id)
assert inputs["goal"]["agentview_image"].shape == (1, config.train.frame_stack+1, 84, 84, 3)
assert np.all(np.equal(inputs["goal"]["agentview_image"][0, 0, :], inputs["goal"]["agentview_image"][0, 1, :]))
assert np.all(np.equal(inputs["goal"]["agentview_image"][0, 1, :], inputs["goal"]["agentview_image"][0, 2, :]))
fig, axs = plt.subplots(2, 3)
axs[0, 0].imshow(inputs["goal"]["agentview_image"][0, 0, :])
axs[0, 1].imshow(inputs["goal"]["agentview_image"][0, 1, :])
axs[0, 2].imshow(inputs["goal"]["agentview_image"][0, 2, :])
axs[1, 0].imshow(inputs["obs"]["agentview_image"][0, 0, :])
axs[1, 1].imshow(inputs["obs"]["agentview_image"][0, 1, :])
axs[1, 2].imshow(inputs["obs"]["agentview_image"][0, 2, :])
plt.show()

In [None]:
# test inputs_from_new_obs
inputs = rollout_env.inputs_from_new_obs(x=inputs, obs=obs, demo_id=demo_id, t=10)
assert inputs["obs"]["agentview_image"].shape == (1, config.train.frame_stack+1, 84, 84, 3)
assert np.any(np.not_equal(inputs["goal"]["agentview_image"][0, 0, :], inputs["goal"]["agentview_image"][0, 1, :]))
assert np.any(np.not_equal(inputs["goal"]["agentview_image"][0, 1, :], inputs["goal"]["agentview_image"][0, 2, :]))
fig, axs = plt.subplots(2, 3)
axs[0, 0].imshow(inputs["goal"]["agentview_image"][0, 0, :])
axs[0, 1].imshow(inputs["goal"]["agentview_image"][0, 1, :])
axs[0, 2].imshow(inputs["goal"]["agentview_image"][0, 2, :])
axs[1, 0].imshow(inputs["obs"]["agentview_image"][0, 0, :])
axs[1, 1].imshow(inputs["obs"]["agentview_image"][0, 1, :])
axs[1, 2].imshow(inputs["obs"]["agentview_image"][0, 2, :])
plt.show()

In [None]:
# create model
device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
ac_dim = config.train.ac_dim
model = algo_factory(
    algo_name=config.algo_name,
    config=config,
    obs_key_shapes=ObsUtils.OBS_SHAPES,
    ac_dim=ac_dim,
    device=device,
)

In [None]:
# full rollout
rollout_model = RolloutPolicy(policy=model)
rollout_env.rollout_with_stats(
    policy=rollout_model,
    demo_id=validset.demos[0],
    video_dir=video_dir
)