In [6]:
# [setup]
import os
from typing import TYPE_CHECKING, Union, cast

import matplotlib.pyplot as plt
import numpy as np

import habitat
from habitat.config.default_structured_configs import (
    CollisionsMeasurementConfig,
    FogOfWarConfig,
    TopDownMapMeasurementConfig,
)
from habitat.core.agent import Agent
from habitat.tasks.nav.nav import NavigationEpisode, NavigationGoal
from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower
from habitat.utils.visualizations import maps
from habitat.utils.visualizations.utils import (
    images_to_video,
    observations_to_image,
    overlay_frame,
)
from habitat.core.registry import registry
from habitat_sim.utils import viz_utils as vut
from lmnav.config.default import get_config
import gzip
import json
import os
import pickle
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from habitat.config import read_write
from habitat.core.dataset import ALL_SCENES_MASK, Dataset
from habitat.core.registry import registry
from habitat.tasks.nav.nav import (
    NavigationEpisode,
    NavigationGoal,
    ShortestPathPoint,
)

if TYPE_CHECKING:
    from omegaconf import DictConfig

# Quiet the Habitat simulator logging
os.environ["MAGNUM_LOG"] = "quiet"
os.environ["HABITAT_SIM_LOG"] = "quiet"

if TYPE_CHECKING:
    from habitat.core.simulator import Observations
    from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim


In [17]:
CONTENT_SCENES_PATH_FIELD = "content_scenes_path"
DEFAULT_SCENE_PATH_PREFIX = "data/scene_datasets/"

@registry.register_dataset(name="OfflineDataset-v1")
class OfflineDatasetV1(Dataset):
    r"""Class inherited from Dataset that loads Point Navigation dataset."""

    episodes: List[NavigationEpisode]
    content_scenes_path: str = "{data_path}/content/{scene}.json.gz"

    @staticmethod
    def check_config_paths_exist(config: "DictConfig") -> bool:
        return os.path.exists(
            config.data_path.format(split=config.split)
        ) and os.path.exists(config.scenes_dir)

    @classmethod
    def get_scenes_to_load(cls, config: "DictConfig") -> List[str]:
        r"""Return list of scene ids for which dataset has separate files with
        episodes.
        """
        dataset_dir = os.path.dirname(
            config.data_path.format(split=config.split)
        )
        if not cls.check_config_paths_exist(config):
            raise FileNotFoundError(
                f"Could not find dataset file `{dataset_dir}`"
            )

        cfg = config.copy()
        with read_write(cfg):
            cfg.content_scenes = []
            dataset = cls(cfg)
            has_individual_scene_files = os.path.exists(
                dataset.content_scenes_path.split("{scene}")[0].format(
                    data_path=dataset_dir
                )
            )
            if has_individual_scene_files:
                return cls._get_scenes_from_folder(
                    content_scenes_path=dataset.content_scenes_path,
                    dataset_dir=dataset_dir,
                )
            else:
                # Load the full dataset, things are not split into separate files
                cfg.content_scenes = [ALL_SCENES_MASK]
                dataset = cls(cfg)
                return list(map(cls.scene_from_scene_path, dataset.scene_ids))

    @staticmethod
    def _get_scenes_from_folder(
        content_scenes_path: str, dataset_dir: str
    ) -> List[str]:
        scenes: List[str] = []
        content_dir = content_scenes_path.split("{scene}")[0]
        scene_dataset_ext = content_scenes_path.split("{scene}")[1]
        content_dir = content_dir.format(data_path=dataset_dir)
        if not os.path.exists(content_dir):
            return scenes

        for filename in os.listdir(content_dir):
            if filename.endswith(scene_dataset_ext):
                scene = filename[: -len(scene_dataset_ext)]
                scenes.append(scene)
        scenes.sort()
        return scenes

    def _load_from_file(self, fname: str, scenes_dir: str) -> None:
        """
        Load the data from a file into `self.episodes`. This can load `.pickle`
        or `.json.gz` file formats.
        """

        if fname.endswith(".pickle"):
            # NOTE: not implemented for pointnav
            with open(fname, "rb") as f:
                self.from_binary(pickle.load(f), scenes_dir=scenes_dir)
        else:
            with gzip.open(fname, "rt") as f:
                self.from_json(f.read(), scenes_dir=scenes_dir)

    def __init__(self, config: Optional["DictConfig"] = None, directory=None) -> None:
        self.episodes = []

        if config is None:
            return

        datasetfile_path = config.data_path.format(split=config.split)

        self._load_from_file(datasetfile_path, config.scenes_dir)

        # Read separate file for each scene
        dataset_dir = os.path.dirname(datasetfile_path)
        has_individual_scene_files = os.path.exists(
            self.content_scenes_path.split("{scene}")[0].format(
                data_path=dataset_dir
            )
        )
        if has_individual_scene_files:
            scenes = config.content_scenes
            if ALL_SCENES_MASK in scenes:
                scenes = self._get_scenes_from_folder(
                    content_scenes_path=self.content_scenes_path,
                    dataset_dir=dataset_dir,
                )

            for scene in scenes:
                scene_filename = self.content_scenes_path.format(
                    data_path=dataset_dir, scene=scene
                )

                self._load_from_file(scene_filename, config.scenes_dir)

        else:
            self.episodes = list(
                filter(self.build_content_scenes_filter(config), self.episodes)
            )


    def to_binary(self) -> Dict[str, Any]:
        raise NotImplementedError()

    def from_binary(
        self, data_dict: Dict[str, Any], scenes_dir: Optional[str] = None
    ) -> None:
        raise NotImplementedError()

    def from_json(
        self, json_str: str, scenes_dir: Optional[str] = None
    ) -> None:
        deserialized = json.loads(json_str)
        if CONTENT_SCENES_PATH_FIELD in deserialized:
            self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD]

        for episode in deserialized["episodes"]:
            episode = NavigationEpisode(**episode)

            if scenes_dir is not None:
                if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX):
                    episode.scene_id = episode.scene_id[
                        len(DEFAULT_SCENE_PATH_PREFIX) :
                    ]

                episode.scene_id = os.path.join(scenes_dir, episode.scene_id)

            for g_index, goal in enumerate(episode.goals):
                episode.goals[g_index] = NavigationGoal(**goal)
            if episode.shortest_paths is not None:
                for path in episode.shortest_paths:
                    for p_index, point in enumerate(path):
                        path[p_index] = ShortestPathPoint(**point)
            self.episodes.append(episode)

