In [None]:
import os

from edf.pc_utils import draw_geometry, create_o3d_points, get_plotly_fig
from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, DemoSeqDataset, gzip_load
from edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from edf.agent import PickAgent, PlaceAgent

import numpy as np
import yaml
import plotly as pl
import plotly.express as ple
import open3d as o3d

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

torch.set_printoptions(precision= 3, sci_mode=False, linewidth=120)

In [None]:
agent_config_dir = "config/agent_config/place_agent_dev.yaml"
train_config_dir = "config/train_config/train_place_dev.yaml"
agent_param_dir = "checkpoint/mug_10_demo/place_dev"

In [None]:
train_logs = gzip_load(dir=agent_param_dir, filename=f"trainlog_iter_{1}.gzip")
scene_raw: PointCloud = train_logs['scene_raw']
grasp_raw: PointCloud = train_logs['grasp_raw']
query_points = train_logs['edf_outputs']['query_points']
query_attention = train_logs['edf_outputs']['query_attention']
target_pose = SE3(train_logs['target_T'])
best_pose = SE3(train_logs['best_neg_T'])
sampled_poses= SE3(train_logs['sampled_Ts'])

In [None]:
grasp_pl = grasp_raw.plotly(point_size=1.0, name="grasp")
query_points = train_logs['edf_outputs']['query_points']
query_attention = train_logs['edf_outputs']['query_attention']
query_opacity = query_attention ** 1
query_pl = PointCloud.points_to_plotly(pcd=query_points, point_size=15.0, opacity=query_opacity / query_opacity.max())

In [None]:
target_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed(target_pose)[0])
target_pl = target_pcd.plotly(point_size=1.0)
fig_target = get_plotly_fig("Target Placement")
fig_target = fig_target.add_traces([target_pl])

In [None]:
best_sample_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed(best_pose)[0])
best_sample_pl = best_sample_pcd.plotly(point_size=1.0)
sample_pl = PointCloud.points_to_plotly(pcd=sampled_poses.points, point_size=7.0, colors=[0.2, 0.5, 0.8])
fig_sample = get_plotly_fig("Sampled Placement")
fig_sample = fig_sample.add_traces([best_sample_pl, sample_pl])