In [5]:
import torch
from torch.utils.data import DataLoader

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

REPO_ID = "yilin404/pick_and_place"
DATA_ROOT = "/home/yilin/dataset/own_episode_data"

# Set up the dataset.
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.images.colors_camera_top": [-0.1, 0.0],
    "observation.images.colors_camera_wrist": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset(REPO_ID, root=DATA_ROOT, delta_timestamps=delta_timestamps)

print("===> logging dataset info...")
for key, value in dataset[0].items():
    print("key name is: ", key)
    if key in delta_timestamps.keys():
        print(value.size(), type(value))
print("===> logging dataset info...\n")

print("===> logging dataset stats...")
print(dataset.stats)
print("===> logging dataset stats...\n")

# Create dataloader for offline training.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=32,
    shuffle=True,
    pin_memory=device != torch.device("cpu"),
    drop_last=True,
)
for batch in dataloader:
    batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
    print("===> logging dataloader info...")
    for key, value in batch.items():
        print("key name is: ", key)
        if key in delta_timestamps.keys():
            print(value.size(), type(value))
    print("===> logging dataloader info...\n")
    
    break


===> logging dataset info...
key name is:  observation.state
torch.Size([2, 7]) <class 'torch.Tensor'>
key name is:  action
torch.Size([16, 7]) <class 'torch.Tensor'>
key name is:  episode_index
key name is:  frame_index
key name is:  timestamp
key name is:  next.done
key name is:  observation.images.colors_camera_top
torch.Size([2, 3, 480, 640]) <class 'torch.Tensor'>
key name is:  observation.images.colors_camera_wrist
torch.Size([2, 3, 480, 640]) <class 'torch.Tensor'>
key name is:  index
key name is:  observation.images.colors_camera_top_is_pad
key name is:  observation.images.colors_camera_wrist_is_pad
key name is:  observation.state_is_pad
key name is:  action_is_pad
===> logging dataset info...

===> logging dataset stats...
{'action': {'max': tensor([ 2.8798,  3.1023, -0.2012,  2.6279,  1.6581,  2.8798,  0.0300]), 'mean': tensor([ 1.6240,  1.7822, -1.0750,  1.3828, -0.7761, -0.0124,  0.0192]), 'min': tensor([-1.3824,  0.0721, -3.3088, -2.8740, -1.6581, -2.8798,  0.0000]), 'std'