In [18]:
import os

os.chdir("/srv/flash1/pputta7/projects/lm-nav")
config = get_config("train/nav_llama/1env_karmesh/bc/lora+clip+karmesh")
with habitat.config.read_write(config):
        config.habitat.task.measurements.update(
            {
                "top_down_map": TopDownMapMeasurementConfig(
                    map_padding=0,
                    map_resolution=128,
                    draw_source=False,
                    draw_border=False,
                    draw_shortest_path=False,
                    draw_view_points=False,
                    draw_goal_positions=False,
                    draw_goal_aabbs=False,
                    fog_of_war=FogOfWarConfig(
                        draw=False,
                        visibility_dist=5.0,
                        fov=90,
                    ),
                ),
                "collisions": CollisionsMeasurementConfig(),
            }
        )
dataset = habitat.make_dataset(
        id_dataset="OfflineDataset-v1", config=config.habitat.dataset)

2023-12-11 23:03:33,093 Initializing dataset OfflineDataset-v1


In [71]:
# episode dictionary
import multiprocessing
from tqdm import tqdm
import torch

episodes_dict = {(episode.scene_id, episode.episode_id): episode for episode in dataset.episodes}
episodes = []

class OfflineTrajectory(NavigationEpisode):
    def __init__(self, base_episode, actions, trajectory_id):
        # Initialize the base NavigationEpisode class with attributes from the base episode
        super().__init__(
            episode_id=base_episode.episode_id,
            scene_id=base_episode.scene_id,
            scene_dataset_config=base_episode.scene_dataset_config,
            additional_obj_config_paths=base_episode.additional_obj_config_paths,
            start_position=base_episode.start_position,
            start_rotation=base_episode.start_rotation,
            info=base_episode.info,
            goals=base_episode.goals,
            start_room=base_episode.start_room,
            shortest_paths=base_episode.shortest_paths
        )

        # Add the actions variable to the class
        self.actions = actions
        self.trajectory_id = trajectory_id

def process_file(file_path):
    """
    Process a single file to extract scene_id and episode_id.
    """
    try:
        if file_path.endswith(".pkl"):
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
        elif file_path.endswith(".pt"):
            data = torch.load(file_path)
        return data.get('action'), data.get("scene_id"), data.get("episode_id"), file_path.split(".")[1]

    except Exception as e:
        print(f"Error processing file {file_path}: {e}")
        return None, None
        

# Initialize a multiprocessing pool
directory = "/srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav/offline_1env_karmesh_multipath"
pool = multiprocessing.Pool(processes=16)

# go through each trajectory and create a navigation episode
# Create a list of file paths to process
file_paths = [os.path.join(root, file) for root, dirs, files in os.walk(directory) 
              for file in files if file.startswith("data") and (file.endswith(".pt") or file.endswith(".pkl"))][:]
print(len(file_paths))


# Process files in parallel with a progress bar
for actions, scene_id, episode_id, trajectory_id in tqdm(pool.imap_unordered(process_file, file_paths), total=len(file_paths)):
    try:
        base_episode = episodes_dict[(scene_id, episode_id)]
        episodes.append(OfflineTrajectory(base_episode, actions.tolist(), trajectory_id))
    except:
        print("error", file_path)

# Close the pool
pool.close()
pool.join()


5388


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5388/5388 [00:07<00:00, 757.11it/s]


In [72]:
new_dataset = habitat.Dataset()
new_dataset.episodes = episodes

new_dirpath = f"/srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav_episodes/"
os.makedirs(new_dirpath, exist_ok=True)
new_path = os.path.join(new_dirpath, f"{os.path.basename(directory)}.json.gz")

with gzip.open(new_path, 'wb+') as f:
    f.write(str.encode(new_dataset.to_json()))
    print(f'finished generating data for scene {new_path}')

finished generating data for scene /srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav_episodes/offline_1env_karmesh_multipath.json.gz


In [64]:
episodes[0]

OfflineTrajectory(episode_id=3892, scene_id='data/scene_datasets/hm3d/train/00744-1S7LAXRdDqK/1S7LAXRdDqK.basis.glb', scene_dataset_config='default', additional_obj_config_paths=[], start_position=[-8.531359672546387, 0.19160032272338867, -6.199461460113525], start_rotation=[0.0, 0.545838794638867, 0.0, -0.8378902137316014], info={'geodesic_distance': 2.936727523803711, 'difficulty': 'easy'}, _shortest_path_cache=None, goals=[NavigationGoal(position=[-9.719176292419434, 0.19160032272338867, -3.5400874614715576], radius=None)], start_room=None, shortest_paths=None)

In [55]:
os.path.basename(directory)

'offline_1env_karmesh_multipath'