In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from isaacgym import gymtorch
from isaacgym import gymapi

from bc_algos.envs.isaacgym_simple import IsaacGymEnvSimple
# from bc_algos.dataset.robomimic import RobomimicDataset
from bc_algos.dataset.isaac_gym import IsaacGymDataset
# from bc_algos.envs.robosuite import RobosuiteEnv
from bc_algos.utils.constants import Modality
import bc_algos.utils.constants as Constants
import matplotlib.pyplot as plt
import json
import numpy as np

# change this to test different simulator environments
type = Constants.EnvType.ISAAC_GYM

### Dataset Init

In [None]:
obs_key_to_modality = {
    "robot0_eef_pos": Modality.LOW_DIM,
    "robot0_eef_quat": Modality.LOW_DIM,
    "agentview_image": Modality.RGB
}
obs_group_to_key = {
    "obs": ["robot0_eef_pos", "robot0_eef_quat", "agentview_image"],
    "goal": ["agentview_image"]
}

if type == Constants.EnvType.ROBOSUITE:
    path = "../datasets/test/square_ph.hdf5"
    action_key = "actions"
    demo = "demo_0"

elif type == Constants.DatasetType.ISAAC_GYM:
    path = "/home/markvdm/Documents/IsaacGym/mental_models_envs/out/mm_simple/dataset_v13_diff"
    action_key = "actions"
    demo = 40

In [None]:
if type == Constants.EnvType.ROBOSUITE:
    dataset = RobomimicDataset(
        path=path,
        obs_key_to_modality=obs_key_to_modality,
        obs_group_to_key=obs_group_to_key,
        dataset_keys=[action_key],
        frame_stack=0,
        seq_length=1,
        pad_frame_stack=False,
        pad_seq_length=False,
        get_pad_mask=False,
        demos=[demo],
        normalize=False,
    )

if type == Constants.EnvType.ISAAC_GYM:
    dataset = IsaacGymDataset(
        path=path,
        obs_key_to_modality=obs_key_to_modality,
        obs_group_to_key=obs_group_to_key,
        dataset_keys=[action_key],
        frame_stack=0,
        seq_length=1,
        pad_frame_stack=False,
        pad_seq_length=False,
        get_pad_mask=False,
        demos=[demo],
        normalize=False,
    )

### Environment Init

In [None]:
if type == Constants.EnvType.ROBOSUITE:
    env_meta = json.loads(dataset.hdf5_file["data"].attrs["env_args"])
    env = RobosuiteEnv(
        env_name=env_meta["env_name"],
        obs_key_to_modality=obs_key_to_modality,
        render=False,
        use_image_obs=True,
        use_depth_obs=False,
        **env_meta["env_kwargs"],
    )
elif type == Constants.EnvType.ISAAC_GYM:
    env_cfg_file = "../config/isaac_gym_env.json"
    cfg = json.load(open(env_cfg_file, "r"))
    env = IsaacGymEnvSimple(
        env_name="isaac_gyn_env_simple",
        obs_key_to_modality=obs_key_to_modality,
        render=False,
        use_image_obs=True,
        use_depth_obs=False,
        cfg=cfg,
    )

### Test Environment

In [None]:
if type == Constants.EnvType.ROBOSUITE:
    xml = dataset.hdf5_file[f"data/{demo}"].attrs["model_file"]
    env.load_env(xml=xml)
    init_state = dataset.hdf5_file[f"data/{demo}/states"][0]
    env.reset_to(state=init_state)
elif type == Constants.EnvType.ISAAC_GYM:
    demo_metadata = dataset.dataset[demo]["metadata"]
    env.reset_to(state=demo_metadata)

ref_image = dataset[0]["obs"]["agentview_image"][0]
curr_image = env.render()

fig, axs = plt.subplots(1, 3)
axs[0].imshow(ref_image)
axs[1].imshow(curr_image)
axs[2].imshow((curr_image - ref_image)[:, :, :3])
plt.show()


In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter

# 
# 


# def animate_rollout(i):
for i in range(len(dataset)):
    frame = dataset[i]
    action = frame[action_key][0]

    final_img = env.render()
    ref_image = frame["obs"]["agentview_image"][0]

    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(ref_image)
    axs[0].set_title("Reference")
    axs[1].imshow(final_img)
    axs[1].set_title("Execution")
    plt.savefig(f"../out/test_rollout/{i}.png")
    plt.draw()
    # axs[2].imshow((curr_image - ref_image)[:, :, :3])
    plt.close()

    env.step(action=action)

    # plt.show()


# ani = FuncAnimation(fig, animate_rollout, frames=len(dataset), interval=1.0 / 8.0)
# ani.save("../out/test_rollout/rollout.gif", dpi=300, writer=PillowWriter(fps=8))