In [None]:
%load_ext dotenv
%dotenv

from ros_edf.ros_interface import EdfRosInterface
from ros_edf.pc_utils import pcd_from_numpy, draw_geometry, reconstruct_surface
from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, save_demos
from edf.pc_utils import check_pcd_collision, optimize_pcd_collision

from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, DemoSeqDataset, save_demos
from edf.pc_utils import optimize_pcd_collision, draw_geometry, check_pcd_collision
from edf.preprocess import Rescale, NormalizeColor, Downsample, ApplySE3
from edf.agent import PickAgent, PlaceAgent

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

torch.set_printoptions(precision= 3, sci_mode=False, linewidth=120)

# Initialize EDF

In [None]:
device = 'cuda:0'
# device = 'cpu'
unit_len = 0.01

##### Initialize Pick Agent #####
pick_agent_config_dir = "config/agent_config/pick_agent.yaml"
pick_agent_param_dir = "checkpoint/mug_10_demo/pick/model_iter_600.pt"
max_N_query_pick = 1
langevin_dt_pick = 0.001

pick_agent = PickAgent(config_dir=pick_agent_config_dir, 
                       device = device,
                       max_N_query = max_N_query_pick, 
                       langevin_dt = langevin_dt_pick).requires_grad_(False)

pick_agent.load(pick_agent_param_dir)
pick_agent.warmup(warmup_iters=10, N_poses=100, N_points_scene=2000)



##### Initialize Place Agent #####
place_agent_config_dir = "config/agent_config/place_agent_dev.yaml"
place_agent_param_dir = "checkpoint/mug_1_demo_dev/place_dev/model_iter_300.pt"
max_N_query_place = 3
langevin_dt_place = 0.001

place_agent = PlaceAgent(config_dir=place_agent_config_dir, 
                         device = device,
                         max_N_query = max_N_query_place, 
                         langevin_dt = langevin_dt_place).requires_grad_(False)

place_agent.load(place_agent_param_dir, strict=False)
place_agent.warmup(warmup_iters=10, N_poses=100, N_points_scene=1500, N_points_grasp=900)



