# Import Libraries

In [None]:
import copy
import utils
import torch
import constants
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from env import Env
from agent import Agent
from torchsummary import summary
from torch.distributions import Normal, Categorical

# Initialise Environment

In [None]:
#initialise environment
min_x, max_x =  -0.110 - 0.150,   -0.110 + 0.150
min_y, max_y =   0.560 - 0.125,    0.560 + 0.125
min_z, max_z =               0,              0.4 

workspace_lim = np.asarray([[min_x, max_x], 
                            [min_y, max_y],
                            [min_z, max_z]])

print(f"workspace space: \n{workspace_lim}")

obj_dir = 'objects/blocks/'
N_obj   = 5

env = Env(obj_dir, N_obj, workspace_lim, cluttered_mode= True, is_debug = False)

# Initialise Agent

In [None]:
agent = Agent(env, 
              max_memory_size = 100000, 
              max_memory_size_rl = 200000,
              max_memory_size_hld = 50000,
              is_debug = True, 
              N_batch = 512, 
              N_batch_hld = 512, 
              lr = 1e-4, 
              hld_lr = 1e-4,
              tau = 0.05,
              tau_hld = 0.05,
              max_action_taken = 50,
              max_result_window = 500,
              max_result_window_hld = 250,
              max_result_window_eval = 100,
              max_stage1_episode = 200,
              N_grasp_step = 25, #define the maximum step for low-level grasping network
              N_push_step = 25, #define the maximum step for low-level pushing network
              success_rate_threshold = 0.7,
              checkpt_dir_agent="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/agent", 
              checkpt_dir_models="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/models",
              exp_dir_expert="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/exp_expert", 
              exp_dir_rl="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/exp_rl",
              exp_dir_hld="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/exp_hld")

# Interact

In [None]:
# agent.interact(max_episode = 100, 
#                lla_mode = constants.BC_RL,
#                hld_mode = constants.DEMO_MODE,
#                is_eval = True)

In [5]:
agent.hld_mode = constants.HLD_MODE
agent.is_eval  = True
agent.load_agent_data()

In [None]:
CR_rate = np.sum(agent.CR_eval)/agent.max_result_window_eval
print(f"completion rate (eval): {CR_rate*100.}%")

plt.plot(agent.CR_eval)

In [None]:
AGS_rate = np.sum(np.array(agent.AGS_eval)*np.array(agent.CR_eval))/(np.array(agent.CR_eval) > 0).sum()
print(f"AGS rate (eval): {AGS_rate*100.}%")

plt.plot(agent.AGS_eval)

In [None]:
ATC_rate = np.sum(np.array(agent.ATC_eval)*np.array(agent.CR_eval))/(np.array(agent.CR_eval) > 0).sum()
print(f"ATC rate (eval): {ATC_rate}")

plt.plot(agent.ATC_eval, '--')
plt.ylim([0, 25])

In [None]:
(np.array(agent.CR_eval) > 0).sum()