In [None]:
import sys
import random
import numpy as np
import os
from PIL import Image
import json
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from src.env.env import RILAB_OMY_ENV

In [None]:
root = './dataset/demo_data'
dataset = LeRobotDataset('Jeongeun/deep_learning_2025',root = root )

In [3]:

class EpisodeSampler(torch.utils.data.Sampler):
    """
    Sampler for a single episode
    """
    def __init__(self, dataset: LeRobotDataset, episode_index: int):
        from_idx = dataset.episode_data_index["from"][episode_index].item()
        to_idx = dataset.episode_data_index["to"][episode_index].item()
        self.frame_ids = range(from_idx, to_idx)

    def __iter__(self):
        return iter(self.frame_ids)

    def __len__(self) -> int:
        return len(self.frame_ids)

In [4]:
# Select an episode index that you want to visualize
episode_index =1

episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=1,
    batch_size=1,
    sampler=episode_sampler,
)


In [5]:

iter_dataloader = iter(dataloader)
data = next(iter_dataloader)

In [None]:
try:
    config_file_path = data['config_file_name'][0]
except: config_file_path = 'configs/train2.json'
with open(config_file_path) as f:
    env_conf = json.load(f)
omy_env = RILAB_OMY_ENV(cfg=env_conf, seed=0, action_type='joint', vis_mode = 'teleop')

In [7]:
# geom_idx = omy_env.env.geom_names.index('box_1_geom')
# omy_env.env.model.geom(geom_idx).rgba = np.array([0.5, 1.0, 1.0, 1])

In [None]:
step = 0
iter_dataloader = iter(dataloader)
omy_env.reset()
action = omy_env.get_full_joint_state()
clips = []
actions = []
save_flag = True
while omy_env.env.is_viewer_alive():
    # PnPEnv.forward_env(action)
    omy_env.step_env()
    # if omy_env.env.loop_every(HZ = 1):
    #     omy_env.agument_object_random_color()
    if omy_env.env.loop_every(HZ=20):
        # Get the action from dataset
        data = next(iter_dataloader)
        obj_pose = data['obj_pose'][0].numpy()
        obj_names = data['obj_names'][0]
        obj_names = obj_names.split(',')
        recp_q_states = data['obj_q_states'][0].numpy()
        recp_q_names = data['obj_q_names'][0]
        recp_q_names = recp_q_names.split(',')
        omy_env.set_object_pose(obj_pose, obj_names, recp_q_states, recp_q_names)
        language_instruction = data['task'][0]
        # Get the action from dataset
        action = data['action'].numpy()
        # print(action.shape)
        img = data['image'][0].numpy()*255
        img =  np.transpose(img, (1,2,0))
        if save_flag:
            clips.append(img.astype(np.uint8))
            actions.append(action[0])
        # print(action.shape)
        # obs = PnPEnv.step(state[0])
        # Visualize the image from dataset to rgb_overlay
        omy_env.rgb_agent = data['image'][0].numpy()*255
        omy_env.rgb_ego = data['wrist_image'][0].numpy()*255
        omy_env.rgb_agent = omy_env.rgb_agent.astype(np.uint8)
        omy_env.rgb_ego = omy_env.rgb_ego.astype(np.uint8)
        # 3 256 256 -> 256 256 3
        omy_env.rgb_agent = np.transpose(omy_env.rgb_agent, (1,2,0))
        omy_env.rgb_ego = np.transpose(omy_env.rgb_ego, (1,2,0))
        omy_env.rgb_side = np.zeros((480, 640, 3), dtype=np.uint8)
        omy_env.render(language_instruction)
        step += 1
        omy_env.step(action[0])
        if step == len(episode_sampler):
            # start from the beginning
            iter_dataloader = iter(dataloader)
            # omy_env.reset()
            step = 0
        if step == 0:
            omy_env.env.reset()
            omy_env.reset()
            save_flag = False
    omy_env.env.sync_sim_wall_time()