In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
from typing import List, Tuple, Optional, Union, Iterable
import warnings
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface.data import PointCloud, SE3, DemoDataset, TargetPoseDemo
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose
from diffusion_edf.agent import DiffusionEdfAgent

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
seed = 2
torch.manual_seed(seed)
import numpy as np
np.random.seed(seed)
import random
random.seed(seed)

In [None]:
device = 'cuda:0'
# device = 'cpu'
half_precision = False
task_type = 'pick'
config_root_dir = 'configs/ebm'

testset = DemoDataset(dataset_dir='demo/panda_mug',
                      device=device, 
                      dtype = torch.float16 if half_precision else torch.float32)

In [None]:
with open(os.path.join(config_root_dir, 'agent.yaml')) as f:
    model_kwargs_list = yaml.load(f, Loader=yaml.FullLoader)['model_kwargs'][f"{task_type}_models_kwargs"]

with open(os.path.join(config_root_dir, 'preprocess.yaml')) as f:
    preprocess_config = yaml.load(f, Loader=yaml.FullLoader)
    unprocess_config = preprocess_config['unprocess_config']
    preprocess_config = preprocess_config['preprocess_config']

agent = DiffusionEdfAgent(
    model_kwargs_list=model_kwargs_list,
    preprocess_config=preprocess_config,
    unprocess_config=unprocess_config,
    device=device,
    compile_score_head=False,
    half_precision=half_precision
)

In [None]:
from diffusion_edf.emb_score_head import EbmScoreModelHead

In [None]:
model = agent.models[0]

In [None]:
key_tensor_field_kwargs={
    "irreps_output": o3.Irreps('64x0e+32x1e+16x2e'),
    "irreps_sh": o3.Irreps("1x0e+1x1e+1x2e"),
    "num_heads": 4,
    "fc_neurons": [-1, 128, 64],
    "length_emb_dim": 64,
    "r_cluster_multiscale": [6.,],
    "n_layers": 1,
    "irreps_mlp_mid": 3,
    "cutoff_method": 'edge_attn',
    "r_mincut_nonscalar_sh": 0.1
}

assert 'irreps_input' not in key_tensor_field_kwargs.keys()
key_tensor_field_kwargs['irreps_input'] = model.key_model.irreps_output
assert 'use_src_point_attn' not in key_tensor_field_kwargs.keys()
key_tensor_field_kwargs['use_src_point_attn'] = False
assert 'use_dst_point_attn' not in key_tensor_field_kwargs.keys()
key_tensor_field_kwargs['use_dst_point_attn'] = False



ebm_head = EbmScoreModelHead(
    max_time=1.0,
    time_emb_mlp=[512, 256, 128],
    key_tensor_field_kwargs=key_tensor_field_kwargs,
    irreps_query_edf=o3.Irreps('64x0e+32x1e+16x2e'),
    lin_mult=15.,
    ang_mult=2.5,
    edge_time_encoding=True,
    query_time_encoding=False
).to(device)

# Initialize Input Data and Initial Pose

In [None]:
demo: TargetPoseDemo = testset[2][0 if task_type == 'pick' else 1 if task_type == 'place' else "task_type must be either 'pick' or 'place'"].to(device)
scene_pcd: PointCloud = demo.scene_pcd
grasp_pcd: PointCloud = demo.grasp_pcd
Ts: SE3 = demo.target_poses
# T0 = torch.cat([
#     torch.tensor([[1., 0., 0.0, 0.]], device=device, dtype=scene_pcd.points.dtype),
#     torch.tensor([[0., 0., 0.3]], device=device, dtype=scene_pcd.points.dtype)
# ], dim=-1)
# Ts_init = SE3(poses=T0).to(device, dtype=scene_pcd.points.dtype)


In [None]:
from diffusion_edf import train_utils

scene_pcd_proc: PointCloud = agent.proc_fn(scene_pcd)
grasp_pcd_proc: PointCloud = agent.proc_fn(grasp_pcd)
Ts_proc: SE3 = agent.proc_fn(Ts)

scene_pcd_proc: FeaturedPoints = train_utils.pcd_to_featured_points(scene_pcd_proc)
grasp_pcd_proc: FeaturedPoints = train_utils.pcd_to_featured_points(grasp_pcd_proc)
Ts_proc: torch.Tensor = Ts_proc.poses

In [None]:
with torch.no_grad():
    key_pcd_multiscale: List[FeaturedPoints] = model.get_key_pcd_multiscale(scene_pcd_proc)
    query_pcd: FeaturedPoints = model.get_query_pcd(grasp_pcd_proc)

In [None]:
with torch.no_grad():
    energy = ebm_head.compute_energy(
        Ts=Ts_proc,
        key_pcd_multiscale=key_pcd_multiscale,
        query_pcd=query_pcd,
        time=torch.tensor([1.], device=device)
    )

In [None]:
sfad

In [None]:
ang_vel, lin_vel = ebm_head.forward(
    Ts=Ts_proc,
    key_pcd_multiscale=key_pcd_multiscale,
    query_pcd=query_pcd,
    time=torch.tensor([1.], device=device)
)

In [None]:
sfad

In [None]:
from diffusion_edf import transforms

T = Ts_proc
time = torch.tensor([1.], device=device)


T = T.detach().requires_grad_(True)
logP = -ebm_head.compute_energy(
    Ts=T,
    key_pcd_multiscale=key_pcd_multiscale,
    query_pcd=query_pcd,
    time=time
)
logP.sum().backward(inputs=T)
grad = T.grad
L = T.detach()[...,ebm_head.q_indices] * ebm_head.q_factor
grad = torch.cat([transforms.quaternion_apply(transforms.quaternion_invert(T[...,:4].detach()), grad[...,4:]), torch.einsum('...ia,...i', L, grad[...,:4])], dim=-1)

In [None]:
grad

In [None]:
sdfasfda

In [None]:
Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
                                              N_steps_list = [[250, 250], [250, 250]],
                                              timesteps_list = [[0.02, 0.02], [0.02, 0.02]],
                                              temperatures_list = [[1., 1.], [1., 1.]],
                                              log_t_schedule = True,
                                              time_exponent_temp = 1.0,
                                              time_exponent_alpha = 0.5)

# Ts_out, scene_proc, grasp_proc = agent.sample(scene_pcd=scene_pcd, grasp_pcd=grasp_pcd, Ts_init=Ts_init,
#                                               N_steps_list = [[250, 250], [500, 500]],
#                                               timesteps_list = [[0.04, 0.04], [0.04, 0.06]],
#                                               temperatures_list = [[1., 1.], [0.5, 0.1]],
#                                               diffusion_schedules_list=[
#                                                   [[1., 0.1], [0.1, 0.1]],
#                                                   [[0.1, 0.1], [0.03, 0.03] ],
#                                                   ],
#                                               log_t_schedule = False,
#                                               time_exponent_temp = 1.0,
#                                               time_exponent_alpha = 0.5,)

In [None]:
sample_idx = 0
visualization = TargetPoseDemo(
    target_poses=SE3(poses=torch.cat([Ts_out[::10, sample_idx], Ts_out[-1:, sample_idx]], dim=0)),
    scene_pcd=scene_proc,
    grasp_pcd=grasp_proc
)
visualization = agent.unprocess_fn(visualization).to('cpu')
visualization.show(bg_color=(0.3, 0.3, 0.3))