In [1]:
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
from IPython import display
import imageio
import os

In [2]:
ucsd_kitchen_dataset_name = "ucsd_kitchen_dataset_converted_externally_to_rlds"
cmu_play_fusion_dataset_name = "cmu_play_fusion"

data_dir = "/workspaces/ares/data/"

In [None]:
builder = tfds.builder(ucsd_kitchen_dataset_name, data_dir=data_dir)
builder.download_and_prepare()


In [24]:
datasets = builder.as_dataset()
ds = datasets["train"]

In [None]:
!pip install imageio[ffmpeg]

In [26]:
import matplotlib.pyplot as plt
from PIL import Image
from IPython import display

def create_episode_gif(episode, output_path='episode.gif', frame_duration=500):
    images = []
    
    # Clear any existing plots and set up the matplotlib backend
    plt.close('all')
    %matplotlib inline
    
    print(f"Creating GIF for {len(episode['steps'])} frames...")
    steps = list(episode['steps'])
    
    for i, step in enumerate(steps):
        # Show progress
        
        # if in the last 20 frames, show every other
        if (len(steps) - i) > 10:
            # show at most 10 total frames from before the last 10 frames
            if (i % (len(steps) // 10)) != 0:
                continue
        else:
            # show every other frame
            if (i % 2) != 0:
                continue

        # display.clear_output(wait=True)
        print(f"Processing frame {i+1}/{len(steps)}...")
        
        img = Image.fromarray(step["observation"]["image"].numpy())
        lang_instruction = step["language_instruction"].numpy().decode().lower()
        timestep = i
        state_str = f"is_first: {step['is_first']}, is_last: {step['is_last']}, is_terminal: {step['is_terminal']}"
        
        # Create figure with adequate size and margins
        fig = plt.figure(figsize=(10,8), dpi=100)
        plt.subplots_adjust(top=0.85)
        
        # Display the image
        plt.imshow(img)
        plt.axis('off')
        
        # Add title with proper wrapping
        plt.title(f"{lang_instruction}\n{timestep}/{len(steps)}\n{state_str}", 
                 wrap=True, 
                 pad=20,
                 fontsize=12)
        
        # Convert to image
        fig.canvas.draw()
        img_with_title = Image.frombytes('RGB', 
                                       fig.canvas.get_width_height(),
                                       fig.canvas.tostring_rgb())
        
        images.append(img_with_title)
        plt.close(fig)  # Important in Jupyter to prevent memory leaks
    
    # Save as animated GIF
    # images[0].save(
    #     output_path + ".gif",
    #     save_all=True,
    #     append_images=images[1:],
    #     duration=frame_duration,
    #     loop=0
    # )
    # print("GIF created successfully! Saved to ", output_path)

    # save as mp4
    # Use imageio to save as MP4 instead of PIL
    imageio.mimsave(
        output_path + ".mp4",
        images,
        fps=1000/frame_duration,  # Convert duration in ms to fps
        macro_block_size=None  # Add this parameter to avoid FFMPEG errors
    )
    print("MP4 created successfully! Saved to ", output_path.replace(".gif", ".mp4"))
    
    # Clear the progress output and show completion message
    # display.clear_output(wait=True)
    
    # Display the final GIF
    return output_path

In [27]:
iter_ds = iter(ds)

In [None]:

output_dir = "/tmp/episodes"
os.makedirs(output_dir, exist_ok=True)

for i in range(10):
    if i < 5:
        continue
    episode = next(iter_ds)
    output_path = create_episode_gif(episode, output_path=os.path.join(output_dir, f"episode_{i}"))

In [29]:
# display.Image(filename=output_path)