In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import warnings
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface import data
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose
from diffusion_edf.agent import DiffusionEdfAgent

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
import torch_cluster, torch_scatter
from edf_interface.utils.collision_utils import _pcd_energy, check_pcd_collision, _se3_adjoint_lie_grad
from edf_interface.data.pcd_utils import transform_points

In [None]:
device = 'cuda:0'
task_type = 'place'
config_root_dir = 'configs/sapien'
testset = data.DemoDataset(dataset_dir='demo/sapien_demo_20230625')

In [None]:
idx = 0
demo = testset[idx][0 if task_type == 'pick' else 1]
scene_pcd, grasp_pcd, target_poses = demo.scene_pcd, demo.grasp_pcd, demo.target_poses

In [None]:
x = scene_pcd.points
# y = torch.stack([pcd.points for pcd in grasp_pcd.transformed(target_poses)], dim=0)
y = grasp_pcd.points.unsqueeze(0)
Ts = target_poses.poses

In [None]:
cutoff_r = 0.03

energy, grad = _pcd_energy(x, transform_points(y, Ts, batched_pcd=True), cutoff_r=cutoff_r, eps = 0.001, max_num_neighbor=100, cluster_method='knn')
adj_grad = _se3_adjoint_lie_grad(target_poses.poses, grad)
energy, grad, adj_grad

In [None]:
torch.set_printoptions(precision=4, sci_mode=False)
dt = 0.001
lie = torch.eye(6) * dt
lie

In [None]:
for idx in range(6):
    new_Ts = data.se3._multiply(data.se3._exp_map(lie[idx].unsqueeze(0)), Ts)
    y_new = transform_points(points=y, Ts=new_Ts, batched_pcd=True)
    energy_new, grad_new = _pcd_energy(x, y_new, cutoff_r=cutoff_r, eps = 0.001, max_num_neighbor=100, cluster_method='knn')
    num_grad = (energy_new - energy) / dt
    print(f"analytic_grad: {grad[0,idx].item()} || num_grad: {num_grad.item()}")

In [None]:
for idx in range(6):
    new_Ts = data.se3._multiply(Ts, data.se3._exp_map(lie[idx].unsqueeze(0)))
    y_new = transform_points(points=y, Ts=new_Ts, batched_pcd=True)
    energy_new, grad_new = _pcd_energy(x, y_new, cutoff_r=cutoff_r, eps = 0.001, max_num_neighbor=100, cluster_method='knn')
    num_grad = (energy_new - energy) / dt
    print(f"analytic_grad: {adj_grad[0,idx].item()} || num_grad: {num_grad.item()}")

In [None]:
from edf_interface.data import transforms

