# Import Libraries

In [None]:
import copy
import utils
import torch
import constants
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from env import Env
from agent import Agent
from torchsummary import summary
from torch.distributions import Normal, Categorical

# Initialise Environment

In [None]:
#initialise environment
min_x, max_x =  -0.110 - 0.150,   -0.110 + 0.150
min_y, max_y =   0.560 - 0.125,    0.560 + 0.125
min_z, max_z =               0,              0.4 

workspace_lim = np.asarray([[min_x, max_x], 
                            [min_y, max_y],
                            [min_z, max_z]])

print(f"workspace space: \n{workspace_lim}")

obj_dir = 'objects/blocks/'
N_obj = 5

env = Env(obj_dir, N_obj, workspace_lim, cluttered_mode= True, is_debug = False)

# Initialise Agent

In [None]:
agent = Agent(env, 
              max_memory_size = 100000, 
              max_memory_size_rl = 200000,
              max_memory_size_hld = 50000,
              is_debug = True, 
              N_batch = 512, 
              N_batch_hld = 512, 
              lr = 1e-4, 
              hld_lr = 1e-4,
              tau = 0.05,
              tau_hld = 0.05,
              max_action_taken = 50,
              max_result_window = 500,
              max_result_window_hld = 250,
              max_result_window_eval = 100,
              max_stage1_episode = 200,
              N_grasp_step = 25, #define the maximum step for low-level grasping network
              N_push_step = 25, #define the maximum step for low-level pushing network
              success_rate_threshold = 0.7,
              checkpt_dir_agent="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/agent", 
              checkpt_dir_models="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/models",
              exp_dir_expert="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/exp_expert", 
              exp_dir_rl="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/exp_rl",
              exp_dir_hld="/media/ryan/Seagate/research_proj_backup/research_2.0/logs/exp_hld")

# Interact

In [4]:
# from buffer import BufferReplay
# from buffer_hld import BufferReplay_HLD
# import os
# import pickle

In [5]:
# buffer_replay = BufferReplay(max_memory_size = agent.max_memory_size_rl, checkpt_dir = agent.exp_dir_rl)

In [6]:
# buffer_replay_expert = BufferReplay(max_memory_size = agent.max_memory_size, checkpt_dir = agent.exp_dir_expert)

In [7]:
# buffer_replay_hld = BufferReplay_HLD(max_memory_size = int(agent.max_memory_size_hld), checkpt_dir = agent.exp_dir_hld)

In [8]:
# #save memory counter
# data_dict_mem_counter = {'memory_cntr': 97178}
# file_name = os.path.join(buffer_replay.checkpt_dir, "memory_cntr.pkl")

# with open(file_name, 'wb') as file:
#     pickle.dump(data_dict_mem_counter, file)

In [None]:
agent.interact(max_episode = 1800, 
               hld_mode = constants.HLD_MODE,
               lla_mode = constants.BC_RL,
               is_eval = False)

In [10]:
agent.is_eval = False
agent.load_agent_data()

In [None]:
plt.plot(np.array(agent.complete_record_train), 'o')

In [None]:
ATC_mean_train = np.sum(agent.action_taken_record_train)/(np.array(agent.action_taken_record_train) > 0).sum()
print(f"[HLD ATC] ATC mean: {ATC_mean_train}/{agent.best_ATC_mean_train} [{agent.max_result_window_hld}]")

plt.plot(agent.action_taken_record_train, '-.')

In [None]:
fig, ax = plt.subplots(1, 2)

fig.set_figheight(5)
fig.set_figwidth(10)

ax[0].plot(np.array(agent.CR_train)*100., '-.')
ax[0].set_xlabel('Number of Episodes in stage 3 training')
ax[0].set_ylabel('Moveing Average of Completion Rate (%)')

ax[1].plot(agent.ATC_train, '-.')
ax[1].set_xlabel('Number of Episodes in stage 3 training')
ax[1].set_ylabel('Moving Average of Actions Taken for Completion')

fig.tight_layout()
# print(f'max ATC: {np.array(agent.ATC_train[250:]).max()}')
# print(f'min ATC: {np.array(agent.ATC_train[250:]).min()}')

In [None]:
fig, ax = plt.subplots(1, 2)

fig.set_figheight(5)
fig.set_figwidth(10)

ax[0].plot(np.array(agent.grasp_success_rate_hist)*100.)
ax[0].set_xlabel('Attempts Made by Grasping Network')
ax[0].set_ylabel('Moving Average Success Rate (%)')

ax[1].plot(np.array(agent.push_success_rate_hist)*100.)
ax[1].set_xlabel('Attempts Made by Pushing Network')
ax[1].set_ylabel('Moving Average Success Rate (%)')

print(f'best grasp success rate: {np.array(agent.grasp_success_rate_hist).max()*100.}%')
print(f'best push success rate: {np.array(agent.push_success_rate_hist).max()*100.}%')

In [None]:
agent.episode