In [1]:
import os

import gym
from habitat_baselines.common import env_spec
import numpy as np

import habitat
import habitat.gym
from habitat_baselines.common.env_spec import EnvironmentSpec
from habitat.utils.visualizations.utils import (
    observations_to_image,
    overlay_frame,
)
from habitat_baselines.rl.ddppo.ddp_utils import (
    EXIT,
    get_distrib_size,
    init_distrib_slurm,
    is_slurm_batch_job,
    load_resume_state,
    rank0_only,
    requeue_job,
    save_resume_state,
)
import hydra
from habitat_baselines.common.baseline_registry import baseline_registry

import torch
from habitat_baselines.utils.info_dict import extract_scalars_from_info


from habitat_baselines.common.obs_transformers import (
    apply_obs_transforms_batch,
    apply_obs_transforms_obs_space,
    get_active_obs_transforms,
)
from habitat_baselines.utils.common import (
    batch_obs,
    generate_video,
    get_action_space_info,
    inference_mode,
    is_continuous_action_space,
)

from lmnav.emb_transfer.old_eai_policy import OldEAIPolicy


# Quiet the Habitat simulator logging
os.environ["MAGNUM_LOG"] = "quiet"
os.environ["HABITAT_SIM_LOG"] = "quiet"

os.chdir('/srv/flash1/pputta7/projects/lm-nav')


[2023-09-03 13:03:40,032] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:

def _init_envs(config=None, is_eval: bool = False):
    env_factory = hydra.utils.instantiate(config.habitat_baselines.vector_env_factory)
    envs = env_factory.construct_envs(
            config,
            workers_ignore_signals=is_slurm_batch_job(),
            enforce_scenes_greater_eq_environments=is_eval,
            is_first_rank=(
                not torch.distributed.is_initialized()
                or torch.distributed.get_rank() == 0
            ),
        )
    _env_spec = EnvironmentSpec(
        observation_space=envs.observation_spaces[0],
        action_space=envs.action_spaces[0],
        orig_action_space=envs.orig_action_spaces[0],
    )

    return envs, _env_spec

def _create_obs_transforms(config, env_spec):
    obs_transforms = get_active_obs_transforms(config)
    env_spec.observation_space = apply_obs_transforms_obs_space(
            env_spec.observation_space, obs_transforms
        )
    return obs_transforms, env_spec

def _setup_teacher(teacher_ckpt, obs_space, action_space):
    teacher = OldEAIPolicy.hardcoded(OldEAIPolicy, obs_space, action_space)
    torch.set_grad_enabled(False)

    ckpt_dict = torch.load(teacher_ckpt)
    state_dict = ckpt_dict['state_dict']
    state_dict = {k[len('actor_critic.'):]: v for k, v in state_dict.items()}

    teacher.load_state_dict(state_dict)
    return teacher


def _construct_state_tensors(num_environments, device):
    rnn_hx = torch.zeros((num_environments, 2, 512), device=device)
    prev_actions = torch.zeros(num_environments, 1, device=device, dtype=torch.long)
    not_done_masks = torch.ones(num_environments, 1, device=device, dtype=torch.bool)

    return rnn_hx, prev_actions, not_done_masks 
    
    
