In [None]:
import datasets
import os
import logging
from openvla.prismatic.vla.datasets import (
    RLDSDataset,
    EpisodicRLDSDataset,
    RLDSBatchTransform,
)
from dataclasses import dataclass
from pathlib import Path
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers import AutoConfig, AutoImageProcessor
from openvla.prismatic.models.backbones.llm.prompting import (
    PurePromptBuilder,
    VicunaV15ChatPromptBuilder,
)
from openvla.prismatic.vla.action_tokenizer import ActionTokenizer


hf_base = "jxu124/OpenX-Embodiment"
ucsd_kitchen_dataset_name = "ucsd_kitchen_dataset_converted_externally_to_rlds"
cmu_play_fusion_dataset_name = "cmu_play_fusion"
# Load the dataset with the custom session
# ds = datasets.load_dataset(
#     hf_base,
#     ucsd_kitchen_dataset_name,
#     split="train",
#     cache_dir="/workspaces/ares/data",
# )
# breakpoint()


@dataclass
class EpisodicConfig:
    data_root_dir: Path = Path(
        # "datasets/open-x-embodiment"
        f"/workspaces/ares/data"
    )  # Path to Open-X dataset directory
    # dataset_name: str = ucsd_kitchen_dataset_name
    dataset_name: str = cmu_play_fusion_dataset_name
    image_sizes: tuple[int] = (224, 224)
    vla_path: str = "openvla/openvla-7b"  # Path to OpenVLA model (on HuggingFace Hub)
    shuffle_buffer_size: int = 256  # _000
    image_aug: bool = False


cfg = EpisodicConfig()
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
action_tokenizer = ActionTokenizer(processor.tokenizer)

# batch_transform = RLDSBatchTransform(
#     action_tokenizer,
#     processor.tokenizer,
#     image_transform=processor.image_processor.apply_transform,
#     prompt_builder_fn=(
#         PurePromptBuilder if "v01" not in cfg.vla_path else VicunaV15ChatPromptBuilder
#     ),
# )
batch_transform = lambda x: x
vla_dataset = EpisodicRLDSDataset(
    cfg.data_root_dir,
    cfg.dataset_name,
    batch_transform,
    resize_resolution=tuple(cfg.image_sizes),
    shuffle_buffer_size=cfg.shuffle_buffer_size,
    image_aug=cfg.image_aug,
)

In [21]:
rlds_batch_episode = next(iter(vla_dataset))

In [22]:
out = next(iter(vla_dataset.dataset.as_numpy_iterator()))

In [None]:
print(out['observation']['image_primary'].shape)
out['observation']['timestep'].T, out['task']['timestep']

In [24]:
# task, dataset_name = rlds_batch_i["task"], rlds_batch_i["dataset_name"]
# # print(dataset_name)
# # print("task:   ")
# # # lang instruction, image primary, timestemp, mask
# for k,v in task.items():
#     print("\t", k, type(v))

# # print(f"action: {rlds_batch_i['action'][0]}")
# from PIL import Image
# img_vals = rlds_batch_i["observation"]["image_primary"][0]
# print(type(img_vals), img_vals.shape, img_vals.min(), img_vals.max())
# img = Image.fromarray(img_vals)
# lang_instruction = rlds_batch_i["task"]["language_instruction"].decode().lower()
# timestep = rlds_batch_i["task"]["timestep"]



# import matplotlib.pyplot as plt
# plt.imshow(img)
# plt.title(f"{lang_instruction}\n{timestep}/{len(rlds_batch_episode)}")
# plt.show()


In [25]:
import matplotlib.pyplot as plt
from PIL import Image
from IPython import display

def create_episode_gif(rlds_batch_episode, output_path='episode.gif', frame_duration=500):
    """
    Creates a GIF from RLDS episode data with properly rendered captions in a Jupyter notebook.
    
    Args:
        rlds_batch_episode: Batch of RLDS episode data
        output_path: Path to save the output GIF
        frame_duration: Duration for each frame in milliseconds
    """
    images = []
    
    # Clear any existing plots and set up the matplotlib backend
    plt.close('all')
    %matplotlib inline
    
    print("Creating GIF...")
    
    for i, step in enumerate(rlds_batch_episode):
        # Show progress
        display.clear_output(wait=True)
        print(f"Processing frame {i+1}/{len(rlds_batch_episode)}...")
        
        img = Image.fromarray(step["observation"]["image_primary"][0])
        lang_instruction = step["task"]["language_instruction"].decode().lower()
        # timestep = step["task"]["timestep"]
        timestep = step["observation"]["timestep"]
        
        # Create figure with adequate size and margins
        fig = plt.figure(figsize=(10, 8), dpi=100)
        plt.subplots_adjust(top=0.85)
        
        # Display the image
        plt.imshow(img)
        plt.axis('off')
        
        # Add title with proper wrapping
        plt.title(f"{lang_instruction}\n{timestep}/{len(rlds_batch_episode)}", 
                 wrap=True, 
                 pad=20,
                 fontsize=12)
        
        # Convert to image
        fig.canvas.draw()
        img_with_title = Image.frombytes('RGB', 
                                       fig.canvas.get_width_height(),
                                       fig.canvas.tostring_rgb())
        
        images.append(img_with_title)
        plt.close(fig)  # Important in Jupyter to prevent memory leaks
    
    # Save as animated GIF
    images[0].save(
        output_path,
        save_all=True,
        append_images=images[1:],
        duration=frame_duration,
        loop=0
    )
    
    # Clear the progress output and show completion message
    display.clear_output(wait=True)
    print("GIF created successfully!")
    
    # Display the final GIF
    return display.Image(filename=output_path)

In [None]:
vla_dataset.dataset

In [32]:
iter_ds = iter(vla_dataset)

In [51]:
rlds_batch_episode = next(iter_ds)

In [None]:
# display the gif
create_episode_gif(rlds_batch_episode)