# Import modules

In [None]:
%load_ext autoreload
%autoreload 2

from typing import *
from numbers import Real
import os, sys, argparse, time

import copy
from copy import deepcopy   
import math
import open3d as o3d
import open3d.visualization as viz
import numpy as np
import torch

from edf_interface import data
from edf_interface.utils import manipulation_utils
from edf_interface.data import preprocess
from edf_interface.data.preprocess import downsample

In [None]:
from typing import Union, Iterable, List, Tuple, Dict, Any

import torch

from edf_interface.pyro import PyroClientBase, remote
from edf_interface import data

class DiffusionEdfClient(PyroClientBase):
    def __init__(self, ip = None, port = None):
        PyroClientBase.__init__(self, service_names='agent', nameserver_host=ip, nameserver_port=port)

    @remote
    def request_trajectories(self, scene_pcd: data.PointCloud, 
                                grasp_pcd: data.PointCloud,
                                current_poses: data.SE3,
                                task_name: str,
                                ) -> Tuple[List[data.SE3], Dict[str, Any]]: ...
    
    def update_system_msg(self, msg: str, **kwargs) -> bool:
        print(msg)
        return True
    
    @remote
    def get_configs(self) -> Dict[str, Any]: ...

    @remote
    def reconfigure(self, name: str, value: Dict[str, Any]) -> bool: ...
    
    @remote
    def denoise(self, scene_pcd: data.PointCloud, 
                    grasp_pcd: data.PointCloud,
                    current_poses: data.SE3,
                    task_name: str,
                    ) -> Tuple[List[data.SE3], Dict[str, Any]]: ...

In [None]:
time.sleep(0.5)
client = DiffusionEdfClient(ip="192.168.0.6")

client.reconfigure(
    name = "pick_diffusion_configs",
    value = {
        "N_steps_list": [[100, 100], [100, 100]],
        "timesteps_list": [[0.02, 0.02], [0.02, 0.02]],
        "temperatures_list": [[1., 1.], [1., 1.]],
        "diffusion_schedules_list": [
                                        [[1., 0.1], [0.1, 0.1]],
                                        [[0.1, 0.03], [0.03, 0.01] ],
                                    ],
        "log_t_schedule": True,
        "time_exponent_temp": 1.0,
        "time_exponent_alpha": 0.5,
    }
)

client.reconfigure(
    name = "pick_trajectory_configs",
    value = {
        "n_steps": 5,
        "approach_len": 0.1
    }
)

client.reconfigure(
    name = "place_diffusion_configs",
    value = {
        "N_steps_list": [[200, 200], [300, 300]],
        "timesteps_list": [[0.04, 0.04], [0.02, 0.02]],
        "temperatures_list": [[1., 1.], [1., 1.]],
        "diffusion_schedules_list": [
                                        [[1., 0.1], [0.1, 0.1]],
                                        [[0.1, 0.03], [0.03, 0.01] ],
                                    ],
        "log_t_schedule": True,
        "time_exponent_temp": 1.0,
        "time_exponent_alpha": 0.5,
    }
)

client.reconfigure(
    name = "place_trajectory_configs",
    value = {
        "n_steps": 1,
        "dt": 0.0001,
        "cutoff_r": 0.03,
        "max_num_neighbors": 2,
        "eps": 0.0001,
        "cluster_method": "knn",
        "voxel_size": 0.03,
        "voxel_coord_reduction": "average"
    }
)


In [None]:
demo = data.TargetPoseDemo.load('fail_case')

In [None]:
n_samples = 100

T0 = data.SE3(
    torch.cat([
        torch.tensor([[1., 0., 0.0, 0.]]),
        torch.tensor([[0., 0., 0.3]])
    ], dim=-1).repeat(n_samples,1)
)

trajectories, info = client.denoise(
    scene_pcd=demo.scene_pcd,
    grasp_pcd=demo.grasp_pcd,
    current_poses=T0,
    task_name='place'
)

In [None]:
info

In [None]:
Ts = trajectories[:, 0].cpu()
Ts = Ts * torch.tensor([1., 1., 1., 1., 0.01, 0.01, 0.01])
Ts = data.SE3(Ts[::10])

In [None]:
data.TargetPoseDemo(
    scene_pcd=preprocess.downsample(demo.scene_pcd, voxel_size=0.01),
    grasp_pcd=preprocess.downsample(demo.grasp_pcd, voxel_size=0.01),
    target_poses=Ts
).show()