In [None]:
from typing import *
import time

from beartype import beartype
import numpy as np
import torch

from edf_interface.web import DashEdfDemoServer, TERMINATE, RESET, SUCCESS
from edf_interface import data
from edf_interface.pyro import PyroClientBase

In [None]:
@beartype
class ExampleClient(PyroClientBase):
    def __init__(self):
        super().__init__(service_names=['env'])

    def get_current_pose(self) -> data.SE3: ...
    
    def observe_scene(self, voxel_size: float = 0.01, update_planner = True) -> data.PointCloud: ...
    
    def observe_grasp(self, voxel_size: float = 0.01, update_planner = False) -> data.PointCloud: ...

    def move_se3(self, target_poses: data.SE3, plan: bool = True, block=True) -> bool: ...

    def teleport_se3(self, target_poses: data.SE3) -> bool: ...

    def control_gripper(self, gripper_val: float, duration: float = 1., block=True) -> bool: ...

    def attach(self, obj_name: str) -> bool: ...

    def detach(self, obj_name: str = '*') -> bool: ...

    def reset(self) -> bool: ...

    def get_scene_ranges(self) -> torch.Tensor: ...

    # def infer_target_poses(self, scene_pcd: PointCloud, 
    #                        task_name: str,
    #                        grasp_pcd: Optional[PointCloud] = None,
    #                        current_poses: Optional[SE3] = None, 
    #                        **kwargs) -> SE3: ...

client = ExampleClient()

In [None]:
scene_ranges = np.array([[-0.23, 0.23],
                         [-0.23, 0.23],
                         [-0.01, 0.4]])
point_size = 3.


demo_server = DashEdfDemoServer(scene_ranges = scene_ranges,
                                name = "dash_edf_demo_server", 
                                point_size = point_size, 
                                host_id = '127.0.0.1',
                                server_debug=False)



pick_checklist = [('collision_check', {'colcheck_r': 0.003})]
place_checklist = [('collision_check', {'colcheck_r': 0.0015})]


def get_pick(scene: data.PointCloud, grasp: data.PointCloud) -> Union[str, data.SE3]:
    # return SE3([0.0, 0.0, 1.0, 0.0, -0.05, 0.0, 0.275])
    demo_server.update_scene_pcd(pcd=scene)
    demo_server.update_grasp_pcd(pcd=grasp)
    user_response = demo_server.get_user_response()
    if user_response == RESET or user_response == TERMINATE:
        return user_response
    elif isinstance(user_response, data.SE3):
        return user_response
    else:
        raise ValueError(f"Unknown user response: {user_response}")


def get_place(scene: data.PointCloud, grasp: data.PointCloud) -> Union[str, data.SE3]:
    # return SE3([0.5000, -0.5000, -0.5000, -0.5000, 0.13, -0.20, 0.32])
    demo_server.update_scene_pcd(pcd=scene)
    demo_server.update_grasp_pcd(pcd=grasp)
    user_response = demo_server.get_user_response()
    if user_response == RESET or user_response == TERMINATE:
        return user_response
    elif isinstance(user_response, data.SE3):
        return user_response
    else:
        raise ValueError(f"Unknown user response: {user_response}")
    
def update_system_msg(msg: str, wait_sec: float = 0.):
    # print(msg)
    demo_server.update_robot_state(msg)
    if wait_sec:
        time.sleep(wait_sec)

def cleanup():
    demo_server.close()

In [None]:
demo_server.run()

In [None]:
def move_robot_near_target(pose: data.SE3, env_interface: EdfRosInterface):
    assert len(pose) == 1

    rel_pos = torch.tensor([-0.7, 0.], device=pose.device, dtype=pose.poses.dtype)
    pos = pose.poses[0,4:6] + rel_pos
    if pos[0] > -0.6:
        pos[0] = -0.6

    env_interface.move_robot_base(pos=pos) # x,y

