In [None]:
import os
# os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable, NamedTuple, Any, Sequence

from tqdm import tqdm
import yaml

import torch
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf.feature_extractor import UnetFeatureExtractor

In [None]:
device = 'cuda:0'
task = 'pick'
eval = True
compile = True

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'

# Load configs, preprocessor, and dataloader

In [None]:
with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['configs'])
    )
proc_fn = Compose(proc_fn)

In [None]:
collate_fn = train_utils.get_collate_fn(task=task, proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [None]:
feature_extractor = UnetFeatureExtractor(**model_configs['scene']['feature_extractor_configs']).to(device)
if compile:
    feature_extractor = torch.jit.script(feature_extractor)
if eval:
    feature_extractor = feature_extractor.eval()

# Loop Example

In [None]:
for demo_batch in train_dataloader:
    B = len(demo_batch)
    scene_pcd, grasp_pcd, target_poses = train_utils.flatten_batch(demo_batch=demo_batch)
    input_pcd = FeaturedPoints(
        x=scene_pcd.x/model_configs['unit_len'],
        f=scene_pcd.f,
        b=scene_pcd.b
    )
    break

In [None]:
output = feature_extractor(input_pcd)

In [None]:
@torch.jit.script
def func(x: List[Tuple[int, Tuple[int, int]]]) -> List[Tuple[int, Tuple[int, int]]]:
    return x

In [None]:
irreps = o3.Irreps(model_configs['scene']['feature_extractor_configs']['irreps_output'])

In [None]:
func(irreps)

# Rotate

In [None]:
# scale = 3
# node_coord = output[scale].x
# node_feature = output[scale].f
# node_batch = output[scale].b

In [None]:
from diffusion_edf.transforms import quaternion_apply, random_quaternions
rot = random_quaternions(1, device=device)
rot = rot/rot.norm(dim=-1, keepdim=True)
trans = torch.randn(3,device=device)
node_coord_rot = quaternion_apply(rot, node_coord) + trans


node_feature_rot = node_feature

In [None]:
field_val_rot, edf_info_rot = edf.forward(query_coord=query_coord_rot, query_batch=query_batch, node_feature=node_feature_rot, node_coord=node_coord_rot, batch=batch)

In [None]:
irrep_transform = e3nn_script(TransformFeatureQuaternion(irreps = o3.Irreps(edf.irreps_emb), device=device))
a = field_val_rot
b = field_val
isclose = torch.isclose(irrep_transform(b, rot)[0], a, atol=0.01, rtol=0.01)
# print(isclose)
print(f"Equivariance ratio: {(isclose.sum() / len(a.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.