In [None]:
from ros_edf.ros_interface import EdfRosInterface
from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, DemoSeqDataset, save_demos
from edf.pc_utils import optimize_pcd_collision, draw_geometry
from edf.preprocess import Rescale, NormalizeColor, Downsample, ApplySE3
from edf.agent import PickAgent, PlaceAgent

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

torch.set_printoptions(precision= 3, sci_mode=False, linewidth=120)

# Initialize EDF

In [None]:
device = 'cuda:0'
# device = 'cpu'
unit_len = 0.01

T_eg = SE3([ 0.707,  0.000,  0.000,  0.707,  0.000,  0.000, 0.150])

In [None]:
pick_agent_config_dir = "config/agent_config/pick_agent.yaml"
pick_agent_param_dir = "checkpoint/mug_10_demo/pick/model_iter_600.pt"
max_N_query_pick = 1
langevin_dt_pick = 0.001

pick_agent = PickAgent(config_dir=pick_agent_config_dir, 
                       device = device,
                       max_N_query = max_N_query_pick, 
                       langevin_dt = langevin_dt_pick).requires_grad_(False)

pick_agent.load(pick_agent_param_dir)

In [None]:
place_agent_config_dir = "config/agent_config/place_agent.yaml"
place_agent_param_dir = "checkpoint/mug_10_demo/place/model_iter_600.pt"
max_N_query_place = 3
langevin_dt_place = 0.001

place_agent = PlaceAgent(config_dir=place_agent_config_dir, 
                         device = device,
                         max_N_query = max_N_query_place, 
                         langevin_dt = langevin_dt_place).requires_grad_(False)

place_agent.load(place_agent_param_dir, strict=False)

In [None]:
scene_proc_fn = Compose([Rescale(rescale_factor=1/unit_len),
                         Downsample(voxel_size=1.7, coord_reduction="average"),
                         NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
grasp_proc_fn = Compose([
                         Rescale(rescale_factor=1/unit_len),
                         Downsample(voxel_size=1.4, coord_reduction="average"),
                         NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
recover_scale = Rescale(rescale_factor=unit_len)

# Initialize Robot Interface

In [None]:
env_interface = EdfRosInterface(reference_frame = "scene")
env_interface.reset()
env_interface.moveit_interface.arm_group.set_planning_time(seconds=5)

# Pick

### Observe

In [None]:
grasp_raw = env_interface.observe_eef(obs_type = 'pointcloud', update = True)
scene_raw = env_interface.observe_scene(obs_type = 'pointcloud', update = True)
scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

### Sample Pick Pose

In [None]:
T_seed = 100
pick_policy = 'sorted'
pick_mh_iter = 1000
pick_langevin_iter = 300
pick_dist_temp = 1.
pick_policy_temp = 1.
pick_optim_iter = 100
pick_optim_lr = 0.005

Ts, edf_outputs, logs = pick_agent.forward(scene=scene_proc, T_seed=T_seed, policy = pick_policy, mh_iter=pick_mh_iter, langevin_iter=pick_langevin_iter, 
                                            temperature=pick_dist_temp, policy_temperature=pick_policy_temp, optim_iter=pick_optim_iter, optim_lr=pick_optim_lr)

In [None]:
pick_poses = recover_scale(SE3(Ts[0].cpu()))
pick_poses = SE3.multiply(pick_poses, T_eg.inv())
pre_pick_poses = SE3.multiply(pick_poses, SE3([1., 0., 0., 0., 0., 0., -0.1]))

In [None]:
env_interface.moveit_interface.plan_pose_waypoints(positions=)

In [None]:
pick_poses.poses[:,-1] = 0.3

In [None]:
env_interface.move_to_target_pose(poses=pick_poses)

In [None]:
env_interface.move_cartesian(poses=pick_poses, cartesian_step=0.01, cspace_step_thr=10, avoid_collision=False)

In [None]:
env_interface.clear_octomap()

In [None]:
result = env_interface.pick(target_poses=pick_poses)
pre_grasp_results, grasp_pose, grasp_result, post_grasp_results, final_pose = result
pick_success = final_pose is not None

In [None]:
env_interface.release()