In [None]:
import torch
from torch.utils.data import DataLoader

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

REPO_ID = "yilin404/pick_and_place"
DATA_ROOT = "/home/yilin/dataset/own_episode_data"

# Set up the dataset.
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.images.colors_camera_top": [-0.1, 0.0],
    "observation.images.colors_camera_wrist": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset(REPO_ID, root=DATA_ROOT, delta_timestamps=delta_timestamps)

# Create dataloader for offline training.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
    dataset,
    num_workers=0,
    batch_size=4,
    shuffle=True,
    pin_memory=device != torch.device("cpu"),
    drop_last=True,
)

In [None]:
print(dataset.stats)

In [None]:
from lerobot.common.policies.rdt.configuration_rdt import RDTConfig
from lerobot.common.policies.rdt.modeling_rdt import RDTPolicy

import torch

# Set up the policy.
config = RDTConfig()
config.input_shapes = {
    "observation.images.colors_camera_top": [3, 480, 640],
    "observation.images.colors_camera_wrist": [3, 480, 640],
    "observation.state": [7],
}
config.output_shapes = {
    "action": [7],
}
config.input_normalization_modes = {
    "observation.images.colors_camera_top": "mean_std",
    "observation.images.colors_camera_wrist": "mean_std",
    "observation.state": "min_max",
}
config.output_normalization_modes = {
    "action": "min_max",
}
config.crop_shape = None
policy = RDTPolicy(config, dataset_stats=dataset.stats)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy.to(device)
for batch in dataloader:
    # Move the batch to the device.
    batch = {k: v.to(device) for k, v in batch.items()}

    # Forward pass.
    loss = policy(batch)["loss"]
    print("==> loss is: ", loss.item())

In [None]:
# from pathlib import Path
# import json
# import re
# import tqdm

# import numpy as np

# from typing import List

# RAW_DIR = Path("/home/yilin/dataset/own_episode_data/raw_data/pick_and_place")
# JSON_FILE = "data.json"

# def get_episodes(raw_dir: Path) -> List[Path]:
#     return [path for path in raw_dir.iterdir() if path.is_dir()]

# episode_paths = get_episodes(RAW_DIR)
# print(f"Found {len(episode_paths)} episodes.")

# episode_paths = sorted(
#     episode_paths,
#     key=lambda path: int(re.search(r'(\d+)$', str(path)).group(1)) if re.search(r'(\d+)$', str(path)) else -1
# )
# num_episodes = len(episode_paths)

# for ep_path in tqdm.tqdm(episode_paths, desc=f"Processing {num_episodes} episodes"):
#     json_path = ep_path / JSON_FILE
#     if not json_path.exists():
#         print(f"Warning: {json_path} does not exist.")
#         continue

#     try:
#         with json_path.open('r+', encoding='utf-8') as jsonf:
#             # 加载 JSON 文件
#             episode_data = json.load(jsonf)

#             # 修改数据
#             for sample_data in episode_data.get("data", []):
#                 arm_states = np.array(sample_data["states"].get("arm", {})["qpos"], dtype=np.float32)
#                 arm_actions = np.array(sample_data["actions"].get("arm", {})["qpos"], dtype=np.float32)

#                 arm_actions = arm_states + np.clip(arm_actions - arm_states, a_min=-0.1, a_max=0.1)

#                 sample_data["actions"]["arm"]["qpos"] = arm_actions.tolist()
            
#             # 写入修改后的数据
#             jsonf.seek(0)  # 将文件指针移动到文件开头
#             jsonf.truncate()  # 清空文件内容
#             json.dump(episode_data, jsonf, indent=4, ensure_ascii=False)  # 写入修改后的数据

#     except json.JSONDecodeError as e:
#         print(f"Error decoding JSON in {json_path}: {e}")
#         continue
