In [14]:
import gzip
import pickle
import os
import numpy as np
from habitat_baselines.utils.common import batch_obs, generate_video
from habitat.core.logging import logger
from habitat.core.utils import try_cv2_import
from habitat.utils.common import flatten_dict
from habitat.utils.visualizations import maps
import cv2
import torch

datadir = "/srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav/offline_00744_actprobs/"

In [15]:
def observations_to_image(observation, info) :
    r"""Generate image of single frame from observation and info
    returned from a single environment step().

    Args:
        observation: observation returned from an environment step().
        info: info returned from an environment step().

    Returns:
        generated image of a single frame.
    """
    render_obs_images = []
    for sensor_name in observation:
        if len(observation[sensor_name].shape) > 1:
            obs_k = observation[sensor_name]
            if not isinstance(obs_k, np.ndarray):
                obs_k = obs_k.cpu().numpy()
            if obs_k.dtype != np.uint8:
                obs_k = obs_k * 255.0
                obs_k = obs_k.astype(np.uint8)
            if obs_k.shape[2] == 1:
                obs_k = np.concatenate([obs_k for _ in range(3)], axis=2)
            render_obs_images.append(obs_k)

    assert (
        len(render_obs_images) > 0
    ), "Expected at least one visual sensor enabled."

    shapes_are_equal = len(set(x.shape for x in render_obs_images)) == 1
    if not shapes_are_equal:
        render_frame = tile_images(render_obs_images)
    else:
        render_frame = np.concatenate(render_obs_images, axis=1)

    # draw collision
    collisions_key = "collisions"
    if collisions_key in info and info[collisions_key]["is_collision"]:
        render_frame = draw_collision(render_frame)

    top_down_map_key = "top_down_map"
    if top_down_map_key in info:
        top_down_map = maps.colorize_draw_agent_and_fit_to_height(
            info[top_down_map_key], render_frame.shape[0]
        )
        render_frame = np.concatenate((render_frame, top_down_map), axis=1)

    action_prob_keys = "act_probs"
    if action_prob_keys in info:
        ap = info[action_prob_keys].tolist()
        txt = " ".join([f"{a}: {p:.2f}" for a, p in zip("SFLR", ap)])
        render_frame = draw_text_box(render_frame, txt, 0, -200)

    frame_idx_key = "frame_idx"
    if frame_idx_key in info:
        render_frame = draw_text_box(render_frame, f"Frame: {info[frame_idx_key]}", 0, -150)
    
    return render_frame


import cv2
import numpy as np

def draw_text_box(frame, text, x_offset, y_offset):
    # Convert frame to BGR format (OpenCV uses BGR by default)
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    # Define font parameters
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    font_color = (255, 255, 255)  # White color
    font_thickness = 2

    # Get the size of the text
    text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
    
    # Calculate the size of the text box
    box_width = text_size[0] + 20  # Add some padding
    box_height = text_size[1] + 20  # Add some padding

    # Calculate the center of the frame
    center_x = frame.shape[1] // 2
    center_y = frame.shape[0] // 2

    # Calculate the coordinates to center the box with offsets
    box_x = center_x - (box_width // 2) + x_offset
    box_y = center_y - (box_height // 2) + y_offset

    # Draw the box
    cv2.rectangle(frame, (box_x, box_y), (box_x + box_width, box_y + box_height), (0, 0, 255), -1)

    # Calculate the position to center the text
    text_x = box_x + 10  # Add some left padding
    text_y = box_y + text_size[1] + 10  # Add some top padding

    # Put the text on the frame
    cv2.putText(frame, text, (text_x, text_y), font, font_scale, font_color, font_thickness)

    return frame



In [16]:
def construct_vid(episode_id):
    with gzip.open(os.path.join(datadir, f"data.{episode_id}.pkl.gz")) as f:
        data = pickle.load(f)
    observations = [{'rgb': data['rgb'][i], 'imagegoal': data['imagegoal'][0], 'depth': data['depth'][i]} for i in range(data['rgb'].shape[0])]
    frames = [observations_to_image(obs, {'act_probs': data['action_probs'][i], 'frame_idx': i}) for i, obs in enumerate(observations)]
    generate_video(
        video_option=['disk'],
        video_dir='/srv/flash1/pputta7/projects/lm-nav/tmpvids',
        images=frames,
        episode_id=episode_id,
        fps=5,
        checkpoint_idx=0,
        metrics={},
        tb_writer=None
    )

def compute_action_entropies(episode_id):
    with gzip.open(os.path.join(datadir, f"data.{episode_id}.pkl.gz")) as f:
        data = pickle.load(f)
    actprobs = data['action_probs']
    entropies = -torch.sum(actprobs * torch.log2(actprobs), dim=1)
    return entropies.tolist()

In [17]:
for episode_id in range(100):
    construct_vid(episode_id)

2023-09-26 16:03:29,609 Video created: /srv/flash1/pputta7/projects/lm-nav/tmpvids/episode=0-ckpt=0-.mp4
100%|█████████████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00,  9.13it/s]


KeyboardInterrupt: 

In [18]:
from tqdm import tqdm
es = []
for episode_id in tqdm(range(100)):
    es += compute_action_entropies(episode_id)

100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [02:13<00:00,  1.34s/it]


In [19]:
[(i, e) for i, e in enumerate(es) if e >= 1]

[(3, 1.445692777633667),
 (7, 1.02545166015625),
 (42, 1.0541306734085083),
 (53, 1.0275933742523193),
 (57, 1.0003399848937988),
 (70, 1.053717017173767),
 (74, 1.3822519779205322),
 (75, 1.4779868125915527),
 (77, 1.040473222732544),
 (130, 1.0397778749465942),
 (132, 1.0978426933288574),
 (134, 1.0525819063186646),
 (135, 1.034899353981018),
 (146, 1.0003974437713623),
 (252, 1.0367939472198486),
 (256, 1.052490234375),
 (259, 1.1045286655426025),
 (276, 1.0113164186477661),
 (279, 1.211916208267212),
 (296, 1.096377968788147),
 (309, 1.2561280727386475),
 (328, 1.4402052164077759),
 (329, 1.0472874641418457),
 (337, 1.0345808267593384),
 (348, 1.380676031112671),
 (388, 1.564596176147461),
 (396, 1.073052167892456),
 (418, 1.0275933742523193),
 (419, 1.005483865737915),
 (423, 1.0695948600769043),
 (427, 1.13426673412323),
 (430, 1.0559377670288086),
 (468, 1.4311978816986084),
 (469, 1.1061694622039795),
 (486, 1.0888633728027344),
 (497, 1.5309765338897705),
 (534, 1.431197881698