In [1]:
import os
# os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable, NamedTuple, Any, Sequence

from tqdm import tqdm
import yaml

import torch
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, FeaturedPoints, merge_featured_points
from diffusion_edf import train_utils
from diffusion_edf import preprocess

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda:0'
task = 'pick'
eval = True

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'

In [3]:
with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)

In [4]:
proc_fn = Compose([
    preprocess.Downsample(voxel_size=model_configs['input_voxel_size'], coord_reduction="average"),
])
collate_fn = train_utils.get_collate_fn(task=task, proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir="test_data", annotation_file="data.yaml", device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

In [5]:
for demo_batch in train_dataloader:
    B = len(demo_batch)
    scene_pcd, grasp_pcd, target_poses = train_utils.flatten_batch(demo_batch=demo_batch)

    break

In [None]:
sfdafdsa

In [None]:
edf = EDF(irreps_input = irreps_input,
          irreps_emb_init = irreps_node_embedding,
          irreps_sh = irreps_sh,
          fc_neurons_init = [32, 16, 16],
          num_heads = 4,
          n_scales = 4,
          pool_ratio = 0.5,
          dim_mult = [1, 2, 3, 4],
          n_layers = 2,
          gnn_radius = 2.0,
          cutoff_radius = 3.0,
          deterministic = True,
          compile_head=True)

edf = edf.to(device)
if eval:
    edf = edf.eval()

# Load demo

In [None]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[0]

demo: TargetPoseDemo = demo_seq[1]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

#print(scene_proc)
#go.Figure(scene_unproc_fn(scene_proc).plotly())


node_feature = scene_proc.colors
node_coord = scene_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)

query_coord = torch.randn(3,3, device=device)
query_batch = torch.zeros_like(query_coord[...,0])

In [None]:
field_val, edf_info = edf.forward(query_coord=query_coord, query_batch=query_batch, node_feature=node_feature, node_coord=node_coord, batch=batch)

# Rotate

In [None]:
from diffusion_edf.quaternion_utils import quaternion_apply, random_quaternions
rot = random_quaternions(1, device=device)
rot = rot/rot.norm(dim=-1, keepdim=True)
trans = torch.randn(3,device=device)
node_coord_rot = quaternion_apply(rot, node_coord) + trans
query_coord_rot = quaternion_apply(rot, query_coord) + trans

node_feature_rot = node_feature

In [None]:
field_val_rot, edf_info_rot = edf.forward(query_coord=query_coord_rot, query_batch=query_batch, node_feature=node_feature_rot, node_coord=node_coord_rot, batch=batch)

In [None]:
irrep_transform = e3nn_script(TransformFeatureQuaternion(irreps = o3.Irreps(edf.irreps_emb), device=device))
a = field_val_rot
b = field_val
isclose = torch.isclose(irrep_transform(b, rot)[0], a, atol=0.01, rtol=0.01)
# print(isclose)
print(f"Equivariance ratio: {(isclose.sum() / len(a.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.