In [None]:
import os
import argparse
from typing import Tuple, Dict, Optional, Union, Any

import torch
from torchvision.transforms import Compose

from edf.pc_utils import draw_geometry, create_o3d_points, get_plotly_fig
from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, DemoSeqDataset, gzip_load
from edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from edf.agent import PickAgent, PlaceAgent


torch.set_printoptions(precision= 3, sci_mode=False, linewidth=120)

# 1. Define utility functions for visualization

In [None]:
def get_raw_pointcloud(**kwargs) -> Tuple[PointCloud, PointCloud]:

    ################### Write your custom codes here ###################
    dir, idx, pick_or_place = kwargs['dir'], kwargs['idx'], kwargs['pick_or_place']

    demos = DemoSeqDataset(dataset_dir="demo/test_demo", annotation_file="data.yaml")
    demo: DemoSequence = demos[idx]
    if pick_or_place == 'pick':
        demo: TargetPoseDemo = demo[0]
    elif pick_or_place == 'place':
        demo: TargetPoseDemo = demo[1]
    else:
        raise ValueError(f"Wrong value for pick_or_place argument: {pick_or_place}")

    scene_pcd: PointCloud = demo.scene_pc
    grasp_pcd: PointCloud = demo.grasp_pc
    target_pose: SE3 = demo.target_poses
    ####################################################################

    return scene_pcd, grasp_pcd


def visualize(scene_pcd: PointCloud, grasp_pcd: PointCloud, pose: SE3, sampled_poses: Optional[SE3] = None, query: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
    
    grasp_pl = grasp_pcd.plotly(point_size=1.0, name="grasp")
    grasp_geometry = [grasp_pl]
    if query is not None:
        query_points, query_attention = query
        query_opacity = query_attention ** 1
        query_pl = PointCloud.points_to_plotly(pcd=query_points, point_size=15.0, opacity=query_opacity / query_opacity.max())#, custom_data={'attention': query_attention.cpu()})
        grasp_geometry.append(query_pl)
    fig_grasp = get_plotly_fig("Grasp")
    fig_grasp = fig_grasp.add_traces(grasp_geometry)



    placement_geometry = []
    best_sample_pcd = PointCloud.merge(scene_pcd, grasp_pcd.transformed(pose)[0])
    best_sample_pl = best_sample_pcd.plotly(point_size=1.0)
    placement_geometry.append(best_sample_pl)
    if sampled_poses is not None:
        sample_pl = PointCloud.points_to_plotly(pcd=sampled_poses.points, point_size=7.0, colors=[0.2, 0.5, 0.8])
        placement_geometry.append(sample_pl)
    fig_sample = get_plotly_fig("Sampled Placement")
    fig_sample = fig_sample.add_traces(placement_geometry)



    return fig_grasp, fig_sample

# 2. Load and warm-up model

In [None]:
device = 'cuda:0'
# device = 'cpu'
unit_len = 0.01
##### Initialize Place Agent #####
place_agent_config_dir = "config/agent_config/place_agent.yaml"
place_agent_param_dir = "checkpoint/mug_10_demo/place/model_iter_600.pt"
max_N_query_place = 3
langevin_dt_place = 0.001

place_agent = PlaceAgent(config_dir=place_agent_config_dir, 
                         device = device,
                         max_N_query = max_N_query_place, 
                         langevin_dt = langevin_dt_place).requires_grad_(False)

place_agent.load(place_agent_param_dir, strict=False)
place_agent.warmup(warmup_iters=10, N_poses=100, N_points_scene=1500, N_points_grasp=900)


scene_proc_fn = Compose([Rescale(rescale_factor=1/unit_len),
                         Downsample(voxel_size=1.7, coord_reduction="average"),
                         NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
grasp_proc_fn = Compose([
                         Rescale(rescale_factor=1/unit_len),
                         Downsample(voxel_size=1.4, coord_reduction="average"),
                         NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
recover_scale = Rescale(rescale_factor=unit_len)

# 3. Inference

### 3.1 Get Point Clouds

In [None]:
file_idx = 0
scene_raw, grasp_raw = get_raw_pointcloud(dir='demo/test_demo', idx=file_idx, pick_or_place='place')

### 3.2 Inference: Sample grasp poses

In [None]:
scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

T_seed = 100
place_policy = 'sorted'
place_mh_iter = 1000
place_langevin_iter = 300
place_dist_temp = 1.
place_policy_temp = 1.
place_optim_iter = 100
place_optim_lr = 0.005
place_query_temp = 1.

Ts, edf_outputs, logs = place_agent.forward(scene=scene_proc, T_seed=T_seed, grasp=grasp_proc, policy = place_policy, mh_iter=place_mh_iter, langevin_iter=place_langevin_iter, 
                                            temperature=place_dist_temp, policy_temperature=place_policy_temp, optim_iter=place_optim_iter, optim_lr=place_optim_lr, query_temperature=place_query_temp)

place_poses = recover_scale(SE3(Ts.cpu()))

# 4. Visualize results

In [None]:
fig_grasp, fig_sample = visualize(scene_pcd=scene_raw, grasp_pcd=grasp_raw, pose=place_poses[0], sampled_poses=place_poses, query=(edf_outputs['query_points'] * unit_len, edf_outputs['query_attention']))

In [None]:
fig_sample.show()

In [None]:
fig_grasp.show()