In [None]:
import sys
import random
import numpy as np
import os
from PIL import Image
import json
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from src.env.env import RILAB_OMY_ENV
from src.mujoco_helper.transforms import rpy2r, r2rpy
import torch
from src.dataset.utils import *
from src.mujoco_helper import MuJoCoParserClass

# Load the original dataset from teleoperations
dataset = LeRobotDataset('Jeongeun/deep_learning_2025',root = './dataset/demo_data')
metadata = dataset.meta

In [None]:
'''
Define the action and observation space for the environment
'''
action_type = 'delta_eef_pose'  # Options: 'joint','delta_joint, 'delta_eef_pose', 'eef_pose'
proprio_type = 'eef_pose' # Options: 'joint_pos', 'eef_pose'
observation_type = 'object_pose' # options: 'image', 'object_pose'
image_aug_num = 2  # Number of augmented images to generate per original image
transformed_dataset_path = './dataset/transformed_data_notebook'

## Transform with the action space and agument image

In [None]:
'''
Create transformed dataset
'''
if os.path.exists(transformed_dataset_path):
    import shutil
    shutil.rmtree(transformed_dataset_path)
transformed_dataset = create_dataset(transformed_dataset_path, add_images = (observation_type=='image'))

In [None]:

'''
Load environment configuration and initialize environments
'''
config_file_path = './configs/train2.json'
with open(config_file_path) as f:
    env_conf = json.load(f)
omy_env = RILAB_OMY_ENV(cfg=env_conf, seed=0, 
                        action_type=action_type, 
                        obs_type=proprio_type,
                        vis_mode = 'teleop')
ik_env = MuJoCoParserClass(name='IK_env',rel_xml_path='./asset/scene_table.xml')

In [None]:

"""
Wraper function to iterate through episodes and apply image augmentation
"""
def iterate_episodes(dataset,transformed_dataset, omy_env, ik_env,q_init, start_idx_ori, end_idx_ori, language_instruction, img_aug = False):
    omy_env.reset()
    ik_env.reset()
    for _ in range(10):
        omy_env.action_type = 'joint'
        omy_env.step(q_init)
        omy_env.step_env()
    omy_env.action_type = action_type
    current_step = start_idx_ori
    if img_aug:
        omy_env.agument_object_random_color()
    while omy_env.env.is_viewer_alive():
        omy_env.step_env()
        if img_aug and omy_env.env.loop_every(HZ = 1):
            omy_env.agument_object_random_color()
        if omy_env.env.loop_every(HZ = 20):
            success = omy_env.check_success()
            if current_step > end_idx_ori - 1:
                break
            data = dataset.hf_dataset[current_step]
            if current_step == start_idx_ori:
                objet_info = parse_object_info(data)
                omy_env.set_object_pose(*objet_info)
            action = transform_action(data, omy_env, ik_env, action_type)
            observation = omy_env.step(action, gripper_mode='continuous')
            agent_image, wrist_image = omy_env.grab_image(return_side=False)
            # # resize to 256x256
            frame = {
                "observation.state": observation,
                "action": action.astype(np.float32),
            }
            if observation_type == 'image':
                agent_image = Image.fromarray(agent_image)
                wrist_image = Image.fromarray(wrist_image)
                agent_image = agent_image.resize((256, 256))
                wrist_image = wrist_image.resize((256, 256))
                agent_image = np.array(agent_image)
                wrist_image = np.array(wrist_image)
                frame["observation.image"] = agent_image
                frame["observation.wrist_image"] = wrist_image
            else:
                obj_states, recp_q_poses = omy_env.get_object_pose(pad=10)
                # extract position and rpy
                # sort the object names in alphabetical order to maintain consistency
                obj_states_sorted = np.zeros((24,), dtype=np.float32)
                sorted_indices = np.argsort(obj_states['names'])
                for i, idx in enumerate(sorted_indices):
                    if 'pad' in obj_states['names'][idx]:
                        continue
                    obj_state = obj_states['poses'][idx]
                    obj_states_sorted[i*6:(i+1)*6] = obj_state
                # sort the receptacle q states in alphabetical order
                recp_q_poses_sorted = np.zeros((3,), dtype=np.float32)
                sorted_indices = np.argsort(recp_q_poses['names'])
                i = 0
                for _, idx in enumerate(sorted_indices):
                    if 'pad' in recp_q_poses['names'][idx]:
                        continue
                    recp_q_pose = recp_q_poses['poses'][idx]
                    recp_q_poses_sorted[i] = recp_q_pose
                    i += 1
                obj_states_final = np.concatenate([obj_states_sorted, recp_q_poses_sorted])
                frame["observation.environment_state"] = obj_states_final

            transformed_dataset.add_frame(
                frame, task=language_instruction
            )
            omy_env.render()
            current_step += 1
    return success

In [None]:

'''
Iterate through all episodes in the original dataset
'''
for episode_index in range(metadata.total_episodes):
        
    start_idx_ori = dataset.episode_data_index['from'][episode_index].item()
    end_idx_ori = dataset.episode_data_index['to'][episode_index].item()
    q_init = dataset.hf_dataset[start_idx_ori]['action'].numpy()
    language_instruction = dataset.hf_dataset[start_idx_ori]['task_index'].item()
    language_instruction = metadata.tasks[language_instruction]
    print(f"Episode {episode_index}, Instruction: {language_instruction}")
    success =   iterate_episodes(dataset, transformed_dataset, omy_env, ik_env, q_init, start_idx_ori, end_idx_ori, language_instruction)
    print(success)
    if success:
        transformed_dataset.save_episode()
    else:
        transformed_dataset.clear_episode_buffer()
    if observation_type == 'image':
        for _ in range(image_aug_num):
            # This will randomize object colors
            success =   iterate_episodes(dataset, transformed_dataset, omy_env, ik_env, q_init, start_idx_ori, end_idx_ori, language_instruction, img_aug=True)
            if success:
                transformed_dataset.save_episode()
            else:
                transformed_dataset.clear_episode_buffer()