In [1]:
# First cell - Setup paths
import sys
import os

# Setup paths as before
notebook_dir = os.path.dirname('__file__')
lerobot_root = os.path.abspath(os.path.join(notebook_dir, '../..'))
sys.path.append(lerobot_root)
os.chdir(lerobot_root)

In [None]:
import cv2
import numpy as np
from PIL import Image
import torch
from tqdm.notebook import tqdm
from typing import List, Optional, Dict, Union, Tuple

from pathlib import Path
import modal
import gc
import json
import time
from google import genai
from google.genai import types
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.caferacer.scripts.image_utils import reorder_tensor_dimensions, tensor_to_pil, display_images
from lerobot.caferacer.scripts.aug_utils import flip_frame, apply_color, get_mask
from lerobot.caferacer.scripts.gemini_utils import parse_json
from lerobot.caferacer.scripts.inpaint_utils import *

In [3]:
#episodes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251]

In [4]:
# Load dataset
repo_id = "shreyasgite/so100_base_left"
#repo_id0 = "shreyasgite/so100_base_env"

In [None]:
gsam = modal.Function.lookup("grounded-sam","GroundedSam.run", environment_name='prod')

In [11]:
def create_dataset(repo_id, dataset0, gsam, GPU_POOR=True, DATA_AUG=None, create_new=False, batch_size=10):

    if create_new:
        dataset = LeRobotDataset.create(
            repo_id,
            fps=dataset0.fps,
            robot_type=dataset0.meta.robot_type,
            features=dataset0.features,
            use_videos=len(dataset0.meta.video_keys) > 0, 
            image_writer_threads=8)
    else:
        dataset = LeRobotDataset(repo_id)
        dataset.start_image_writer(num_threads=8)
        
    num_episodes = dataset0.num_episodes
    image_keys = [key for key in dataset0[0].keys() if key.startswith("observation.images.")]
    
    for batch_start in range(0, num_episodes, batch_size):
        batch_end = min(batch_start + batch_size, num_episodes)
        print(f"Processing episodes {batch_start} to {batch_end}")

        for ep_idx in range(batch_start, batch_end):
            print(f"Processing episode {ep_idx}")
            
            # Get frame indices for this episode
            from_idx = dataset0.episode_data_index["from"][ep_idx].item()
            to_idx = dataset0.episode_data_index["to"][ep_idx].item()
            print(f"Processing frames {from_idx} to {to_idx}")

            # Process frames
            for frame_idx in range(from_idx, to_idx):
                obs = {}
                frame_data = dataset0[frame_idx]
                for key in image_keys:
                    obs[key] = reorder_tensor_dimensions(frame_data[key])
                obs['observation.state'] = frame_data['observation.state']
                action = {"action": frame_data['action']}
                
                if DATA_AUG is not None:
                    if 'inpaint_distractions' in DATA_AUG:
                        object = DATA_AUG['inpaint_distractions']['objects']
                        if frame_idx == from_idx: source_top_im, source_front_im, top_im_msk, front_im_msk = get_inpaint_object(gsam, obs, object)
                        obs = create_inpainted_frame(obs, source_top_im, source_front_im, top_im_msk, front_im_msk)
                    
                    if 'flip_frame' in DATA_AUG:
                        obs, action = flip_frame(obs, action, image_keys)
                    
                    if 'change_color' in DATA_AUG:
                        if GPU_POOR: 
                            if frame_idx == from_idx:  mask = get_mask(gsam, obs, image_keys, object=DATA_AUG['change_color']['object'])
                            elif frame_idx == from_idx + (8 * 30): mask = get_mask(gsam, obs, image_keys, object=DATA_AUG['change_color']['object'])
                            elif frame_idx == from_idx + (10 * 30): mask = get_mask(gsam, obs, image_keys, object=DATA_AUG['change_color']['object'])
                        else: mask = get_mask(gsam, obs, image_keys)
                        target_color = DATA_AUG['change_color']['target_color']
                        obs = apply_color(obs, mask, image_keys, target_color=target_color)

                single_task = "Grasp the lego brick and drop it in the Yellow Container."
                frame = {**obs, **action, "task": single_task}
                dataset.add_frame(frame)
                
            print(f"Saving episode {ep_idx}")
            dataset.save_episode()
            
        gc.collect()
        
    return dataset

In [9]:
DATA_AUG = {'flip_frame': True, 'change_color': {'object': 'blue container', 'target_color': 'yellow'}, 'inpaint_distractions': {'objects': 'rubik cube'}}

In [None]:
dataset_base = LeRobotDataset(repo_id)
dataset = create_dataset(repo_id, dataset_base, gsam, DATA_AUG=DATA_AUG)
#dataset.push_to_hub()