In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import warnings
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface import data
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose
from diffusion_edf.agent import DiffusionEdfAgent

torch.set_printoptions(precision=4, sci_mode=False)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
import torch_cluster, torch_scatter
from edf_interface.utils.collision_utils import _pcd_energy, check_pcd_collision
from edf_interface.data.pcd_utils import transform_points

In [3]:
device = 'cuda:0'
task_type = 'place'
config_root_dir = 'configs/sapien'
testset = data.DemoDataset(dataset_dir='demo/sapien_demo_20230625')

In [4]:
idx = 0
demo = testset[idx][0 if task_type == 'pick' else 1]
scene_pcd, grasp_pcd, target_poses = demo.scene_pcd, demo.grasp_pcd, demo.target_poses

In [5]:
x = scene_pcd.points
y = torch.stack([pcd.points for pcd in grasp_pcd.transformed(target_poses)], dim=0)

In [6]:
energy, grad = _pcd_energy(x, y, cutoff_r=0.05, eps = 0.001, max_num_neighbor=100, cluster_method='knn')
energy, grad

(tensor([29598.8633]),
 tensor([[  38627.4297,  137114.0938,  -15166.8398,  144367.3906,  -51188.8672,
          -114749.8594]]))

In [None]:
torch.set_printoptions(precision=4, sci_mode=False)
dt = 0.001
lie = torch.eye(6) * dt
lie

In [None]:
for idx in range(6):
    T = data.se3._exp_map(lie[idx].unsqueeze(0))
    y_new = transform_points(points=y, Ts=T, batched_pcd=True)
    energy_new, grad_new = _pcd_energy(x, y_new, cutoff_r=0.05, eps = 0.001, max_num_neighbor=100, cluster_method='knn')
    num_grad = (energy_new - energy) / dt
    print(f"analytic_grac: {grad[0,idx].item()} || num_grad: {num_grad.item()}")

In [None]:

def _optimize_pcd_collision_once(x: torch.Tensor, 
                                 y: torch.Tensor, 
                                 dt: float, 
                                 cutoff_r: float, 
                                 max_num_neighbors: int = 100,
                                 eps: float = 0.01,
                                 cluster_method: str = 'knn'):
    assert x.ndim == 2 and x.shape[-1] == 3, f"{x.shape}" # (nX, 3)
    if y.ndim == 2:
        y = y.unsqueeze(0)
    assert y.ndim == 3 and y.shape[-1] == 3, f"{y.shape}" # (nPose, nY, 3)
    n_poses, n_y_points = y.shape[:2]

    energy, grad = _pcd_energy(x, y, cutoff_r=cutoff_r, eps = eps, max_num_neighbor=max_num_neighbors, cluster_method=cluster_method)
    # done = torch.isclose(energy, torch.zeros_like(energy))

    # disp = -grad / (grad.norm() + eps) * dt
    grad = grad * (cutoff_r**3)
    disp = -grad * dt
    print(grad)
    disp_pose = data.se3._exp_map(disp) # (n_poses, 7)

    return disp_pose, transform_points(y, disp_pose,batched_pcd=True), energy

In [None]:
# disp_pose, y_new, energy = _optimize_pcd_collision_once(x=x, y=y, dt=0.0001, cutoff_r=0.05)

In [None]:
poses = [target_poses.poses]
y_new = y
for i in range(30):
    disp_pose, y_new, energy = _optimize_pcd_collision_once(x=x, y=y_new, dt=0.001, cutoff_r=0.03)
    poses.append(data.se3._multiply(disp_pose, poses[-1])) ## TODO: Make it right lie-deriv
poses = torch.cat(poses, dim=0)

In [None]:
data.TargetPoseDemo(scene_pcd=scene_pcd,grasp_pcd=grasp_pcd, target_poses=data.SE3(poses=poses)).show()

In [None]:
demo.show(width=600,height=600)

In [None]:
data.PointCloud.merge(scene_pcd, grasp_pcd.new(points=y_new[0])).show()

In [None]:
for _ in range()

In [None]:
energy, grad = _pcd_energy(x,y,cutoff_r=0.05, eps = 0.001)

In [None]:
with open(os.path.join(config_root_dir, 'agent.yaml')) as f:
    model_kwargs_list = yaml.load(f, Loader=yaml.FullLoader)['model_kwargs'][f"{task_type}_models_kwargs"]

with open(os.path.join(config_root_dir, 'preprocess.yaml')) as f:
    preprocess_config = yaml.load(f, Loader=yaml.FullLoader)
    unprocess_config = preprocess_config['unprocess_config']
    preprocess_config = preprocess_config['preprocess_config']

agent = DiffusionEdfAgent(
    model_kwargs_list=model_kwargs_list,
    preprocess_config=preprocess_config,
    unprocess_config=unprocess_config,
    device=device
)

# Initialize Input Data and Initial Pose

In [None]:
demo: TargetPoseDemo = testset[0][0 if task_type == 'pick' else 1 if task_type == 'place' else "task_type must be either 'pick' or 'place'"].to(device)
scene_pcd: PointCloud = demo.scene_pcd
grasp_pcd: PointCloud = demo.grasp_pcd
T0 = torch.cat([
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[0., 0., 0.8]], device=device)
], dim=-1)
Ts_init = SE3(poses=T0).to(device)


In [None]:
Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
                                              N_steps_list = [[500, 500], [500, 1000]],
                                              timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
                                              temperature_list = [1., 1.],)

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show()