# Visualize Minari dataset

Offline datasets contain episodes.

Each episode has a length of `len(episode.rewards)`. `len(episode.actions)` is the sames as episode's.

But each of the keys (image, mission and direction) in `episode.observations` is of length `len(episode.actions) + 1`.

This is because the first set of observations is that before doing the first action and getting the first reward.

### Understand the flow in RL
At time $t$, the agent:

1. Observes the current state $s_t$

2. Chooses and executes an action $a_t$

3. Receives a reward $r_{t}$ upon reaching $s_{t+1}$

4. Observes $s_{t+1}$

The reward at time *t* is sometimes denoted $r_{t+1}$, indicating that it is a consequence of $a_t$ and $s_t$ and is revealed after the transition.

### GIF Generation with Minari datasets

In [6]:

import os
import imageio
import minari
from PIL import Image, ImageDraw, ImageFont

ACTION_NAMES = {
    0: "left",
    1: "right",
    2: "forward",
    3: "pickup",
    4: "drop",
    5: "toggle",
    6: "done"
}

DIRECTION_NAMES = {
    0: "right",
    1: "down",
    2: "left",
    3: "up"
}

def _space_at(values, index):
    if isinstance(values, dict):
        return {k: _space_at(v, index) for k, v in values.items()}
    elif isinstance(values, tuple):
        return tuple(_space_at(v, index) for v in values)
    else:
        return values[index]

def annotate_frame(frame, action=None, direction=None, mission=None, reward=None, font=None):
    img = Image.fromarray(frame)
    draw = ImageDraw.Draw(img)

    if font is None:
        font = ImageFont.load_default()

    if mission is not None:
        draw.text((8, 8), f"Mission: {mission}", font=font, fill="white")

    if action is not None:
        action_name = ACTION_NAMES.get(action, str(action))
        draw.text((8, 48), f"Action: {action_name}", font=font, fill="white")

    if direction is not None:
        direction_name = DIRECTION_NAMES.get(direction, str(direction))
        draw.text((8, 88), f"Direction: {direction_name}", font=font, fill="white")

    if reward is not None:
        draw.text((8, 128), f"Reward: {reward:.2f}", font=font, fill="white")

    return img

def generate_gif(dataset_id, path, num_frames=512, fps=32):
    """
    Generate a GIF from a Minari dataset.
    This function loads a Minari dataset, iterates through its episodes, and creates a GIF
    by rendering frames with annotations for actions, directions, missions, and rewards.
    The GIF is saved to the specified path.

    Args:
        dataset_id (str): The ID of the Minari dataset.
        path (str): The directory where the GIF will be saved.
        num_frames (int): The maximum number of frames to include in the GIF.
        fps (int): Frames per second for the GIF.
    Returns:
        str: The path to the generated GIF file.
    Raises:
        ValueError: If the dataset does not have enough steps or if the seed is unknown.

    """
    dataset = minari.load_dataset(dataset_id)
    env = dataset.recover_environment(render_mode="rgb_array")
    images = []

    try:
        font = ImageFont.truetype("arial.ttf", 38)
    except IOError:
        font = ImageFont.load_default(38)

    metadatas = dataset.storage.get_episode_metadata(dataset.episode_indices)
    for episode, episode_metadata in zip(dataset.iterate_episodes(), metadatas):
        seed = episode_metadata.get("seed")
        if episode.id == 0 and seed is None:
            raise ValueError("Cannot reproduce episodes with unknown seed.")

        mission_raw = episode.observations["mission"]
        if isinstance(mission_raw, bytes):
            mission = mission_raw.decode("utf-8")
        else:
            mission = str(mission_raw)

        env.reset(seed=seed, options=episode_metadata.get("options"))
        frame = env.render()
        direction = episode.observations["direction"][0]
        images.append(annotate_frame(frame, action=None, direction=direction, mission=mission, reward=None, font=font))

        for step_id in range(len(episode)):
            act = _space_at(episode.actions, step_id)
            obs, reward, terminated, truncated, info = env.step(act)
            frame = env.render()
            direction = episode.observations["direction"][step_id + 1]
            img = annotate_frame(frame, action=act, direction=direction, mission=mission, reward=reward, font=font)
            images.append(img)

            if len(images) > num_frames:
                env.close()
                safe_path = os.path.join(path, *dataset_id.split("/")[:-1])
                os.makedirs(safe_path, exist_ok=True)
                gif_file = os.path.join(safe_path, f"{dataset_id.split('/')[-1]}.gif")
                imageio.mimsave(gif_file, [img.convert("RGB") for img in images], fps=fps)
                return gif_file

    raise ValueError("There are not enough steps in the dataset.")


