In [1]:
from pathlib import Path
from typing import Any, Literal, NewType, TypedDict

from pathlib import Path

import torch

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.configs.types import FeatureType
from beartype.door import die_if_unbearable, infer_hint

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

# Prereqs before training

In [2]:
 # Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# # Select your device
device:torch.device = torch.device("cuda")
die_if_unbearable(device, torch.device)

### Set Hyperparameters and do training

In [3]:
from lerobot.configs.types import PolicyFeature

# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 500
log_freq = 1

# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
#   - input/output shapes: to properly size the policy
#   - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features:dict[str, PolicyFeature] = dataset_to_policy_features(dataset_metadata.features)
die_if_unbearable(features, dict[str, PolicyFeature])

output_features:dict[str, PolicyFeature] = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
die_if_unbearable(output_features, dict[str, PolicyFeature])
input_features:dict[str, PolicyFeature] = {key: ft for key, ft in features.items() if key not in output_features}
die_if_unbearable(input_features, dict[str, PolicyFeature])


# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
cfg:DiffusionConfig = DiffusionConfig(input_features=input_features, output_features=output_features)

# We can now instantiate our policy with this config and the dataset stats.
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)

DiffusionPolicy(
  (normalize_inputs): Normalize(
    (buffer_observation_image): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
    )
    (buffer_observation_state): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (diffusion): Diff

### Get delta timesteps corresponding to given number of frames before

In [4]:
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
    # which can differ for inputs, outputs and rewards (if there are some).
delta_timestamps = {
    "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}

# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.image": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}

In [5]:
import rerun as rr
import rerun.blueprint as rrb

# We can then instantiate the dataset with these delta_timestamps configuration.
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

# Then we create our optimizer and dataloader for offline training.
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device.type != "cpu",
    drop_last=True,
)

rr.init("train policy")
rr.set_time_sequence("step", 0)
rr.log("loss", rr.Scalar(0))

blueprint = rrb.Blueprint(
    # rrb.TimeSeriesView(
    #     time_ranges = [
    #         rrb.VisibleTimeRange(
    #             "step",
    #             start=rrb.TimeRangeBoundary.cursor_relative(seq=-100),
    #             end=rrb.TimeRangeBoundary.cursor_relative()),
    #     ]
    # )
)

rr.init("evaluate_policy")
rr.notebook_show(blueprint=blueprint)

# Run training loop.
step = 0
done = False
while not done:
    for batch in dataloader:
        batch = {k: (v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
        output_dict = policy.forward(batch)
        loss = output_dict["loss"]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % log_freq == 0:
            rr.set_time_sequence("step", step)
            rr.log("loss", rr.Scalar(loss.item()))
        step += 1
        if step >= training_steps:
            done = True
            break

# Save a policy checkpoint.
policy.save_pretrained(output_directory)

Resolving data files:   0%|          | 0/206 [00:00<?, ?it/s]

Viewer()