In [None]:
@torch.jit.script
def _se3_adjoint_lie_grad(Ts: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
    """_summary_

    Args:
        Ts (torch.Tensor): (..., 7), (qw, qx, qy, qz, x, y, z)
        grad (torch.Tensor): (..., 6), (rx, ry, rz, vx, vy, vz)

    Returns:
        adjoint_grad (torch.Tensor): (..., 6), (rx, ry, rz, vx, vy, vz)

    Note:
    L_v f(g_0 g x) = L_{[Ad_g0]v} f(g g_0 x)
    => Grad_{g} f(g_0 g x) = Grad_{g} [Ad_g0]^{Transpose} f(g g_0 x)
    Note that gradient takes the transpose of adjoint matrix!!
    [Ad_T]^{Transpose} = [
        [R^{-1},   -R^{-1} skew(p)],
        [     0,        R^{-1}    ]
    ]
    """
    assert Ts.shape[-1] == 7, f"{Ts.shape}"
    assert grad.shape[-1] == 6, f"{grad.shape}"
    assert Ts.shape[:1] == grad.shape[:1], f"{Ts.shape}, {grad.shape}"

    qinv = transforms.quaternion_invert(Ts[..., :4]) # (..., 4)
    adj_grad_R = grad[..., :3] - torch.cross(Ts[..., 4:], grad[..., 3:]) # (..., 3)
    adj_grad_R = transforms.quaternion_apply(qinv, adj_grad_R) # (..., 3)
    adj_grad_v = transforms.quaternion_apply(qinv, grad[..., 3:]) # (..., 3)
    
    adj_grad = torch.cat([adj_grad_R, adj_grad_v], dim=-1) # (..., 6)

    return adj_grad

In [None]:
# @torch.jit.script
def _optimize_pcd_collision_once(x: torch.Tensor, 
                                 y: torch.Tensor, 
                                 Ts: torch.Tensor,
                                 dt: float, 
                                 cutoff_r: float, 
                                 max_num_neighbors: int = 100,
                                 eps: float = 0.01,
                                 cluster_method: str = 'knn'):
    assert x.ndim == 2 and x.shape[-1] == 3, f"{x.shape}" # (nX, 3)
    assert y.ndim == 3 and y.shape[-1] == 3, f"{y.shape}" # (nPose, nY, 3)
    assert Ts.ndim == 2 and Ts.shape[-1] == 7, f"{Ts.shape}" # (nPose, 7)
    assert len(Ts) == len(y), f"{Ts.shape}, {y.shape}"
    n_poses, n_y_points = y.shape[:2]

    Ty = transform_points(y, Ts, batched_pcd=True) # (nPose, nY, 3)
    energy, grad = _pcd_energy(
        x=x, 
        y=Ty, 
        cutoff_r=cutoff_r, 
        eps = eps, 
        max_num_neighbor=max_num_neighbors, 
        cluster_method=cluster_method
    ) # (nPose,), (nPose, 6)
    assert isinstance(grad, torch.Tensor)
    grad = _se3_adjoint_lie_grad(Ts, grad) # (nPose, 6)

    # disp = -grad / (grad.norm() + eps) * dt
    grad = grad * (torch.tensor([1., 1., 1., cutoff_r, cutoff_r, cutoff_r], device=grad.device, dtype=grad.dtype))
    disp = -grad * dt * cutoff_r
    disp_pose = data.se3._exp_map(disp) # (n_poses, 7)

    new_pose = data.se3._multiply(Ts, disp_pose)

    # done = torch.isclose(energy, torch.zeros_like(energy))

    return new_pose, energy

In [None]:
# disp_pose, y_new, energy = _optimize_pcd_collision_once(x=x, y=y, dt=0.0001, cutoff_r=0.05)

In [None]:
poses = [target_poses.poses]
for i in range(30):
    new_pose, energy = _optimize_pcd_collision_once(x=scene_pcd.points, y=grasp_pcd.points.unsqueeze(0), Ts=poses[-1], dt=0.00003, cutoff_r=0.03)
    poses.append(new_pose)
poses = torch.cat(poses, dim=0)

In [None]:
data.TargetPoseDemo(scene_pcd=scene_pcd,grasp_pcd=grasp_pcd, target_poses=data.SE3(poses=poses)).show()

In [None]:
asdf

In [None]:
demo.show(width=600,height=600)

In [None]:
data.PointCloud.merge(scene_pcd, grasp_pcd.new(points=y_new[0])).show()

In [None]:
for _ in range()

In [None]:
energy, grad = _pcd_energy(x,y,cutoff_r=0.05, eps = 0.001)

In [None]:
with open(os.path.join(config_root_dir, 'agent.yaml')) as f:
    model_kwargs_list = yaml.load(f, Loader=yaml.FullLoader)['model_kwargs'][f"{task_type}_models_kwargs"]

with open(os.path.join(config_root_dir, 'preprocess.yaml')) as f:
    preprocess_config = yaml.load(f, Loader=yaml.FullLoader)
    unprocess_config = preprocess_config['unprocess_config']
    preprocess_config = preprocess_config['preprocess_config']

agent = DiffusionEdfAgent(
    model_kwargs_list=model_kwargs_list,
    preprocess_config=preprocess_config,
    unprocess_config=unprocess_config,
    device=device
)

# Initialize Input Data and Initial Pose

In [None]:
demo: TargetPoseDemo = testset[0][0 if task_type == 'pick' else 1 if task_type == 'place' else "task_type must be either 'pick' or 'place'"].to(device)
scene_pcd: PointCloud = demo.scene_pcd
grasp_pcd: PointCloud = demo.grasp_pcd
T0 = torch.cat([
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[0., 0., 0.8]], device=device)
], dim=-1)
Ts_init = SE3(poses=T0).to(device)


In [None]:
Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
                                              N_steps_list = [[500, 500], [500, 1000]],
                                              timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
                                              temperature_list = [1., 1.],)

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show()