In [1]:
from pprint import pprint
from typing import Any, Literal, NewType, TypedDict

# for performing runtime typechking in a iPython environment.
import jaxtyping
import lerobot
import numpy as np
import rerun as rr
import torch
from beartype.door import die_if_unbearable
from einops import rearrange
from huggingface_hub import HfApi
from jaxtyping import Float32, UInt8
from lerobot.common.datasets.lerobot_dataset import (
    LeRobotDataset,
    LeRobotDatasetMetadata,
)

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

# Prereqs

The best way that I learn is understanding inputs/outputs with beartype and jaxtyping. I'll lay out some motivating examples and then move onto lerobot and its dataset

Hopefully this will be much less verbose in the future when these two things land, for now die_if_unbearable is needed to validate types when doing variable assigment
1. https://github.com/beartype/ipython-beartype
2. https://github.com/beartype/beartype/issues/492

#### Start with basic types to show how die_if_unbearable
This is to validate on variable assignment that the variable is what the type expects

In [2]:
sample_number:int = 10
die_if_unbearable(sample_number, int) # should not raise an exception

sample_text: int = "hello" # doesn't raise an exception, but should
try:
    die_if_unbearable(sample_text, int) # will raise an exception since the type is not int
except Exception as e:
    print(e)

Die_if_unbearable() value 'hello' violates type hint <class 'int'>, as str 'hello' not instance of int.


Using jaxtyping to validate dtype and shape of tensor/numpy arrays

### To start with dtype

In [3]:
# Checking type for a numpy array, should not raise an exception
sample_array: Float32[np.ndarray, "..."] = np.array([1, 2, 3], dtype=np.float32)
try:
    die_if_unbearable(sample_array, Float32[np.ndarray, "..."]) # should not raise an exception
except Exception as e:
    print(e)

In [4]:
# checking type for a torch tensor, should raise an exception
sample_tensor: Float32[torch.Tensor, "..."] = torch.tensor([1, 2, 3], dtype=torch.float64)
try:
    die_if_unbearable(sample_tensor, Float32[torch.Tensor, "..."]) # should raise an expection as it is the wrong type
except Exception as e:
    print(e)

Die_if_unbearable() value "tensor([1., 2., 3.], dtype=torch.float64)" violates type hint <class 'jaxtyping.Float32[Tensor, '...']'>, as this array has dtype float64, not float32 as expected by the type hint.


### Now tensor shape

In [5]:
# example of correct shape and type for a torch tensor
sample_3x3_tensor: Float32[torch.Tensor, "3 3"] = torch.rand((3, 3), dtype=torch.float32)
try:
    die_if_unbearable(sample_3x3_tensor, Float32[torch.Tensor, "3 3"]) # should not raise an exception
except Exception as e:
    print(e)


In [6]:
# example of incorrect shape but correct type for a torch tensor
sample__4x4_tensor: Float32[torch.Tensor, "3 3"] = torch.rand((4, 4), dtype=torch.float32)
try:
    die_if_unbearable(sample__4x4_tensor, Float32[torch.Tensor, "3 3"]) # should raise an exception
except Exception as e:
    print(e)

Die_if_unbearable() value "tensor([[0.6003, 0.1025, 0.1948, 0.0770],
        [0.6433, 0.0538, 0.5310, 0.9717],
    ...]])" violates type hint <class 'jaxtyping.Float32[Tensor, '3 3']'>, as the dimension size 4 does not equal 3 as expected by the type hint.


### Type checking function inputs and outputs

In [7]:

def add_numbers(a: int, b: int) -> int:
    return a + b

def process_numbers(numbers: list[int]) -> Float32[np.ndarray, "..."]:
    array:Float32[np.ndarray, "..."] = np.array(numbers, dtype=np.float32)
    return array

# Validate function inputs
try:
    die_if_unbearable(add_numbers(5, 10), int)  # should not raise an exception
except Exception as e:
    print(e)

try:
    die_if_unbearable(add_numbers(5, "10"), int)  # should raise an exception
except Exception as e:
    print(e)

# Validate function outputs
try:
    result = process_numbers([1, 2, 3])
    die_if_unbearable(result, Float32[np.ndarray, "..."])  # should not raise an exception
except Exception as e:
    print(e)

try:
    result = process_numbers([1, 2, "3"])  # should raise an exception
    die_if_unbearable(result, Float32[np.ndarray, "..."])
except Exception as e:
    print(e)

Type-check error whilst checking the parameters of __main__.add_numbers.
The problem arose whilst typechecking parameter 'b'.
Actual value: '10'
Expected type: <class 'int'>.
----------------------
Called with parameters: {'a': 5, 'b': '10'}
Parameter annotations: (a: int, b: int) -> Any.