##### Initialize Preprocessing functions #####
scene_proc_fn = Compose([Rescale(rescale_factor=1/unit_len),
                         Downsample(voxel_size=1.7, coord_reduction="average"),
                         NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
grasp_proc_fn = Compose([
                         Rescale(rescale_factor=1/unit_len),
                         Downsample(voxel_size=1.4, coord_reduction="average"),
                         NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
recover_scale = Rescale(rescale_factor=unit_len)

# Initialize Robot Interface

In [None]:
env_interface = EdfRosInterface(reference_frame = "scene", planner_id="AnytimePathShortening")
env_interface.set_planning_time(seconds=5.)
env_interface.reset()

# Pick

### Observe

In [None]:
grasp_raw = env_interface.observe_eef(obs_type = 'pointcloud', update = True)
scene_raw = env_interface.observe_scene(obs_type = 'pointcloud', update = True)

### Sample Pick Pose

In [None]:
##### Preprocess Observations #####
scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

##### Sample Pick Poses #####
T_seed = 100
pick_policy = 'sorted'
pick_mh_iter = 1000
pick_langevin_iter = 300
pick_dist_temp = 1.
pick_policy_temp = 1.
pick_optim_iter = 100
pick_optim_lr = 0.005

Ts, edf_outputs, logs = pick_agent.forward(scene=scene_proc, T_seed=T_seed, policy = pick_policy, mh_iter=pick_mh_iter, langevin_iter=pick_langevin_iter, 
                                            temperature=pick_dist_temp, policy_temperature=pick_policy_temp, optim_iter=pick_optim_iter, optim_lr=pick_optim_lr)

T_eg = SE3([ 0.707,  0.000,  0.000,  0.707,  0.000,  0.000, 0.150])
pick_poses = recover_scale(SE3(Ts.cpu()))
pick_poses = SE3.multiply(pick_poses, T_eg.inv())

##### Infer pre-pick and post-pick poses #####
_, pre_pick_poses = optimize_pcd_collision(x=scene_raw, y=grasp_raw, 
                                           cutoff_r = 0.05, dt=0.01, eps=1., iters=7,
                                           rel_pose=pick_poses)
post_pick_poses = pre_pick_poses

### Execute Pick

In [None]:
for idx in range(10):
    pick_pose, pre_pick_pose, post_pick_pose = pick_poses[idx], pre_pick_poses[idx], post_pick_poses[idx]
    
    colcheck_r = 0.003 # Should be similar to voxel filter size
    col_check = check_pcd_collision(x=scene_raw, y=grasp_raw.transformed(pick_pose)[0], r = colcheck_r)
    print(f"Pick Pose_{idx} collision-free: {not col_check}")
    if not col_check:
        break
    
if not col_check:
    print("Found collision-free pick pose!")
else:
    raise NotImplementedError("No collision-free pick pose found!")

In [None]:
# DEBUG
# draw_geometry([scene_raw] + grasp_raw.transformed(post_pick_pose))

In [None]:
pick_result = env_interface.pick(pre_pick_pose, pick_pose, post_pick_pose)
print(f"Pick result: {pick_result}")
if pick_result == "SUCCESS":
    env_interface.detach()
    env_interface.attach_placeholder() # To avoid collsion with the grasped object
else:
    raise NotImplementedError("Pick failed")

# Observe for Place

In [None]:
# Observe EEF
result = env_interface.move_to_named_target("init")
print(f"Move to End-Effector observation pose: {result}")
if result == 'SUCCESS':
    env_interface.detach()
    grasp_raw = env_interface.observe_eef(obs_type = 'pointcloud', update = True)
    env_interface.attach(obj = grasp_raw)
else:
    raise NotImplementedError(f"Failed to move to End-Effector observation pose.")

# Observe Scene
result = env_interface.move_to_named_target("observe")
print(f"Move to Scene observation pose: {result}")
if result == 'SUCCESS':
    scene_raw = env_interface.observe_scene(obs_type = 'pointcloud', update = True)
else:
    raise NotImplementedError(f"Failed to move to Scene observation pose.")


result = env_interface.move_to_named_target("init")
print(f"Come back to default pose: {result}")
if result == 'SUCCESS':
    pass
else:
    raise NotImplementedError(f"Failed to come back to default pose.")

# Place

### Sample Place Poses

In [None]:
##### Preprocess Observations #####
scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

##### Sample Place Poses #####
T_seed = 100
place_policy = 'sorted'
place_mh_iter = 1000
place_langevin_iter = 300
place_dist_temp = 1.
place_policy_temp = 1.
place_optim_iter = 100
place_optim_lr = 0.005
place_query_temp = 1.

Ts, edf_outputs, logs = place_agent.forward(scene=scene_proc, T_seed=T_seed, grasp=grasp_proc, policy = place_policy, mh_iter=place_mh_iter, langevin_iter=place_langevin_iter, 
                                            temperature=place_dist_temp, policy_temperature=place_policy_temp, optim_iter=place_optim_iter, optim_lr=place_optim_lr, query_temperature=place_query_temp)

place_poses = recover_scale(SE3(Ts.cpu()))

##### Infer pre-place and post-place poses #####
_, pre_place_poses = optimize_pcd_collision(x=scene_raw, y=grasp_raw, 
                                            cutoff_r = 0.03, dt=0.01, eps=1., iters=5,
                                            rel_pose=place_poses)
post_place_poses = place_poses * pick_pose.inv() * pre_pick_pose

### Execute Place

In [None]:
for idx in range(0,10):
    place_pose, pre_place_pose, post_place_pose = place_poses[idx], pre_place_poses[idx], post_place_poses[idx]
    
    colcheck_r = 0.0015 # Should be similar to voxel filter size
    col_check = check_pcd_collision(x=scene_raw, y=grasp_raw.transformed(place_pose)[0], r = colcheck_r)
    print(f"Place Pose_{idx} collision-free: {not col_check}")
    if not col_check:
        break
    
if not col_check:
    print("Found collision-free place pose!")
else:
    raise NotImplementedError("No collision-free place pose found!")

In [None]:
# DEBUG
# draw_geometry([scene_raw] + grasp_raw.transformed(pre_place_pose))

In [None]:
place_result = env_interface.place(pre_place_pose, place_pose, post_place_pose)
print(f"Place result: {place_result}")
if place_result == "SUCCESS":
    env_interface.detach()
    env_interface.release()
else:
    raise NotImplementedError("Place failed")