In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import warnings
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface.data import PointCloud, SE3, DemoDataset, TargetPoseDemo
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose
from diffusion_edf.agent import DiffusionEdfAgent

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
device = 'cuda:0'
task_type = 'place'
config_root_dir = 'configs/sapien'
testset = DemoDataset(dataset_dir='demo/sapien_demo_20230625')

In [None]:
with open(os.path.join(config_root_dir, 'agent.yaml')) as f:
    model_kwargs_list = yaml.load(f, Loader=yaml.FullLoader)['model_kwargs'][f"{task_type}_models_kwargs"]

with open(os.path.join(config_root_dir, 'preprocess.yaml')) as f:
    preprocess_config = yaml.load(f, Loader=yaml.FullLoader)
    unprocess_config = preprocess_config['unprocess_config']
    preprocess_config = preprocess_config['preprocess_config']

agent = DiffusionEdfAgent(
    model_kwargs_list=model_kwargs_list,
    preprocess_config=preprocess_config,
    unprocess_config=unprocess_config,
    device=device
)

# Initialize Input Data and Initial Pose

In [None]:
demo: TargetPoseDemo = testset[1][0 if task_type == 'pick' else 1 if task_type == 'place' else "task_type must be either 'pick' or 'place'"].to(device)
scene_pcd: PointCloud = demo.scene_pcd
grasp_pcd: PointCloud = demo.grasp_pcd
T0 = torch.cat([
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[0., 0., 1.2]], device=device)
], dim=-1)


N_samples = 10
Ts_init = SE3(poses=T0.repeat(N_samples,1)).to(device)

In [None]:
Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
                                              N_steps_list = [[250, 250], [250, 250]],
                                              timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
                                              temperatures_list = [[1., 1.], [1., 1.]],
                                              log_t_schedule=True,    # Original Behavior: False
                                              time_exponent_temp=1.0, # Original Behavior: 0.5
                                              time_exponent_alpha=0.5, # Original Behavior: 0.5
                                              )

