In [1]:
from diffusion_policy.dataset.tcl_dataset import TCLImageDataset
import mediapy
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from robokit.debug_utils.images import concatenate_rgb_images, plot_action_wrt_time
from robokit.debug_utils.io import dataloader_speed_test

In [2]:
dataset = TCLImageDataset(
    data_root="/home/geyuan/local_soft/TCL/collected_data_0425",
    horizon=16, pad_before=0, pad_after=7,
    shape_meta={
        "obs": {
            "image": {
                "shape": [3, 240, 320],
                "type": "rgb"
            },
            "gripper": {
                "shape": [3, 240, 320],
                "type": "rgb"
            },
            "joint_state": {
                "shape": [6],
                "type": "low_dim"
            }
        },
        "action": {
            "shape": [7,]
        }
    }
)
dataloader_speed_test(dataset)

[TCLDataset] loaded key=rel_actions shape=(4169, 7) from /home/geyuan/local_soft/TCL/collected_data_0425/extracted/rel_actions.npy
[TCLDataset] total length: 4169
[TCLDataset] loading dataset statistics from: /home/geyuan/local_soft/TCL/collected_data_0425/statistics.json
[TCLImageDataset] dataset loaded, action_min=[-0.09999695 -0.1        -0.09999695 -0.09056091 -0.49610901 -0.49998474
  0.        ], action_max=[0.1        0.09999695 0.1        0.49998474 0.26126099 0.5
 1.        ]


In [3]:
global_idx = 0

task_skip_idx = 1000
fps = 10

# task_skip_idx = 35
# fps = 5

print("Total tasks:", len(dataset.tcl_dataset.tasks))

for task_idx, task in enumerate(dataset.tcl_dataset.tasks):
    task_length = dataset.tcl_dataset.task_lengths[task_idx]
    images_primary, images_gripper = [], []
    images_cat = []
    actions = []

    if task_idx < task_skip_idx:
        global_idx += task_length
        continue

    for frame_idx in tqdm(range(task_length)):
        frame_data = dataset.tcl_dataset[global_idx]
        global_idx += 1

        images_primary.append(frame_data['primary_rgb'])
        images_gripper.append(frame_data['gripper_rgb'])
        images_cat.append(concatenate_rgb_images(frame_data['primary_rgb'], frame_data['gripper_rgb'], vertical=True))
        actions.append(frame_data['rel_actions'])

    all_vis = []
    actions_vis, fig, ax = plot_action_wrt_time(np.array(actions))
    for frame_idx in range(task_length):
        all_vis.append(concatenate_rgb_images(images_cat[frame_idx], actions_vis[frame_idx],
                                              vertical=False, smaller_size=1))

    mediapy.show_video(all_vis, fps=fps)

    break

Total tasks: 51


100%|███████████████████████████████████████████████████████████████████████████████████| 74/74 [00:01<00:00, 39.59it/s]


0
This browser does not support the video tag.


In [5]:
batch_data = dataset[10]

obs_image = []
obs_gripper = []
obs_joint_state = []
action = []

obs_image_data = batch_data['obs']['image']  # (T,C,H,W), in [-1,1]
obs_gripper_data = batch_data['obs']['gripper']
obs_image_data = (obs_image_data.permute(0, 2, 3, 1).numpy() * 255.).astype(np.uint8)
obs_gripper_data = (obs_gripper_data.permute(0, 2, 3, 1).numpy() * 255.).astype(np.uint8)
for obs_image_frame in obs_image_data:
    obs_image.append(obs_image_frame)
for obs_gripper_frame in obs_gripper_data:
    obs_gripper.append(obs_gripper_frame)
mediapy.show_video(obs_image, fps=5)
mediapy.show_video(obs_gripper, fps=5)

action_data = batch_data['action']  # (T,7)
for action_frame in action_data:
    action.append(action_frame)  # each is (7)
actions_vis, fig, ax = plot_action_wrt_time(np.array(action))
mediapy.show_video(actions_vis, fps=5)



0
This browser does not support the video tag.


0
This browser does not support the video tag.
