In [79]:
import gzip
import pickle
import os
import numpy as np
from habitat_baselines.utils.common import batch_obs, generate_video
from habitat.core.logging import logger
from habitat.core.utils import try_cv2_import
from habitat.utils.common import flatten_dict
from habitat.utils.visualizations import maps
import cv2

datadir = "/srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav/offline_10envs_actprobs/"

In [86]:
def observations_to_image(observation, info) :
    r"""Generate image of single frame from observation and info
    returned from a single environment step().

    Args:
        observation: observation returned from an environment step().
        info: info returned from an environment step().

    Returns:
        generated image of a single frame.
    """
    render_obs_images = []
    for sensor_name in observation:
        if len(observation[sensor_name].shape) > 1:
            obs_k = observation[sensor_name]
            if not isinstance(obs_k, np.ndarray):
                obs_k = obs_k.cpu().numpy()
            if obs_k.dtype != np.uint8:
                obs_k = obs_k * 255.0
                obs_k = obs_k.astype(np.uint8)
            if obs_k.shape[2] == 1:
                obs_k = np.concatenate([obs_k for _ in range(3)], axis=2)
            render_obs_images.append(obs_k)

    assert (
        len(render_obs_images) > 0
    ), "Expected at least one visual sensor enabled."

    shapes_are_equal = len(set(x.shape for x in render_obs_images)) == 1
    if not shapes_are_equal:
        render_frame = tile_images(render_obs_images)
    else:
        render_frame = np.concatenate(render_obs_images, axis=1)

    # draw collision
    collisions_key = "collisions"
    if collisions_key in info and info[collisions_key]["is_collision"]:
        render_frame = draw_collision(render_frame)

    top_down_map_key = "top_down_map"
    if top_down_map_key in info:
        top_down_map = maps.colorize_draw_agent_and_fit_to_height(
            info[top_down_map_key], render_frame.shape[0]
        )
        render_frame = np.concatenate((render_frame, top_down_map), axis=1)

    action_prob_keys = "act_probs"
    if action_prob_keys in info:
        ap = info[action_prob_keys].tolist()
        txt = " ".join([f"{a}: {p:.2f}" for a, p in zip("SFLR", ap)])
        render_frame = draw_text_box(render_frame, txt, 0, -200)

    frame_idx_key = "frame_idx"
    if frame_idx_key in info:
        render_frame = draw_text_box(render_frame, f"Frame: {info[frame_idx_key]}", 0, -150)
    
    return render_frame


import cv2
import numpy as np

def draw_text_box(frame, text, x_offset, y_offset):
    # Convert frame to BGR format (OpenCV uses BGR by default)
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    # Define font parameters
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    font_color = (255, 255, 255)  # White color
    font_thickness = 2

    # Get the size of the text
    text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
    
    # Calculate the size of the text box
    box_width = text_size[0] + 20  # Add some padding
    box_height = text_size[1] + 20  # Add some padding

    # Calculate the center of the frame
    center_x = frame.shape[1] // 2
    center_y = frame.shape[0] // 2

    # Calculate the coordinates to center the box with offsets
    box_x = center_x - (box_width // 2) + x_offset
    box_y = center_y - (box_height // 2) + y_offset

    # Draw the box
    cv2.rectangle(frame, (box_x, box_y), (box_x + box_width, box_y + box_height), (0, 0, 255), -1)

    # Calculate the position to center the text
    text_x = box_x + 10  # Add some left padding
    text_y = box_y + text_size[1] + 10  # Add some top padding

    # Put the text on the frame
    cv2.putText(frame, text, (text_x, text_y), font, font_scale, font_color, font_thickness)

    return frame



In [87]:
def construct_vid(episode_id):
    with gzip.open(os.path.join(datadir, f"data.{episode_id}.pkl.gz")) as f:
        data = pickle.load(f)
    observations = [{'rgb': data['rgb'][i], 'imagegoal': data['imagegoal'][0], 'depth': data['depth'][i]} for i in range(data['rgb'].shape[0])]
    frames = [observations_to_image(obs, {'act_probs': data['action_probs'][i], 'frame_idx': i}) for i, obs in enumerate(observations)]
    generate_video(
        video_option=['disk'],
        video_dir='/srv/flash1/pputta7/projects/lm-nav/tmpvids',
        images=frames,
        episode_id=episode_id,
        fps=5,
        checkpoint_idx=0,
        metrics={},
        tb_writer=None
    )

In [88]:
for episode_id in range(100):
    construct_vid(episode_id)

2023-09-26 12:51:59,007 Video created: /srv/flash1/pputta7/projects/lm-nav/tmpvids/episode=0-ckpt=0-.mp4
100%|█████████████████████████████████████████████████████| 13/13 [00:00<00:00, 125.75it/s]
2023-09-26 12:51:59,636 Video created: /srv/flash1/pputta7/projects/lm-nav/tmpvids/episode=1-ckpt=0-.mp4
100%|█████████████████████████████████████████████████████| 27/27 [00:00<00:00, 165.84it/s]
2023-09-26 12:52:00,360 Video created: /srv/flash1/pputta7/projects/lm-nav/tmpvids/episode=2-ckpt=0-.mp4
100%|█████████████████████████████████████████████████████| 20/20 [00:00<00:00, 150.81it/s]
2023-09-26 12:52:01,084 Video created: /srv/flash1/pputta7/projects/lm-nav/tmpvids/episode=3-ckpt=0-.mp4
100%|█████████████████████████████████████████████████████| 27/27 [00:00<00:00, 166.17it/s]
2023-09-26 12:52:01,743 Video created: /srv/flash1/pputta7/projects/lm-nav/tmpvids/episode=4-ckpt=0-.mp4
100%|█████████████████████████████████████████████████████| 15/15 [00:00<00:00, 134.30it/s]
2023-09-26 12:5