In [None]:
visualization = TargetPoseDemo(
    target_poses=SE3(poses=Ts_out[-1]),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show()

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show()

In [None]:
def sample_model(self, T_seed: torch.Tensor,
            scene_pcd_multiscale: List[FeaturedPoints], 
            grasp_pcd: FeaturedPoints,
            diffusion_schedules: List[Union[
                                    List[float], 
                                    Tuple[float, float]]
                                ],
            N_steps: List[int], 
            timesteps: List[float],
            ang_noise_mult: Union[int, float] = 1.0,
            lin_noise_mult: Union[int, float] = 1.0,
            temperatures = 1.0,
            linear_noise_schedule: bool = True) -> torch.Tensor:
    from diffusion_edf import transforms

    if isinstance(temperatures, int) or isinstance(temperatures, float):
        temperatures = [float(temperatures) for _ in range(len(t_schedule))]
    
    device = T_seed.device
    T = T_seed.clone().detach().type(torch.float64)

    Ts = [T.clone().detach()]

    diffusion_schedules = torch.tensor(diffusion_schedules, device=device, dtype=torch.float64)

    steps = 0
    for n, schedule in enumerate(diffusion_schedules):
        temperature_base = float(temperatures[n])
        t_schedule = torch.logspace(torch.log(schedule[0]), torch.log(schedule[1]), steps = N_steps[n], base=torch.e, device=device, dtype=torch.float64)
        # t_schedule = torch.linspace(schedule[0], schedule[1], steps = N_steps[n], device=device, dtype=torch.float64)

        t_schedule = t_schedule.unsqueeze(-1)
        print(f"{self.__class__.__name__}: sampling with (temperature: {temperature_base} || t_schedule: {schedule})")
        for i in tqdm(range(len(t_schedule))):
            t = t_schedule[i]
            temperature = temperature_base * torch.pow(t,1.0)
            alpha_ang = (self.ang_mult **2) * torch.pow(t,0.5) * timesteps[n]
            alpha_lin = (self.lin_mult **2) * torch.pow(t,0.5) * timesteps[n]

            with torch.no_grad():
                (ang_score_dimless, lin_score_dimless) = self.score_head(Ts=T.view(-1,7).float(), 
                                                                         key_pcd_multiscale=scene_pcd_multiscale,
                                                                         query_pcd=grasp_pcd,
                                                                         time = t.repeat(len(T)).float())
            ang_score_dimless, lin_score_dimless = ang_score_dimless.type(torch.float64), lin_score_dimless.type(torch.float64)
            ang_score = ang_score_dimless / (self.ang_mult * torch.sqrt(t))
            lin_score = lin_score_dimless / (self.lin_mult * torch.sqrt(t))

            ang_noise = float(ang_noise_mult) * torch.sqrt(temperature*alpha_ang) * torch.randn_like(ang_score, dtype=torch.float64) 
            lin_noise = float(lin_noise_mult) * torch.sqrt(temperature*alpha_lin) * torch.randn_like(lin_score, dtype=torch.float64)

            ang_disp = (alpha_ang/2) * ang_score + ang_noise
            lin_disp = (alpha_lin/2) * lin_score + lin_noise

            L = T.detach()[...,self.q_indices] * (self.q_factor.type(torch.float64))
            q, x = T[...,:4], T[...,4:]
            dq = torch.einsum('...ij,...j->...i', L, ang_disp)
            dx = transforms.quaternion_apply(q, lin_disp)
            q = transforms.normalize_quaternion(q + dq)
            T = torch.cat([q, x+dx], dim=-1)

            # dT = transforms.se3_exp_map(torch.cat([lin_disp, ang_disp], dim=-1))
            # dT = torch.cat([transforms.matrix_to_quaternion(dT[..., :3, :3]), dT[..., :3, 3]], dim=-1)
            # T = transforms.multiply_se3(T, dT)
            steps += 1
            Ts.append(T.clone().detach())

    Ts.append(T.clone().detach())
    Ts = torch.stack(Ts, dim=0).detach()

    return Ts

In [None]:
def sample_agent(self, scene_pcd: PointCloud, 
            grasp_pcd: PointCloud, 
            Ts_init: SE3,
            N_steps_list: List[List[int]],
            timesteps_list: List[List[float]],
            temperatures_list,
            diffusion_schedules_list = None,
            ) -> Tuple[torch.Tensor, PointCloud, PointCloud]:
    from diffusion_edf.gnn_data import FeaturedPoints, pcd_to_featured_points
    if diffusion_schedules_list is None:
        diffusion_schedules_list = [None for _ in range(len(self.models))]
    assert len(self.models) == len(N_steps_list), f"{len(self.models)} != {len(N_steps_list)}"
    assert len(self.models) == len(timesteps_list), f"{len(self.models)} != {len(timesteps_list)}"
    assert len(self.models) == len(temperatures_list), f"{len(self.models)} != {len(temperatures_list)}"
    assert len(self.models) == len(diffusion_schedules_list), f"{len(self.models)} != {len(diffusion_schedules_list)}"

    scene_pcd: PointCloud = self.proc_fn(scene_pcd)
    grasp_pcd: PointCloud = self.proc_fn(grasp_pcd)
    Ts_init: SE3 = self.proc_fn(Ts_init)

    scene_input: FeaturedPoints = pcd_to_featured_points(scene_pcd)
    grasp_input: FeaturedPoints = pcd_to_featured_points(grasp_pcd)
    T0: torch.Tensor = Ts_init.poses
    assert T0.ndim == 2 and T0.shape[-1] == 7, f"{T0.shape}"

    Ts_out = []
    for model, N_steps, timesteps, temperatures, diffusion_schedules in zip(self.models, N_steps_list, timesteps_list, temperatures_list, diffusion_schedules_list):
        #################### Feature extraction #####################
        with torch.no_grad():
            scene_out_multiscale: List[FeaturedPoints] = model.get_key_pcd_multiscale(scene_input)
            grasp_out: FeaturedPoints = model.get_query_pcd(grasp_input)

        if diffusion_schedules is None:
            diffusion_schedules = model.diffusion_schedules
        assert len(diffusion_schedules) == len(N_steps), f"{len(diffusion_schedules)} != {len(N_steps)}"
        assert len(diffusion_schedules) == len(timesteps), f"{len(diffusion_schedules)} != {len(timesteps)}"

        #################### Sample #####################
        with torch.no_grad():
            Ts = sample_model(model,
                T_seed=T0.clone().detach(),
                scene_pcd_multiscale=scene_out_multiscale,
                grasp_pcd=grasp_out,
                diffusion_schedules=diffusion_schedules,
                N_steps=N_steps,
                timesteps=timesteps,
                temperatures=temperatures,
                linear_noise_schedule = False
            )
            T0 = Ts[-1]
            Ts_out.append(Ts)
    Ts_out = torch.cat(Ts_out, dim=0).float() # Ts_out: (nTime, nSample, 7)

    return Ts_out, scene_pcd, grasp_pcd

In [None]:
Ts_out, scene_proc, grasp_proc = sample_agent(self=agent,
    scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
                                              N_steps_list = [[250, 250], [250, 250]],
                                              timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
                                              temperatures_list = [[1., 1.], [1., 1.]])

# Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
#                                               N_steps_list = [[500, 500], [500, 1000]],
#                                               timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
#                                               temperatures_list = [[1., 1.], [1., 1.]])

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show()