def check_collision(pose: SE3, 
                    scene: PointCloud, 
                    grasp: PointCloud, 
                    colcheck_r: float # Should be similar to voxel filter size
                    ) -> bool:
    assert len(pose) == 1

    col_check = check_pcd_collision(x=scene, y=grasp.transformed(pose)[0], r = colcheck_r)

    return col_check

def feasibility_check(context: Dict[str, Any], check_list: List[Tuple[str, Dict[str, Any]]]) -> Tuple[str, str]:
    available_check_types = ['collision_check']
    
    feasibility, msg = FEASIBLE, 'FEASIBLE'
    for check in check_list:
        check_type, check_kwarg = check
        assert check_type in available_check_types

        if check_type == 'collision_check':
            col_check = check_collision(pose=context['pose'], 
                                        scene=context['scene'], 
                                        grasp=context['grasp'], **check_kwarg)
            if col_check:
                return INFEASIBLE, 'COLLISION_DETECTED'
            
    return feasibility, msg

def get_pre_post_pick(scene: PointCloud, grasp: PointCloud, pick_poses: SE3) -> Tuple[SE3, SE3]:
    # _, pre_pick_poses = optimize_pcd_collision(x=scene, y=grasp, 
    #                                             cutoff_r = 0.03, dt=0.01, eps=1., iters=50,
    #                                             rel_pose=pick_poses)
    pre_pick_poses = pick_poses * SE3(torch.tensor([1., 0., 0., 0., 0., 0., -0.05], device=pick_poses.device))
    #post_pick_poses = pre_pick_poses
    post_pick_poses = SE3(pick_poses.poses + torch.tensor([0., 0., 0., 0., 0., 0., 0.1], device=pick_poses.device))

    return pre_pick_poses, post_pick_poses


def get_pre_post_place(scene: PointCloud, grasp: PointCloud, place_poses: SE3, pre_pick_pose: SE3, pick_pose: SE3) -> Tuple[SE3, SE3]:
    assert len(pick_pose) == len(pre_pick_pose) == 1

    _, pre_place_poses = optimize_pcd_collision(x=scene, y=grasp, 
                                                cutoff_r = 0.03, dt=0.01, eps=1., iters=5,
                                                rel_pose=place_poses)
    post_place_poses = place_poses * pick_pose.inv() * pre_pick_pose

    return pre_place_poses, post_place_poses


def observe(env_interface, max_try: int, attach: bool) -> bool:
    success = True
    update_system_msg("Move to Observe...")


    # Move to default pose before moving to observation pose
    for _ in range(max_try):
        move_result, _info = env_interface.move_to_named_target("init")
        if move_result == SUCCESS:
            break
        else:
            continue

    # Move to observation pose
    if move_result == SUCCESS:
        env_interface.move_robot_base(pos = torch.tensor([-1.5, 0.]))
        for _ in range(max_try):
            move_result, _info = env_interface.move_to_named_target("observe")
            if move_result == SUCCESS:
                break
            else:
                continue
    
    # Observe
    if move_result == SUCCESS:
        if attach:
            env_interface.detach()
        grasp_raw = env_interface.observe_eef(obs_type = 'pointcloud', update = True)
        if attach:
            env_interface.attach(obj = grasp_raw)
        scene_raw = env_interface.observe_scene(obs_type = 'pointcloud', update = True)


    # Come back to default pose
    if move_result == SUCCESS:
        for _ in range(max_try):
            move_result, _info = env_interface.move_to_named_target("init")
            if move_result == SUCCESS:
                break
            else:
                continue
    if move_result == SUCCESS:
        env_interface.move_robot_base(pos = torch.tensor([-0.7, 0.0]))
    
    if move_result != SUCCESS:
        update_system_msg(f"Cannot Move to Observation Pose ({move_result}). Resetting env...", wait_sec=2.0)
        success = False
        
    return success, (scene_raw, grasp_raw)