In [None]:
import pickle
import numpy as np
import os
import json
from PIL import Image

# Settings
base_dir = '/root/onethingai-tmp/my/archive/data'
output_base = '/root/onethingai-tmp/my/data_long_interval'      # !!! change the directory to the true directory
tasks = [
    'blocks_stack_easy_sf50_D435_pkl',
    'block_hammer_beat_sf50_D435_pkl',
    'block_handover_sf50_D435_pkl'
]
threshold = 7

# Track output index and entries
counter = 0
train_entries = []
test_entries = []
val_entries = []

def save_image_pair(src_img, tgt_img, out_dir, prompt, mode):
    os.makedirs(out_dir, exist_ok=True)
    Image.fromarray(src_img).save(os.path.join(out_dir, "source.png"))
    Image.fromarray(tgt_img).save(os.path.join(out_dir, "target.png"))

    rel = os.path.relpath(out_dir, output_base)#.replace("\\", "/")

    # Prompt lookup
    if 'block_hammer_beat' in prompt:
        pt = 'There is a hammer and a block in the middle of the table. If the block is closer to the left robotic arm, it uses the left arm to pick up the hammer and strike the block; otherwise, it does the opposite.'
    elif 'blocks_stack_easy' in prompt:
        pt = 'Red and black cubes are placed randomly on the table. The robotic arm stacks the cubes in order, placing the red cubes first, followed by the black cubes, in the designated target location.'
    elif 'block_handover' in prompt:
        pt = 'A long block is placed on the left side of the table. The left arm grasps the upper side of the block and then hands it over to the right arm, which places the block on the blue mat on the right side of the table.'
    else:
        raise NotImplementedError(f"Prompt not implemented for {prompt}")

    entry = {
        "image": f"data_long_interval/{rel}/source.png",        # !!! change the directory to the true directory
        "edited_image": f"data_long_interval/{rel}/target.png",
        "prompt": pt
    }

    if mode == "train":
        train_entries.append(entry)
    elif mode == "test":
        test_entries.append(entry)
    elif mode == "val":
        val_entries.append(entry)

# Loop over tasks and episodes
for task in tasks:
    task_path = os.path.join(base_dir, task)
    for episode in range(20):
        ep_dir = os.path.join(task_path, f'episode{episode}')
        if not os.path.exists(ep_dir):
            continue

        pkl_files = sorted(
            [f for f in os.listdir(ep_dir) if f.endswith('.pkl')],
            key=lambda x: int(x.split('.')[0])
        )

        indices = list(range(0, len(pkl_files) - 5, 2))        # 5: lapping interval length
        for i in indices:
            try:
                path_src = os.path.join(ep_dir, pkl_files[i])
                path_tgt = os.path.join(ep_dir, pkl_files[i + 5])      # 5: lapping interval length, consistent with the above

                with open(path_src, 'rb') as f:
                    src_img = pickle.load(f)["observation"]["head_camera"]["rgb"]
                with open(path_tgt, 'rb') as f:
                    tgt_img = pickle.load(f)["observation"]["head_camera"]["rgb"]

                diff = np.abs(tgt_img.astype(np.float32) - src_img.astype(np.float32))
                mean_diff = np.mean(diff)

                if mean_diff > threshold:
                    # Decide output location
                    if episode <= 17:
                        folder = f"data/{counter:04d}"
                        mode = "train"
                    elif episode == 18:
                        folder = f"test/{counter:04d}"
                        mode = "test"
                    else:
                        folder = f"val/{counter:04d}"
                        mode = "val"

                    out_dir = os.path.join(output_base, folder)
                    save_image_pair(src_img, tgt_img, out_dir, task, mode)
                    counter += 1

            except Exception as e:
                print(f"Skipping index {i} in episode {episode} due to error: {e}")

# Save JSON files
with open(os.path.join(output_base, 'train.json'), 'w') as f:
    json.dump(train_entries, f, indent=2)

with open(os.path.join(output_base, 'test.json'), 'w') as f:
    json.dump(test_entries, f, indent=2)

with open(os.path.join(output_base, 'val.json'), 'w') as f:
    json.dump(val_entries, f, indent=2)