## BabyAI-Pickup Optimal fullobs
Agent task is to pick up target objects on scene.

The agent is a hard-coded planner, which solves all the tasks optimally.


In [7]:
dataset_id = "minigrid/BabyAI-Pickup/optimal-fullobs-v0"
path = '.'
num_frames = 512
fps = 1

generate_gif(dataset_id, path, num_frames, fps)

gif_path = os.path.join(path, f"{dataset_id.split('/')}.gif")

Sampling rejected: unreachable object at (15, 5)
Sampling rejected: unreachable object at (5, 7)
Sampling rejected: unreachable object at (7, 2)
Sampling rejected: unreachable object at (16, 11)
Sampling rejected: unreachable object at (9, 3)


In [8]:
from IPython.display import HTML
import uuid

gif_path = os.path.join(os.getcwd(), *dataset_id.split('/')[:-1], dataset_id.split('/')[-1] + ".gif")

# === Display in notebook ===
cache_buster = uuid.uuid4().hex
HTML(f'<img src="{gif_path}?v={cache_buster}" width="400">')

## Minigrid's FourRooms Random
Agent's goal is reach the goal in a four room minigrid scenario.

This agent was generated sampling random action from action space.  


In [9]:
dataset_id = "D4RL/minigrid/fourrooms-random-v0"
path = '.'
num_frames = 512
fps = 1

generate_gif(dataset_id, path, num_frames, fps)

gif_path = os.path.join(path, f"{dataset_id.split('/')}.gif")

In [10]:
from IPython.display import HTML
import uuid

gif_path = os.path.join(os.getcwd(), *dataset_id.split('/')[:-1], dataset_id.split('/')[-1] + ".gif")

# === Display in notebook ===
cache_buster = uuid.uuid4().hex
HTML(f'<img src="{gif_path}?v={cache_buster}" width="400">')

## Fourroms optimal
Same scenario but agent behaves optimally

In [11]:
dataset_id = "D4RL/minigrid/fourrooms-v0"
path = '.'  
num_frames = 512
fps = 1
generate_gif(dataset_id, path, num_frames, fps)

gif_path = os.path.join(path, f"{dataset_id.split('/')[-1]}.gif")

In [12]:
from IPython.display import HTML
import uuid
gif_path = os.path.join(os.getcwd(), *dataset_id.split('/')[:-1], dataset_id.split('/')[-1] + ".gif")
# === Display in notebook ===
cache_buster = uuid.uuid4().hex
HTML(f'<img src="{gif_path}?v={cache_buster}" width="400">')


## Pen Human

In [None]:
dataset_id = "D4RL/pen/human-v1"
path = '.'  
num_frames = 512
fps = 1
generate_gif(dataset_id, path, num_frames, fps)

gif_path = os.path.join(path, f"{dataset_id.split('/')[-1]}.gif")

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [None]:
from IPython.display import HTML
import uuid
gif_path = os.path.join(os.getcwd(), *dataset_id.split('/')[:-1], dataset_id.split('/')[-1] + ".gif")
# === Display in notebook ===
cache_buster = uuid.uuid4().hex
HTML(f'<img src="{gif_path}?v={cache_buster}" width="400">')