Type-check error whilst checking the parameters of __main__.process_numbers.
----------------------
Called with parameters: {'numbers': [1, 2, '3']}
Parameter annotations: (numbers: list[int]) -> Any.



# Explore Available Datasets
With type checking out of the way, lets take a look at lerobot!

To explore different available datasets, we can either 
1. directly look at those provided by lerobot
2. Using hfapi
3. One can also simple browse [Huggingface LeRobot](https://huggingface.co/datasets?other=LeRobot
) directly


In [8]:
# We ported a number of existing datasets ourselves, use this to see the list:
available_datasets: list[str] = lerobot.available_datasets
die_if_unbearable(available_datasets, list[str])

print(F"Total number of available datasets through lerobot: {len(available_datasets)}\n")

# print the first ten
print("First ten available datasets through lerobot:")
pprint(available_datasets[:10])

# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids: list[str] = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
die_if_unbearable(repo_ids, list[str])

print(F"\nTotal number of datasets in the hub with the tag 'LeRobot': {len(repo_ids)}\n")


Total number of available datasets through lerobot: 91

First ten available datasets through lerobot:
['lerobot/aloha_mobile_cabinet',
 'lerobot/aloha_mobile_chair',
 'lerobot/aloha_mobile_elevator',
 'lerobot/aloha_mobile_shrimp',
 'lerobot/aloha_mobile_wash_pan',
 'lerobot/aloha_mobile_wipe_wine',
 'lerobot/aloha_sim_insertion_human',
 'lerobot/aloha_sim_insertion_human_image',
 'lerobot/aloha_sim_insertion_scripted',
 'lerobot/aloha_sim_insertion_scripted_image']

Total number of datasets in the hub with the tag 'LeRobot': 1856



# Taking a look at an example dataset
I use type annotations extensivly on variable assignments to help me better understand exactly what I'm working with

In [9]:
dataset_idx:int = 5
die_if_unbearable(dataset_idx, int)

# repo_id:str = "lerobot/aloha_static_coffee_new" #available_datasets[dataset_idx]
# repo_id:str =  available_datasets[dataset_idx]
repo_id:str = "lerobot/aloha_static_cups_open"

print(repo_id)

die_if_unbearable(repo_id, str)

# We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id)

# By instantiating just this class, you can quickly access useful information about the content and the
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
total_episodes: int = ds_meta.total_episodes
print(f"Total number of episodes: {total_episodes}")
avg_frames_per_episode: float = ds_meta.total_frames / total_episodes
print(f"Average number of frames per episode: {avg_frames_per_episode:.3f}")
fps: int = ds_meta.fps
print(f"Frames per second used during data collection: {fps}")
robot_type: str = ds_meta.robot_type
print(f"Robot type: {robot_type}")
camera_keys: list[str] = ds_meta.camera_keys
print(f"keys to access images from cameras: {camera_keys=}\n")

print(ds_meta)

lerobot/aloha_static_cups_open
Total number of episodes: 50
Average number of frames per episode: 400.000
Frames per second used during data collection: 50
Robot type: aloha
keys to access images from cameras: camera_keys=['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_low', 'observation.images.cam_right_wrist']

LeRobotDatasetMetadata({
    Repository ID: 'lerobot/aloha_static_cups_open',
    Total episodes: '50',
    Total frames: '20000',
    Features: '['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_low', 'observation.images.cam_right_wrist', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'task_index']',
})',



In [10]:
# look more closely at the tasks, create a new type for the task id and description to be more specific
TaskID = NewType('TaskID', int)
TaskDescription = NewType('TaskDescription', str)

tasks:dict[TaskID, TaskDescription] = ds_meta.tasks
die_if_unbearable(tasks, dict[TaskID, TaskDescription])
print(f"Tasks:\n{tasks}")

Tasks:
{0: 'Pick up the plastic cup and open its lid.'}


# Features include things like observations, video, timestamp, ect
```
├ observation.images.cam_high (VideoFrame):
│  │   VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
│  ├ observation.state (list of float32): position of an arm joints (for instance)
│  ... (more observations)
│  ├ action (list of float32): goal position of an arm joints (for instance)
│  ├ episode_index (int64): index of the episode for this sample
│  ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
│  ├ timestamp (float32): timestamp in the episode
│  ├ next.done (bool): indicates the end of en episode ; True for the last frame in each episode
│  └ index (int64): general index in the whole dataset
```

In [11]:
class BaseFeature(TypedDict, total=False):
    """
    For dtype in {"float32", "int64", "bool", "string"}.
    'shape' is a tuple of ints, 'names' can be None, a list, or a dict.
    """
    dtype: Literal["float32", "int64", "bool", "string"]
    shape: tuple[int, ...]
    names: list[str] | dict[str, list[str]] | None

