In [None]:
from plot import PlotEvalData, plot_eval
from agent import Agent
from config import Config
import matplotlib.pyplot as plt
import shutil
import os
import torch
print(torch.cuda.is_available())

EMPTY_ENV = "MiniGrid-Empty-Random-6x6-v0"
GO_TO_OBJ_ENV = "MiniGrid-GoToObject-6x6-N2-v0"

In [None]:
# train control model for fetch
agent = Agent(GO_TO_OBJ_ENV, "a2c_fetch_control", num_envs=16)
if agent.frames_trained() == 0:
    agent.eval(100)
for i in range(50):
    if agent.train(1e4*(i+1), 'a2c'):
        agent.eval(100)

In [None]:
# train base model for transfer learning
agent = Agent(EMPTY_ENV, "a2c_empty", num_envs=2)
for i in range(50):
    agent.train(1e4*(i+1), 'a2c')

In [None]:
# Setup TL models
os.makedirs(os.path.dirname("storage/a2c_fetch_tl/status.pt"), exist_ok=True)
shutil.copy2("storage/a2c_empty/status.pt", "storage/a2c_fetch_tl/status.pt")
os.makedirs(os.path.dirname("storage/a2c_fetch_tl_w_rs/status.pt"), exist_ok=True)
shutil.copy2("storage/a2c_empty/status.pt", "storage/a2c_fetch_tl_w_rs/status.pt")

In [None]:
import numpy as np
import minigrid.core.constants as constants

MAX_REWARD = 1
COLOR_NAMES = constants.COLOR_NAMES
OBJECT_NAMES = sorted(list(constants.OBJECT_TO_IDX.keys()))
ACTION_NAMES = ["get"]
ACTION_TO_ACTIONS = {
    "get": [0,1,2]
}

def reshape_reward(obs, action, reward, done):
    # no need to reshape if done
    if done:
        return reward
    
    # guess target tile
    mission = obs['mission']
    if mission is None:
        return reward
    for color in COLOR_NAMES:
        if color in mission:
            target_color = color
            break
    for obj in OBJECT_NAMES:
        if obj in mission:
            target_obj = obj
            break
    target = np.array([constants.OBJECT_TO_IDX[target_obj], constants.COLOR_TO_IDX[target_color], 0])

    # find target tile
    target_pos = None
    for x,r in enumerate(obs['image']):
        for y,c in enumerate(r):
            if np.array_equal(c, target):
                target_pos = np.array([x,y])
                break
        if target_pos is not None:
            break
    if target_pos is not None:
        # give reward based on distance to target
        man_dist = find_manhattan_distance(target_pos, np.array([3,4]))
        if man_dist > 0:
            reward += 1/(man_dist*50)
    return MAX_REWARD if reward > MAX_REWARD else reward

def find_manhattan_distance(p1, p2):
    return np.sum(np.abs(p1-p2))

In [None]:
# train tl only model
agent = Agent(GO_TO_OBJ_ENV, "a2c_fetch_tl", num_envs=16)
agent.eval(100)
for i in range(50, 100):
    if agent.train(1e4*(i+1), 'a2c'):
        agent.eval(100)

In [None]:
# train tl w rs model
agent = Agent(GO_TO_OBJ_ENV, "a2c_fetch_tl_w_rs", num_envs=16)
agent.eval(100)
for i in range(50, 100):
    if agent.train(1e4*(i+1), 'a2c', algo_config=Config(reshape_reward=reshape_reward)):
        agent.eval(100)

In [None]:
plot_eval([
    PlotEvalData("storage/a2c_fetch_control/eval.csv", "Control", show_min_max_fill=False), 
    PlotEvalData("storage/a2c_fetch_tl/eval.csv", "TL", color='red', frame_offset=500000, show_min_max_fill=False),
        PlotEvalData("storage/a2c_fetch_tl_w_rs/eval.csv", "TL w/ RS", color='pink', frame_offset=500000, show_min_max_fill=False),
], title="Go To Object Using A2C")
plt.show()