In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable, NamedTuple, Any, Sequence

from tqdm import tqdm
import yaml

import torch
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf.feature_extractor import UnetFeatureExtractor
from diffusion_edf.tensor_field import TensorField

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda:0'
task = 'pick'
eval = True
compile = False

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'

# Load configs, preprocessor, and dataloader

In [3]:
with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)

In [4]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['configs'])
    )
proc_fn = Compose(proc_fn)

In [5]:
collate_fn = train_utils.get_collate_fn(task=task, proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [6]:
fe_kwargs = model_configs['scene']['feature_extractor_configs']
tf_kwargs = model_configs['scene']['tensor_field_configs']

feature_extractor = UnetFeatureExtractor(**fe_kwargs, deterministic=True).to(device)
tf_kwargs['irreps_input'] = str(feature_extractor.irreps_output)
tf = tensor_field = TensorField(**tf_kwargs).to(device)
if compile:
    feature_extractor = torch.jit.script(feature_extractor)
    tf = torch.jit.script(tf)
if eval:
    feature_extractor = feature_extractor.eval()
    tf = tf.eval()



# Loop Example

In [7]:
for demo_batch in train_dataloader:
    B = len(demo_batch)
    scene_pcd, grasp_pcd, target_poses = train_utils.flatten_batch(demo_batch=demo_batch) # target_poses: (Nbatch, Ngrasps, 7)
    input_pcd = FeaturedPoints(
        x=scene_pcd.x/model_configs['unit_len'],
        f=scene_pcd.f,
        b=scene_pcd.b
    )
    break

In [8]:
Nq = 20
query_x = input_pcd.x.detach().mean(dim=0).expand(Nq,-1)
query_x = query_x + (0.1 * torch.randn_like(query_x) * input_pcd.x.detach().std(dim=0))
query_points = FeaturedPoints(x=query_x, f=torch.empty_like(query_x), b=torch.zeros(len(query_x), dtype=torch.long, device=device))
time_emb = torch.randn(B,tf.time_emb_dim, device=device)

In [9]:
# output_points_multiscale = feature_extractor(input_pcd)
# output_points = output_points_multiscale[scale_idx]
output_points = feature_extractor(input_pcd)
field_output = tf(query_points,
                  input_points = output_points,
                  time_emb = time_emb)

# Rotate

In [10]:
from diffusion_edf.transforms import quaternion_apply, random_quaternions
from diffusion_edf.gnn_data import TransformPcd

transform_input = torch.jit.script(TransformPcd(irreps=model_configs['scene']['feature_extractor_configs']['irreps_input']).to(device=device))
transform_output = torch.jit.script(TransformPcd(irreps=model_configs['scene']['feature_extractor_configs']['irreps_output']).to(device=device))
transform_field = torch.jit.script(TransformPcd(irreps=model_configs['scene']['tensor_field_configs']['irreps_output']).to(device=device))

rot = random_quaternions(1, device=device)
trans = torch.randn(1,3,device=device)
Ts = torch.cat([rot,trans], dim=-1)

In [11]:
input_pcd_rot: FeaturedPoints = flatten_featured_points(transform_input(input_pcd, Ts))
query_pcd_rot: FeaturedPoints = flatten_featured_points(transform_input(query_points, Ts))
post_rot_output_points: FeaturedPoints = flatten_featured_points(transform_output(output_points, Ts))
post_rot_field_output: FeaturedPoints = flatten_featured_points(transform_field(field_output, Ts))

In [12]:
# pre_rot_output_points = feature_extractor(input_pcd_rot)[scale_idx]
pre_rot_output_points = feature_extractor(input_pcd_rot)

isclose = torch.isclose(pre_rot_output_points.x, post_rot_output_points.x, atol=0.01, rtol=0.01)
print(f"Position Equivariance ratio: {(isclose.sum() / len(pre_rot_output_points.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(pre_rot_output_points.f, post_rot_output_points.f, atol=0.001, rtol=0.001)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(pre_rot_output_points.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

Position Equivariance ratio: 0.9966805577278137
Feature Equivariance ratio: 0.9966943264007568


In [13]:
pre_rot_field_output = tf(query_pcd_rot,
                          input_points = post_rot_output_points,
                          time_emb = time_emb)

In [14]:
isclose = torch.isclose(pre_rot_field_output.x, post_rot_field_output.x, atol=0.01, rtol=0.01)
print(f"Position Equivariance ratio: {(isclose.sum() / len(pre_rot_field_output.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(pre_rot_field_output.f, post_rot_field_output.f, atol=0.001, rtol=0.001)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(pre_rot_field_output.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

Position Equivariance ratio: 1.0
Feature Equivariance ratio: 1.0
