# Import Libraries

In [None]:
import copy
import utils
import torch
import constants
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Normal, Categorical

from env import Env
from agent import Agent
from torchsummary import summary

# Initialise Environment

In [None]:
#initialise environment
min_x, max_x =  -0.110 - 0.175,   -0.110 + 0.175
min_y, max_y =   0.560 - 0.150,    0.560 + 0.150
min_z, max_z =               0,              0.4 

workspace_lim = np.asarray([[min_x, max_x], 
                            [min_y, max_y],
                            [min_z, max_z]])

print(f"workspace space: \n{workspace_lim}")

obj_dir = 'objects/blocks/'
N_obj   = 10

env = Env(obj_dir, N_obj, workspace_lim)

# Initialise Agent

In [None]:
agent = Agent(env, max_memory_size = 100, is_debug = True, N_batch = 128, N_batch_hld = 32, lr = 1e-4, hld_lr = 1e-6)

# Gather Demonstration Experience

In [None]:
agent.gather_guidance_experience()

# Get Train and Test Loader

In [None]:
agent.buffer_replay.load_buffer()

grasp_exp = agent.buffer_replay.get_experience_by_action_type(constants.GRASP)
push_exp  = agent.buffer_replay.get_experience_by_action_type(constants.PUSH)
hld_exp   = agent.buffer_replay_hld.get_experience()

print(f"N_grasp_exp: {len(grasp_exp[0])}")
print(f"N_push_exp: {len(push_exp[0])}")
print(f"N_HLD_exp: {len(hld_exp[0])}")

hld_train_loader, hld_test_loader     = agent.get_train_test_dataloader_hld_net(hld_exp, train_ratio=0.9)
grasp_train_loader, grasp_test_loader = agent.get_train_test_dataloader(grasp_exp, is_grasp = True, train_ratio=0.9)
push_train_loader, push_test_loader   = agent.get_train_test_dataloader( push_exp, is_grasp = False, train_ratio=0.9)

# HLD-net clone

In [None]:
# # print(hld_exp[0].shape, hld_exp[1].shape, hld_exp[2].shape, hld_exp[3].shape)
# index = 1001
# # print(hld_exp[2][index])
# plt.imshow(grasp_exp[2][index])
# # print(hld_exp[1][5])

In [None]:
# agent.hld_net.load_checkpoint()
agent.behaviour_cloning_hld(hld_train_loader, hld_test_loader, agent.hld_net, agent.hld_net_target, num_epochs = 500)

# Push Clone

In [None]:
# # agent.push_critic1.load_checkpoint()
# # agent.push_critic2.load_checkpoint()
# # agent.push_actor.load_checkpoint()
# agent.behaviour_cloning(push_train_loader, push_test_loader, 
#                         agent.push_critic1, agent.push_critic2, 
#                         agent.push_critic1_target, agent.push_critic2_target, 
#                         agent.push_actor, num_epochs = 500, is_grasp = False)

# Grasp Clone

In [None]:
# # agent.grasp_critic1.load_checkpoint()
# # agent.grasp_critic2.load_checkpoint()
# # agent.grasp_actor.load_checkpoint()
# agent.behaviour_cloning(grasp_train_loader, grasp_test_loader, 
#                         agent.grasp_critic1, agent.grasp_critic2, 
#                         agent.grasp_critic1_target, agent.grasp_critic2_target, 
#                         agent.grasp_actor, num_epochs = 500, is_grasp = True)