In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable, NamedTuple, Any, Sequence, Dict

from tqdm import tqdm
import yaml

import torch
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr, detach_featured_points
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf import transforms
from diffusion_edf.feature_extractor import UnetFeatureExtractor
from diffusion_edf.radial_func import SinusoidalPositionEmbeddings
from diffusion_edf.equivariant_score_model import ScoreModel

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda:0'
eval = True
compile = False

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'
task_configs_dir = 'configs/test/task_configs.yaml'

with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(task_configs_dir) as file:
    task_configs = yaml.load(file, Loader=yaml.FullLoader)

train_configs['preprocess_config'].append({
    'name': 'Rescale',
    'kwargs': {'rescale_factor': 1/task_configs['unit_length']}
})

# Load configs, preprocessor, and dataloader

In [3]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['kwargs'])
    )
proc_fn = Compose(proc_fn)

In [4]:
collate_fn = train_utils.get_collate_fn(task=train_configs['task_type'], proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [5]:
score_model = ScoreModel(**model_configs, deterministic=True).to(device=device)
if compile:
    raise NotImplementedError
if eval:
    score_model = score_model.eval()

ScoreModel: Initializing Score Head




ScoreModel: Initializing Key Feature Extractor


# Loop Example

In [6]:
N_T = 5

for demo_batch in train_dataloader:
    B = len(demo_batch)
    scene_pcd, grasp_pcd, target_poses = train_utils.flatten_batch(demo_batch=demo_batch) # target_poses: (Nbatch, Ngrasps, 7)

    q = transforms.random_quaternions(N_T, device=device)
    x = torch.randn(N_T,3,device=device)
    Ts = torch.cat([q, x], dim=-1)
    time = torch.rand(N_T,device=device)

    break

In [7]:
scene_out_multiscale = score_model.key_feature_extractor(pcd=scene_pcd)

In [8]:
query_points = detach_featured_points(scene_out_multiscale[-1])
if score_model.score_head.key_tensor_field.irreps_query is not None:
    query_points = set_featured_points_attribute(query_points, f=score_model.score_head.key_tensor_field.irreps_query.randn(len(query_points.x), -1, device=device))

In [9]:
if score_model.score_head.key_tensor_field.context_emb_dim is None:
    context_emb = None
else:
    context_emb = torch.randn(len(query_points.x), score_model.score_head.key_tensor_field.context_emb_dim)

In [10]:
out = score_model.score_head.key_tensor_field(query_points=query_points, input_points_multiscale=scene_out_multiscale, context_emb=context_emb)

In [11]:
out.f.isnan().any()

tensor(False, device='cuda:0')

In [12]:
sdfa

NameError: name 'sdfa' is not defined

In [None]:
scene_pcd.x.shape

In [None]:
sdaf

In [None]:
with torch.no_grad():
    (ang_score, lin_score), (scene_out, grasp_out) = score_model(Ts=Ts, time=time,
                                                                 key_pcd=scene_pcd, 
                                                                 query_pcd=grasp_pcd, 
                                                                 extract_features = True,
                                                                 debug = True)

In [None]:
# pcd = PointCloud(points=scene_out.x.detach().cpu(), colors=scene_out.w.detach().cpu(), cmap='magma')
# pcd.show(point_size=3., width=800, height=800)

# pcd = PointCloud(points=grasp_out.x.detach().cpu(), colors=grasp_out.w.detach().cpu(), cmap='magma')
# pcd.show(point_size=8., width=800, height=800)

# Rotate

In [None]:
from diffusion_edf.transforms import quaternion_apply, random_quaternions
from diffusion_edf.gnn_data import TransformPcd

transform_input = torch.jit.script(TransformPcd(irreps="3x0e").to(device=device))
transform_key = torch.jit.script(TransformPcd(irreps=model_configs['key_kwargs']['feature_extractor_configs']['irreps_output']).to(device=device))
transform_query = torch.jit.script(TransformPcd(irreps=model_configs['query_kwargs']['feature_extractor_configs']['irreps_output']).to(device=device))

In [None]:
rot1 = transforms.random_quaternions(1, device=device)
trans1 = torch.randn(1,3,device=device)
# rot1 = torch.tensor([1., 0., 0., 0.], device=device).expand(N_T,-1)
# trans1 = torch.zeros(N_T,3, device=device)
Ts_1 = torch.cat([rot1, trans1], dim=-1)
scene_pcd_rot: FeaturedPoints = flatten_featured_points(transform_input(scene_pcd, Ts_1))
scene_out_post_rot = transform_key(scene_out, Ts=Ts_1)

rot2 = transforms.random_quaternions(1, device=device)
trans2 = torch.randn(1,3,device=device)
# rot2 = torch.tensor([1., 0., 0., 0.], device=device).expand(N_T,-1)
# trans2 = torch.zeros(N_T,3, device=device)
Ts_2 = torch.cat([rot2, trans2], dim=-1)

grasp_pcd_rot: FeaturedPoints = flatten_featured_points(transform_input(grasp_pcd, Ts_2))
grasp_out_post_rot = transform_key(grasp_out, Ts=Ts_2)

lin_score_post_rot = quaternion_apply(rot2, lin_score)
ang_score_post_rot = quaternion_apply(rot2, ang_score) + torch.cross(trans2, lin_score_post_rot)

q_rot, x_rot = q, x
q_rot = transforms.quaternion_multiply(rot1, q_rot)
x_rot = transforms.quaternion_apply(rot1, x_rot) + trans1
q_rot = transforms.quaternion_multiply(q_rot, transforms.quaternion_invert(rot2))
x_rot = x_rot - transforms.quaternion_apply(q_rot, trans2)

Ts_rot = torch.cat([q_rot, x_rot], dim=-1)

In [None]:
with torch.no_grad():
    (ang_score_pre_rot, lin_score_pre_rot), (scene_out_pre_rot, grasp_out_pre_rot) = score_model(Ts=Ts_rot, time=time,
                                                                                                 key_pcd=scene_pcd_rot, 
                                                                                                 query_pcd=grasp_pcd_rot, 
                                                                                                 extract_features = True,
                                                                                                 debug = True)

In [None]:
print("=== Scene feature extractor equivariance ===")

isclose = torch.isclose(scene_out_pre_rot.x, scene_out_post_rot.x, atol=0.01, rtol=0.01)
print(f"Position Equivariance ratio: {(isclose.sum() / len(scene_out_pre_rot.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(scene_out_pre_rot.f, scene_out_post_rot.f, atol=0.001, rtol=0.001)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(scene_out_pre_rot.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

In [None]:
print("=== Grasp feature extractor equivariance ===")

isclose = torch.isclose(grasp_out_pre_rot.x, grasp_out_post_rot.x, atol=0.01, rtol=0.01)
print(f"Position Equivariance ratio: {(isclose.sum() / len(grasp_out_pre_rot.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(grasp_out_pre_rot.f, grasp_out_post_rot.f, atol=0.001, rtol=0.001)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(grasp_out_pre_rot.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

In [None]:
print("=== Score equivariance ===")

isclose = torch.isclose(ang_score_pre_rot, ang_score_post_rot, atol=0.01, rtol=0.01)
print(f"Angular Score Equivariance ratio: {(isclose.sum() / len(ang_score_pre_rot.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(lin_score_pre_rot, lin_score_post_rot,  atol=0.001, rtol=0.001)
print(f"Linear Score Equivariance ratio: {(isclose.sum() / len(lin_score_pre_rot.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.