In [None]:
# echo "🎯 开始训练..."
!python3 -m lerobot.scripts.train \
    --dataset.root=./dataset/qnbot_data1 \
    --dataset.repo_id=bradley/qnbot_cherry_transfer_20250705 \
    --policy.type=act \
    --policy.repo_id=bradley/qnbot_cherry_transfer_act \
    --output_dir=outputs/qnbot_cherry_transfer_act \
    --batch_size=16 \
    --steps=100000 \
    --policy.device=cuda \
    --wandb.enable=true \
    --wandb.mode=offline \
    --wandb.project=lerobot_qnbot \
    --wandb.entity=breadlee1024 \
    --wandb.notes="QnBot樱桃传递任务 - 右手拿樱桃玩具传递给左手放到白盘子里" \
    --log_freq=200 \
    --save_freq=10000

In [None]:
from pathlib import Path

import torch

from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
import logging

In [None]:
# Parameters
FPS = 30
# Dataset parameters
import os
import lerobot

# Get parent directory of the current file

REPO_PATH = os.path.dirname(os.path.dirname(os.path.dirname(lerobot.__file__)))

dataset_root = os.path.join(REPO_PATH, "dataset/qnbot_data1")

repo_id = "bradley/qnbot_cherry_transfer_20250705"
logging.info(f"Using LeRobot repo path: {dataset_root}")

# Create  a directory to sotore the training checkpoints
output_dir = Path("outputs/train/qnbot_cherry_transfer_act")
output_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Using output directory: {output_dir}")

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 100000
log_freq = 1


In [None]:
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
#   - input/output shapes: to properly size the policy
#   - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata(repo_id=repo_id, root=dataset_root)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
from pprint import pformat
print(f"Dataset metadata:\n{pformat(input_features)}")
print(f"Output features:\n{pformat(output_features)}")
logging.info(f"Input features: {input_features}")
logging.info(f"Output features: {output_features}")

Input features: {'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(17,)), 'observation.images.left_wrist': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.images.head': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.images.right_wrist': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640))}
Output features: {'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(17,))}


In [34]:
# =============================================
# 2. Dataset Setup
# =============================================
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy

cfg = DiffusionConfig(input_features=input_features, output_features=output_features)


delta_timestamps = {
    "observation.images.head": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "observation.images.left_wrist": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "observation.images.right_wrist": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}

dataset = LeRobotDataset(
    repo_id=repo_id,
    root=dataset_root,
    delta_timestamps=delta_timestamps
)

# repo_ids = [f"agibotworld/{path.name}" for path in Path(dataset_path).glob("agibotworld/task_*")]
# multi_dataset = MultiLeRobotDataset(
#     repo_ids=repo_ids,
#     root=dataset_path,
#     delta_timestamps=delta_timestamps,
#     local_files_only=True
# )

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device.type != "cpu",
    drop_last=True,
)




In [36]:
# =============================================
# 3. Policy Configuration and Initialization
# =============================================
cfg.input_shapes = {
    "observation.images.head": [3, 480, 640],
    "observation.images.left_wrist": [3, 480, 640],
    "observation.images.right_wrist": [3, 480, 640],
    "observation.state": [17],
}
cfg.input_normalization_modes = {
    "observation.images.head": "mean_std",
    "observation.images.left_wrist": "mean_std",
    "observation.images.right_wrist": "mean_std",
    "observation.state": "min_max",
}
cfg.output_shapes = {
    "action": [17],
}


policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
#policy = DiffusionPolicy(cfg, dataset_stats=multi_dataset.stats)
policy.train()

policy.to(device)

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

In [37]:

# =============================================
# 4. Training Loop
# =============================================
step = 0
done = False
while not done:
    for batch in dataloader:
        batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
        loss, _ = policy.forward(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % log_freq == 0:
            print(f"step: {step} loss: {loss.item():.3f}")
        step += 1
        if step >= training_steps:
            done = True
            break

step: 0 loss: 1.204
step: 1 loss: 1.403
step: 2 loss: 1.125
step: 3 loss: 1.057
step: 4 loss: 1.064
step: 5 loss: 1.005
step: 6 loss: 1.015
step: 7 loss: 1.018
step: 8 loss: 1.023
step: 9 loss: 1.025
step: 10 loss: 1.009
step: 11 loss: 1.014
step: 12 loss: 1.016


KeyboardInterrupt: 

In [None]:
# =============================================
# 5. Save Policy Checkpoint
# =============================================
policy.save_pretrained(output_path)
print(f"Model saved to {output_path}")