def collect_episodes(envs, teacher, obs_transform, device, deterministic=False, filter_f=None, N=None):
    if filter_f is None:
        filter_f = lambda _: True
    
    device = torch.device(device)
    num_envs = envs.num_envs
    step = 0
    dataset = []
    episodes = [[] for _ in range(num_envs)]

    rnn_hx, prev_actions, not_done_masks = _construct_state_tensors(num_envs, device)

    teacher.to(device)
    teacher.eval()
    
    observations = envs.reset()

    while (N is None) or (len(dataset) < N):
        print(step)
        # roll out a step
        batch = batch_obs(observations, device)
        batch = apply_obs_transforms_batch(batch, obs_transform)
    
        policy_result = teacher.act(batch,
                                  rnn_hx,
                                  prev_actions,
                                  not_done_masks,
                                  deterministic=deterministic)
        
        prev_actions.copy_(policy_result.actions)
        rnn_hx = policy_result.rnn_hidden_states
    
        step_data = [a.item() for a in policy_result.env_actions.cpu()]
        outputs = envs.step(step_data)
        next_observations, rewards_l, dones, infos = [list(x) for x in zip(*outputs)]
    
        # insert episode into list
        for i, episode in enumerate(episodes):
            episode.append({'observation': observations[i],
                            'reward': rewards_l[i],
                            'info': infos[i]})
    
        # check if any episodes finished and archive it into dataset
        for i, done in enumerate(dones):
            if done and filter_f(episodes[i]):
                dataset.append(episodes[i])
                episodes[i] = []
    
                # reset state tensors
                rnn_hx[i] = torch.zeros(rnn_hx.shape[1:])
                prev_actions[i] = torch.zeros(prev_actions.shape[1:])
                not_done_masks[i] = torch.ones(not_done_masks.shape[1:])
    
        observations = next_observations
        step += 1

    return dataset

In [8]:
config = habitat.get_config("lmnav/configs/habitat/imagenav_hm3d.yaml")
envs, env_spec = _init_envs(config)
obs_transform, env_spec = _create_obs_transforms(config, env_spec)

teacher_ckpt = "ckpts/uLHP.300.pth"
teacher = _setup_teacher(teacher_ckpt, env_spec.observation_space, env_spec.action_space)

2023-09-03 13:05:58,682 Initializing dataset PointNav-v1
2023-09-03 13:05:58,682 There are less scenes (1) than environments (2). Each environment will use all the scenes instead of using a subset.
2023-09-03 13:06:02,464 Initializing dataset PointNav-v1
2023-09-03 13:06:02,464 Initializing dataset PointNav-v1
2023-09-03 13:06:02,743 initializing sim Sim-v0
2023-09-03 13:06:02,744 initializing sim Sim-v0
2023-09-03 13:06:04,122 Initializing task Nav-v0
2023-09-03 13:06:04,122 Initializing task Nav-v0
2023-09-03 13:06:04,600 Resizing observation of depth: from (480, 640) to (120, 160)
2023-09-03 13:06:04,601 Resizing observation of imagegoal: from (480, 640) to (120, 160)
2023-09-03 13:06:04,601 Resizing observation of rgb: from (480, 640) to (120, 160)
2023-09-03 13:06:05,769 Using weights from /srv/flash1/rramrakhya6/summer_2022/mae-for-eai/data/visual_encoders/mae_vit_small_decoder_large_HGPS_RE10K_100.pth: _IncompatibleKeys(missing_keys=[], unexpected_keys=['mask_token', 'decoder_po

In [9]:
device = 'cuda:0'
f = lambda episodes: episodes[-1]['info']['distance_to_goal'] <= 1
dataset = collect_episodes(envs, teacher, obs_transform, device, deterministic=False, filter_f=f, N=2)

0




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230


In [67]:
observations, rewards, infos = list(zip(*dataset[7]))
frames = [observations_to_image(obs, info) for obs, info in zip(observations, infos)]
disp_info = {k: [infos[i][k] for i in range(len(infos))] for k in infos[0].keys()}

generate_video(
            video_option=['disk'],
            video_dir='videos/',
            images=frames,
            episode_id=f"test",
            checkpoint_idx=300,
            metrics=extract_scalars_from_info(disp_info),
            fps=config.habitat_baselines.video_fps,
            tb_writer=None,
            keys_to_include_in_name=config.habitat_baselines.eval_keys_to_include_in_name)

2023-09-01 11:12:54,403 Video created: videos/episode=test-ckpt=300-.mp4
100%|████████████████████████████████████████████████████████████████████████████████| 183/183 [00:02<00:00, 66.64it/s]