class VideoFeature(TypedDict, total=False):
    """
    For dtype == "video".
    Must include 'video_info'.
    """
    dtype: Literal["video"]
    shape: tuple[int, ...]
    names: list[str] | dict[str, list[str]] | None
    video_info: dict[str, Any]


Feature = BaseFeature | VideoFeature
FeatureName = NewType('FeatureName', str)
FeaturesDict = dict[FeatureName, Feature]

features:FeaturesDict= ds_meta.features
die_if_unbearable(features, FeaturesDict)
print("Features:\n")
pprint(features)

Features:

{'action': {'dtype': 'float32',
            'names': {'motors': ['left_waist',
                                 'left_shoulder',
                                 'left_elbow',
                                 'left_forearm_roll',
                                 'left_wrist_angle',
                                 'left_wrist_rotate',
                                 'left_gripper',
                                 'right_waist',
                                 'right_shoulder',
                                 'right_elbow',
                                 'right_forearm_roll',
                                 'right_wrist_angle',
                                 'right_wrist_rotate',
                                 'right_gripper']},
            'shape': (14,)},
 'episode_index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
 'frame_index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
 'index': {'dtype': 'int64', 'names': None, 'shape': (1,)},
 'next.done': {'d

# Load full dataset, not just metadata

In [12]:
# You can then load the actual dataset from the hub.
# Either load any subset of episodes:
episodes: list[int] = [0, 4, 9]
die_if_unbearable(episodes, list[int])
dataset = LeRobotDataset(repo_id, episodes=episodes)

# And see how many frames you have:
print(f"Selected episodes: {dataset.episodes}")
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")

# Or simply load the entire dataset:
full_dataset = LeRobotDataset(repo_id)
print(f"Number of episodes selected: {full_dataset.num_episodes}")
print(f"Number of frames selected: {full_dataset.num_frames}")

# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(full_dataset.meta)

# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(full_dataset.hf_dataset)

Selected episodes: [0, 4, 9]
Number of episodes selected: 3
Number of frames selected: 1200


Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

Number of episodes selected: 50
Number of frames selected: 20000
LeRobotDatasetMetadata({
    Repository ID: 'lerobot/aloha_static_cups_open',
    Total episodes: '50',
    Total frames: '20000',
    Features: '['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_low', 'observation.images.cam_right_wrist', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'task_index']',
})',

Dataset({
    features: ['observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'task_index'],
    num_rows: 20000
})


In [13]:
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# frame indices associated to the first episode:
episode_index:int = 0
die_if_unbearable(episode_index, int)

from_idx:int = dataset.episode_data_index["from"][episode_index].item()
to_idx:int = dataset.episode_data_index["to"][episode_index].item()
print(f"Episode {episode_index} starts at frame {from_idx} and ends at frame {to_idx}")
# # Then we grab all the image frames from the first camera:
camera_key: str = dataset.meta.camera_keys[0]
die_if_unbearable(camera_key, str)

# loading all frames like this can take a while, so only do the first 10 frames
frames:list[Float32[torch.Tensor, "3 H W"]] = [dataset[idx][camera_key] for idx in range(from_idx, 150)]
die_if_unbearable(frames, list[Float32[torch.Tensor, "3 H W"]])

# The show that the frames are float32 tensors with shape (3, H, W) or in this case (3, 480, 640)
print(type(frames[0]))
print(frames[0].shape)
print(frames[0].dtype)

# lets visualize these images

rr.init("lerobot images")
for idx, frame in enumerate(frames):
    rr.set_time_sequence("frame_idx", idx)
    # convert to h w 3
    rgb_tensor:Float32[torch.Tensor, "H W 3"] = rearrange(frame, 'C H W -> H W C')
    die_if_unbearable(rgb_tensor, Float32[torch.Tensor, "H W 3"])

    rgb_array:Float32[np.ndarray, "H W 3"] = rgb_tensor.numpy(force=True)
    die_if_unbearable(rgb_array, Float32[np.ndarray, "H W 3"])

    # convert from 0-1 to 0-255 and convert to uint8
    rgb_array:UInt8[np.ndarray, "H W 3"] = (rgb_array * 255).astype(np.uint8)
    die_if_unbearable(rgb_array, UInt8[np.ndarray, "H W 3"])

    rr.log("image", rr.Image(rgb_array).compress(jpeg_quality=70))

rr.notebook_show()

Episode 0 starts at frame 0 and ends at frame 400
<class 'torch.Tensor'>
torch.Size([3, 480, 640])
torch.float32


Viewer()

In [14]:
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps:dict[str, list[float | int]] = {
    # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
    camera_key: [-1, -0.5, -0.20, 0],
    # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
    "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
    # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
    "action": [t / dataset.fps for t in range(64)],
}
die_if_unbearable(delta_timestamps, dict[str, list[float | int]])
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.

In [15]:
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)

delta_frames:Float32[torch.Tensor, "4 3 H W"] = dataset[0][camera_key]
die_if_unbearable(delta_frames, Float32[torch.Tensor, "4 3 H W"])

# positsions of arm joints and
delta_states:Float32[torch.Tensor, "6 num_motors"] = dataset[0]["observation.state"]
die_if_unbearable(delta_states, Float32[torch.Tensor, "6 num_motors"])

# actions to be taken
delta_actions:Float32[torch.Tensor, "64 num_motors"] = dataset[0]["action"]
die_if_unbearable(delta_actions, Float32[torch.Tensor, "64 num_motors"])


print(f"\n{delta_frames.shape=}")  # (4, c, h, w)
print(f"{delta_states.shape=}")  # (6, num_motors)
print(f"{delta_actions.shape=}\n")  # (64, num_motors)

# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=32,
    shuffle=True,
)

for batch in dataloader:
    batch_frames:Float32[torch.Tensor, "32 4 3 H W"] = batch[camera_key]
    die_if_unbearable(batch_frames, Float32[torch.Tensor, "32 4 3 H W"])

    batch_states:Float32[torch.Tensor, "32 6 num_motors"] = batch["observation.state"]
    die_if_unbearable(batch_states, Float32[torch.Tensor, "32 6 num_motors"])

    batch_actions:Float32[torch.Tensor, "32 64 num_motors"] = batch["action"]
    die_if_unbearable(batch_actions, Float32[torch.Tensor, "32 64 num_motors"])
    break

Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]


