In [None]:
from typing import List, Optional
import pickle

from beartype import beartype
import torch

from edf_interface.pyro import PyroServer, expose, PyroClientBase
from edf_interface.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoDataset

# Initialize simple demonstration server

In [None]:
dataset = DemoDataset('demo/test_unseen_demo')
target_task = 'pick'
demo = dataset[3]
pick_demo, place_demo = demo

if target_task == 'pick':
    target_demo = pick_demo
elif target_task == 'place':
    target_demo = place_demo
else:
     raise ValueError(f"Unknown target task: {target_task}")

scene_pcd, grasp_pcd, target_poses = target_demo.scene_pcd, target_demo.grasp_pcd, target_demo.target_poses
current_poses = SE3(
    torch.cat([
        torch.tensor([[1.0, 0.0, 0.0, 0.0]]),
        torch.tensor([[-0.3, -0.3, 0.3]])
    ], dim=-1)
)

@beartype
class EnvService():
    def __init__(self):
        pass

    @expose
    def get_current_poses(self) -> SE3:
        return current_poses

    @expose
    def observe_scene(self) -> PointCloud:
        return scene_pcd

    @expose
    def observe_grasp(self) -> PointCloud:
        return grasp_pcd

    @expose
    def move_se3(self, target_poses: SE3) -> bool:
        print("Target poses received!")
        global current_poses
        current_poses = target_poses[0]
        return True
    
server = PyroServer(server_name='env', 
                    init_nameserver=None) # initialize nameserver if cannot find existing one.
server.register_service(service=EnvService())
server.run(nonblocking=True)
# server.close()

# Initialize EDF Client

In [None]:
@beartype
class DiffusionEdfClient(PyroClientBase):
    def __init__(self, env_server_name: str = 'env',
                 agent_sever_name: str = 'agent'):
        super().__init__(service_names=[env_server_name, agent_sever_name])

    def get_current_poses(self, **kwargs) -> SE3: ...
    
    def observe_scene(self, **kwargs) -> PointCloud: ...
    
    def observe_grasp(self, **kwargs) -> PointCloud: ...

    def move_se3(self, target_poses: SE3, **kwargs) -> bool: ...

    def infer_target_poses(self, scene_pcd: PointCloud, 
                           task_name: str,
                           grasp_pcd: PointCloud,
                           current_poses: PointCloud, 
                           N_steps_list: List[List[int]],
                           timesteps_list: List[List[float]],
                           temperature_list: List[float]) -> SE3: ...

client = DiffusionEdfClient(env_server_name='env', agent_sever_name='agent')

In [None]:
Ts = client.get_current_poses()
scene_pcd = client.observe_scene()
grasp_pcd = client.observe_grasp()
current_state = TargetPoseDemo(target_poses=Ts, scene_pcd=scene_pcd, grasp_pcd=grasp_pcd)

# print(current_state)
current_state.show(width=400, height=400, point_size=2.)

In [None]:
target_Ts = client.infer_target_poses(scene_pcd=scene_pcd, 
                                      grasp_pcd=grasp_pcd, 
                                      current_poses = Ts,
                                      N_steps_list = [[500, 500], [500, 1000]],
                                      timesteps_list = [[0.02, 0.02], [0.02, 0.05]],
                                      temperature_list = [1., 1.],
                                      task_name=target_task)
print(target_Ts)

In [None]:
client.move_se3(target_poses=target_Ts)

In [None]:
Ts = client.get_current_poses()
scene_pcd = client.observe_scene()
grasp_pcd = client.observe_grasp()
current_state = TargetPoseDemo(target_poses=Ts, scene_pcd=scene_pcd, grasp_pcd=grasp_pcd)

# print(current_state)
current_state.show(width=400, height=400, point_size=2.)