In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable, NamedTuple, Any, Sequence

from tqdm import tqdm
import yaml

import torch
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf.feature_extractor import UnetFeatureExtractor

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda:0'
task = 'pick'
eval = True
compile = True

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'

# Load configs, preprocessor, and dataloader

In [3]:
with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)

In [4]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['configs'])
    )
proc_fn = Compose(proc_fn)

In [5]:
collate_fn = train_utils.get_collate_fn(task=task, proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [6]:
feature_extractor = UnetFeatureExtractor(**model_configs['scene']['feature_extractor_configs']).to(device)
if compile:
    feature_extractor = torch.jit.script(feature_extractor)
if eval:
    feature_extractor = feature_extractor.eval()



# Loop Example

In [7]:
for demo_batch in train_dataloader:
    B = len(demo_batch)
    scene_pcd, grasp_pcd, target_poses = train_utils.flatten_batch(demo_batch=demo_batch) # target_poses: (Nbatch, Ngrasps, 7)
    input_pcd = FeaturedPoints(
        x=scene_pcd.x/model_configs['unit_len'],
        f=scene_pcd.f,
        b=scene_pcd.b
    )
    break

In [8]:
output = feature_extractor(input_pcd)

# Rotate

In [9]:
# scale = 3
# node_coord = output[scale].x
# node_feature = output[scale].f
# node_batch = output[scale].b

In [10]:
from diffusion_edf.transforms import quaternion_apply, random_quaternions
from diffusion_edf.gnn_data import TransformPcd

transform_input = torch.jit.script(TransformPcd(irreps=model_configs['scene']['feature_extractor_configs']['irreps_input'], device=device))
transform_output = torch.jit.script(TransformPcd(irreps=model_configs['scene']['feature_extractor_configs']['irreps_output'], device=device))

rot = random_quaternions(1, device=device)
trans = torch.randn(1,3,device=device)
Ts = torch.cat([rot,trans], dim=-1)

In [12]:
T_idx = 0
scale_idx = 3

input_pcd_rot: FeaturedPoints = transform_input(input_pcd, Ts)
input_pcd_rot = FeaturedPoints(x=input_pcd_rot.x[T_idx], f=input_pcd_rot.f[T_idx], b=input_pcd_rot.b[T_idx])

pre_rot_out = feature_extractor(input_pcd_rot)[scale_idx]
post_rot_out: FeaturedPoints = transform_output(output[scale_idx], Ts)
post_rot_out = FeaturedPoints(x=post_rot_out.x[T_idx], f=post_rot_out.f[T_idx], b=post_rot_out.b[T_idx], w=post_rot_out.w[T_idx])

In [13]:
isclose = torch.isclose(pre_rot_out.x, post_rot_out.x, atol=0.01, rtol=0.01)
print(f"Equivariance ratio: {(isclose.sum() / len(pre_rot_out.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

Equivariance ratio: 1.0


In [14]:
isclose = torch.isclose(pre_rot_out.f, post_rot_out.f, atol=0.001, rtol=0.001)
print(f"Equivariance ratio: {(isclose.sum() / len(pre_rot_out.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

Equivariance ratio: 1.0
