In [None]:
import os

from edf.pc_utils import draw_geometry, voxel_filter
from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, DemoSeqDataset
from edf.preprocess import Rescale, NormalizeColor, Downsample
from edf.agent import PickAgent

import numpy as np
import yaml
import plotly as pl
import plotly.express as ple
import open3d as o3d

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

torch.set_printoptions(precision= 3, sci_mode=False, linewidth=120)

In [None]:
device = 'cuda:0'
# device = 'cpu'
unit_len = 0.01

load_transforms = Compose([Rescale(rescale_factor=1/unit_len),
                          ])
transforms = Compose([Downsample(voxel_size=1.7, coord_reduction="average"),
                      NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5])),
                     ])
trainset = DemoSeqDataset(dataset_dir="demo/test_demo", annotation_file="data.yaml", load_transforms = load_transforms, transforms=transforms, device=device)
# train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda xs:{'processed': [x['processed'] for x in xs], 'raw': [x['raw'] for x in xs]}) 
train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda x:x)

In [None]:
pick_agent_config_dir = "config/agent_config/pick_agent.yaml"
pick_agent_param_dir = "checkpoint/mug_10_demo/pick/model_iter_600.pt"
max_N_query_pick = 1
langevin_dt_pick = 0.001

pick_agent = PickAgent(config_dir=pick_agent_config_dir, 
                       device = device,
                       max_N_query = max_N_query_pick, 
                       langevin_dt = langevin_dt_pick)

pick_agent.load(pick_agent_param_dir)

In [None]:
query_points = torch.tensor([[0.0, 0.0, 15.]])
pick_agent.query_model.query_points = query_points.to(device)

In [None]:
for train_batch in train_dataloader:
    for data in train_batch:
        demo_seq_raw: DemoSequence = data['raw']
        demo_seq: DemoSequence = data['processed']
        break
    break

In [None]:
scene_pc = demo_seq[0].scene_pc
scene_pc_raw = demo_seq_raw[0].scene_pc
grasp_pc_raw = demo_seq_raw[0].grasp_pc

In [None]:
# query_points_visual = []
# for query_point in pick_agent.query_model.query_points.cpu():
#     mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.5)
#     mesh_sphere.compute_vertex_normals()
#     mesh_sphere.paint_uniform_color([0.7, 0.1, 0.1])
#     mesh_sphere.translate(query_point)
#     query_points_visual.append(mesh_sphere)

# draw_geometry([grasp_pc_raw] + query_points_visual)

In [None]:
T_seed = 100
pick_policy = 'sorted'
pick_mh_iter = 1000
pick_langevin_iter = 300
pick_dist_temp = 1.
pick_policy_temp = 1.
pick_optim_iter = 100
pick_optim_lr = 0.005

In [None]:
Ts, edf_outputs, logs = pick_agent.forward(pc=scene_pc, T_seed=T_seed, policy = pick_policy, mh_iter=pick_mh_iter, langevin_iter=pick_langevin_iter, 
                                            temperature=pick_dist_temp, policy_temperature=pick_policy_temp, optim_iter=pick_optim_iter, optim_lr=pick_optim_lr)

In [None]:
# Ts = torch.tensor([[ 5.6906e-01,  2.5444e-01,  5.9052e-01,  5.1256e-01, -2.6001e+00,
#                         1.2955e+01, -3.2376e+00],
#                         [ 5.5517e-01, -5.3268e-02,  8.9534e-02,  8.2519e-01, -5.2198e+00,
#                         -5.1792e+00, -1.1604e+01],
#                         [ 5.7169e-01, -5.0451e-02,  6.9902e-02,  8.1593e-01, -5.0391e+00,
#                         -4.6316e+00, -1.0746e+01],
#                         [ 1.0343e-01,  6.9640e-03, -6.2160e-01,  7.7645e-01, -4.8065e+00,
#                         -6.4388e+00, -3.0775e+00],
#                         [ 6.7271e-01, -5.9370e-01,  1.4978e-01, -4.1539e-01,  1.0469e+01,
#                         -1.8646e-01, -3.4648e+00],
#                         [ 6.7084e-01, -5.9419e-01,  1.5086e-01, -4.1732e-01,  1.0421e+01,
#                         -1.3346e-01, -3.4283e+00],
#                         [ 7.5507e-01,  3.6726e-01,  4.3431e-01,  3.2615e-01,  5.1010e+00,
#                         1.8190e+01, -6.4841e+00],
#                         [ 6.5561e-01, -9.3542e-02,  9.3563e-02,  7.4342e-01, -5.0050e+00,
#                         -4.8485e+00, -1.1373e+01],
#                         [ 6.8008e-01, -2.1937e-03,  4.1542e-02, -7.3196e-01,  4.1787e+00,
#                         1.0144e+00, -9.0048e+00],
#                         [ 6.7905e-01, -7.4067e-03,  3.8115e-02, -7.3306e-01,  4.1386e+00,
#                         8.3174e-01, -9.0079e+00],
#                         [ 6.9599e-01, -1.4323e-02, -3.2042e-02,  7.1720e-01, -4.4354e+00,
#                         -1.9948e-01, -9.8192e+00],
#                         [ 9.3583e-01,  2.2031e-01,  9.0131e-02,  2.5992e-01, -4.5382e+00,
#                         1.0393e+01, -1.3425e+01],
#                         [ 6.9717e-01, -1.7813e-02, -3.2025e-02,  7.1596e-01, -4.3702e+00,
#                         -2.6679e-01, -9.7965e+00],
#                         [ 2.7117e-01,  1.9802e-01, -1.7277e-01, -9.2596e-01,  1.0302e+01,
#                         -8.5894e+00, -1.2280e+01],
#                         [ 9.3703e-01,  2.1456e-01,  9.1950e-02,  2.5977e-01, -4.5465e+00,
#                         1.0223e+01, -1.3487e+01],
#                         [ 9.3975e-01,  2.0704e-01,  9.3043e-02,  2.5563e-01, -4.5019e+00,
#                         1.0040e+01, -1.3580e+01],], device=pick_agent.device)

In [None]:
from pytorch3d.transforms import quaternion_apply, quaternion_multiply, axis_angle_to_quaternion

In [None]:
q, t = Ts[..., :4], Ts[..., 4:]
q = quaternion_multiply(axis_angle_to_quaternion(axis_angle=torch.tensor([0, 0, torch.pi/2], device=q.device)), q)
Ts = torch.cat((q,t), dim=-1)

In [None]:
grasps = grasp_pc_raw.transformed(Ts[:10].to(device))

In [None]:
draw_geometry([scene_pc] + grasps)

In [None]:
# query_points_visual = []
# for T in Ts:
#     mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.5)
#     mesh_sphere.compute_vertex_normals()
#     mesh_sphere.paint_uniform_color([0.7, 0.1, 0.1])
#     mesh_sphere.translate(T[4:].cpu())
#     query_points_visual.append(mesh_sphere)

# draw_geometry([scene_pc] + query_points_visual)