In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging
import csv
import json
import gym
import time
import datetime
import torch
import numpy as np
import subprocess

import babyai
import babyai.utils as utils
import babyai.rl
# from babyai.arguments import ArgumentParser
from babyai.model import ACModel
from babyai.evaluate import batch_evaluate
from babyai.utils.agent import ModelAgent
from collections import Counter

In [None]:
args = {
    'seed': 666,  # seed for generate instruction images 
    'procs': 1,
    'env': "BabyAI-PickupLocTemplate-v0",
    'instr_arch': 'gru',
    'no-mem': False,
    'algo': 'ppo',
    'arch': 'expert_filmcnn',
    'pretrained_model': "BabyAI-PickupLoc-v0_template",
    'model': None,
}

utils.seed(args['seed'])

##### Environment

In [None]:
simulated_env = gym.make(args['env'])
simulated_obs = None

##### Obsspreprocessor

In [None]:
obss_preprocessor = utils.ObssPreprocessor(args['pretrained_model'], None)

##### Pretrained agent

In [None]:
pretrained_agent = utils.load_model(args['pretrained_model'], raise_not_found=True)
pretrained_agent.eval()

##### Image generation

In [None]:
def full_obs(env):
    full_obs = env.grid.encode()

    # numeric encoding
    full_obs[env.agent_pos[0]][env.agent_pos[1]] = np.array([
        10,
        0,
        env.agent_dir
    ])

    # one-hot encoding
    full_obs_oh = np.zeros((full_obs.shape[0], full_obs.shape[1], 21))
    channel_start_index = {0: 0, 1:11, 2:17}

    visited_block = set()
    for x in range(full_obs.shape[0]):
        for y in range(full_obs.shape[1]):
            for ch in range(full_obs.shape[2]):
                value = full_obs[x][y][ch]
                full_obs_oh[x][y][channel_start_index[ch] + value] = 1

    return full_obs_oh

In [None]:
def rgb_obs(env):
    rgb_img = env.render(
        mode='rgb_array',
        highlight=False,
        tile_size=8
    )

    return rgb_img

In [None]:
from gym_minigrid.window import Window
def redraw(env):
    img = env.render('rgb_array', tile_size=32)
    window.show_img(img)

In [None]:
device='cuda'
visited_mission = Counter()

In [None]:
for _ in range(1000000):
    simulated_obs = simulated_env.reset()
    mission = simulated_obs['mission']
    
    # limit maximum number of template observation
    if mission in visited_mission and visited_mission[mission] >= 100:
        continue

    # print(mission)
    # window = Window('gym_minigrid - ' + args['env'])
    # redraw(simulated_env)

    img_instr = []
    img_instr.append(full_obs(simulated_env))

    # simulate until done
    memory = torch.zeros(1, pretrained_agent.memory_size, device=device)
    mask = torch.ones(1, device=device)
    done = False
    while not done:
        preprocessed_obs = obss_preprocessor([simulated_obs], device=device)
        with torch.no_grad():
            model_results = pretrained_agent(preprocessed_obs, memory * mask.unsqueeze(1))
            dist = model_results['dist']
            memory_ = model_results['memory']
        action = dist.sample()
        obs, reward, done, env_info = simulated_env.step(action.cpu().numpy())

        simulated_obs = obs
        memory = memory_
        mask = 1 - torch.tensor(done, device=device, dtype=torch.float)
        mask = torch.reshape(mask, (1,))

        if done:
            # print(f'done. reward={reward:.2f}')
            img_instr.append(full_obs(simulated_env))

    # window = Window('gym_minigrid - ' + args['env'])
    # redraw(simulated_env)

    img_instr = np.stack(img_instr, axis=0)
    
    if reward > 0:
        visited_mission[mission] += 1
        
        if not os.path.exists(f'instruction_images/{mission}'):
            os.makedirs(f'instruction_images/{mission}')
        np.save(f'instruction_images/{mission}/{visited_mission[mission]}.npy', img_instr)
        
        print(f"save image for \'{mission}\': {visited_mission[mission]}/100")

### Combine multiple images

In [None]:
import os
import numpy as np

In [None]:
inst_dirs = os.listdir("instruction_images")
inst_dirs[0]

In [None]:
%%time
for instr_dir in inst_dirs:
    files = os.listdir(f"instruction_images/{instr_dir}")
    cnt = 1
    for i, file in enumerate(files):
        for j, file in enumerate(files):
            if i == j:
                continue
                
            im1 = np.load(f"instruction_images/{instr_dir}/{file}")
            im2 = np.load(f"instruction_images/{instr_dir}/{file}")
            im_cat = np.concatenate([im1, im2], axis=-1)
            
            if not os.path.exists(f"instruction_images_multiple/{instr_dir}"):
                os.mkdir(f"instruction_images_multiple/{instr_dir}")
            np.save(f"instruction_images_multiple/{instr_dir}/{cnt}.npy", im_cat)
            print(f"{instr_dir}/{cnt}.npy")
            cnt += 1