In [None]:
import copy
import utils
import torch
import numpy as np
import matplotlib.pyplot as plt

from env import Env
from agent import Agent
from torchsummary import summary

# Initialise Environment

In [None]:
#initialise environment
min_x, max_x =  -0.110 - 0.175,   -0.110 + 0.175
min_y, max_y =   0.535 - 0.175,    0.535 + 0.175
min_z, max_z =               0,              0.4 

workspace_lim = np.asarray([[min_x, max_x], 
                            [min_y, max_y],
                            [min_z, max_z]])

print(f"workspace space: \n{workspace_lim}")

obj_dir = 'objects/blocks/'
N_obj   = 4

env = Env(obj_dir, N_obj, workspace_lim)

# Test Environment Reset

In [None]:
env.reset(reset_obj = False)

# Initialise Agent

In [None]:
agent = Agent(env, N_batch = 8)

# Check Guidance

In [None]:
agent.is_debug = True
delta_move = agent.env.push_guidance_generation(max_move = 0.05)

# item_poses = agent.env.update_item_pose()
# print(f'item_pose: {item_poses[item_ind][0:3]}')

gripper_pos = np.array([-0.11122626281611381, 0.4855598136140757, 0.2684023847637833])
for i in range(len(delta_move)):
    gripper_pos += np.array(delta_move[i][0:3])
    print(f'gripper_pos: {gripper_pos}, type: {delta_move[i][-1]}')

In [14]:
agent.interact_by_guidance(max_episode = 1, grasp_guidance = False)

# Check Encoder

In [None]:
summary(agent.encoder, input_size=(1, 128, 128))

In [None]:
#get color
color_img, depth_img = agent.env.get_rgbd_data()
print(f'dmin: {np.min(depth_img[:])}, dmax: {np.max(depth_img[:])}')

#preprocess data
in_color_img, in_depth_img = agent.preprocess_input(color_img, depth_img)
print(in_color_img.shape)
print(in_depth_img.shape)

#add the extra dimension in the 1st dimension
in_color_img = in_color_img.unsqueeze(0)
in_depth_img = in_depth_img.unsqueeze(0)
print(in_color_img.shape)
print(in_depth_img.shape)

#feed into encoder
with torch.no_grad():
    latent_vector, reconstructed = agent.encoder(in_depth_img)

print(f'dmin: {torch.min(in_depth_img)}, dmax: {torch.max(in_depth_img)}')
print('latent vector shape: ', latent_vector.shape)
print('reconstructed shape: ', reconstructed.shape)

#show depth image
plt.imshow(in_depth_img[0].permute(1,2,0))

# Check Actor

In [None]:
#feed into actor
with torch.no_grad():
    a, a_type, z, normal, a_type_probs = agent.actor.get_actions(latent_vector)
    
print(f"action: {a}, action_type: {a_type}")

# Check Critic

In [None]:
#feed into actor
with torch.no_grad():

    #compute one hot vector
    a_type_onehot = torch.nn.functional.one_hot(a_type.long(), num_classes = 3).float()

    q1 = agent.critic1(state = latent_vector, action = a, action_type = a_type_onehot)
    q2 = agent.critic2(state = latent_vector, action = a, action_type = a_type_onehot)

    tq1 = agent.critic1_target(state = latent_vector, action = a, action_type = a_type_onehot)
    tq2 = agent.critic2_target(state = latent_vector, action = a, action_type = a_type_onehot)

print(f"q1: {q1}, q2: {q2}, tq1: {tq1}, tq2: {tq2}")

# Testing raw data and preprocess input

In [None]:
color_img, depth_img = agent.env.get_rgbd_data()

In [None]:
in_color_img, in_depth_img = agent.preprocess_input(color_img, depth_img)

In [None]:
print(in_color_img.shape)
print(in_depth_img.shape)

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(in_depth_img.permute((1,2,0)))
ax[1].imshow(in_color_img.permute((1,2,0)))
plt.show()

# Test interact

In [None]:
agent.interact()