delta_frames.shape=torch.Size([4, 3, 480, 640])
delta_states.shape=torch.Size([6, 14])
delta_actions.shape=torch.Size([64, 14])



# Finally lets visualize the dataset using rerun

In [17]:
import rerun.blueprint as rrb
import tqdm
from jaxtyping import Int64
from lerobot.scripts.visualize_dataset import EpisodeSampler

rr.init("Final notebook visualization")

dataset = LeRobotDataset(repo_id)

episode_index:int = 0
episode_sampler = EpisodeSampler(dataset, episode_index)

batch_size:int = 32
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=batch_size,
    sampler=episode_sampler,
)

blueprint =rrb.Blueprint(
    rrb.Vertical(
        rrb.Grid(contents=[rrb.Spatial2DView(origin=key) for key in dataset.meta.camera_keys]),
        rrb.TimeSeriesView(),
    ),
    collapse_panels=True
)
rr.log("test", rr.Scalar(0.0))
rr.notebook_show(blueprint=blueprint, height=500,width=1000)
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
    # iterate over the batch
    batch_index:Int64[torch.Tensor, "batch_size"] = batch["index"]
    die_if_unbearable(batch_index, Int64[torch.Tensor, "batch_size"])
    for i in range(len(batch_index)):
        frame_index:int = batch["frame_index"][i].item()
        die_if_unbearable(frame_index, int)
        rr.set_time_sequence("frame_index", frame_index)

        timestamp:float = batch["timestamp"][i].item()
        die_if_unbearable(timestamp, float)
        rr.set_time_seconds("timestamp", timestamp)
    
        # display each camera image
        for key in dataset.meta.camera_keys:
            # convert from tensor format to numpy
            rgb_tensor:Float32[torch.Tensor, "3 H W"] = batch[key][i]
            rgb_array:Float32[np.ndarray, "H W 3"] = rearrange(rgb_tensor, 'C H W -> H W C').numpy(force=True)
            rgb_array:UInt8[np.ndarray, "H W 3"] = (rgb_array * 255).astype(np.uint8)
            rr.log(
                key,
                rr.Image(rgb_array).compress(
                    jpeg_quality=95
                ),
            )
    
        # display each dimension of action space (e.g. actuators command)
        if "action" in batch:
            for dim_idx, val in enumerate(batch["action"][i]):
                rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
    
        # display each dimension of observed state space (e.g. agent position in joint space)
        if "observation.state" in batch:
            for dim_idx, val in enumerate(batch["observation.state"][i]):
                rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
    
        if "next.done" in batch:
            rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
    
        if "next.reward" in batch:
            rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
    
        if "next.success" in batch:
            rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
            

Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

Viewer()

[2025-02-11T00:14:20Z ERROR re_log::result_extensions] rerun_py/src/python_bridge.rs:835 ZMQError: Too many open files                                                                                      | 1/13 [00:03<00:45,  3.77s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:45<00:00,  3.46s/it]
