In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable, NamedTuple, Any, Sequence

from tqdm import tqdm
import yaml

import torch
from torchvision.transforms import Compose
from torch.utils.data import DataLoader

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf import transforms
from diffusion_edf.feature_extractor import UnetFeatureExtractor
from diffusion_edf.tensor_field import TensorField
from diffusion_edf.radial_func import SinusoidalPositionEmbeddings
from diffusion_edf.equivariant_score_model import ScoreModel

In [None]:
device = 'cuda:0'
task = 'pick'
eval = True
compile = False

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'

# Load configs, preprocessor, and dataloader

In [None]:
with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['configs'])
    )
proc_fn = Compose(proc_fn)

In [None]:
collate_fn = train_utils.get_collate_fn(task=task, proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [None]:
score_model = ScoreModel(**model_configs['score_model_configs']).to(device=device)
if compile:
    raise NotImplementedError
if eval:
    score_model = score_model.eval()

# Loop Example

In [None]:
for demo_batch in train_dataloader:
    B = len(demo_batch)
    scene_pcd, grasp_pcd, target_poses = train_utils.flatten_batch(demo_batch=demo_batch) # target_poses: (Nbatch, Ngrasps, 7)
    scene_pcd_rescaled = set_featured_points_attribute(scene_pcd, x=scene_pcd.x/model_configs['unit_len'])
    grasp_pcd_rescaled = set_featured_points_attribute(grasp_pcd, x=grasp_pcd.x/model_configs['unit_len'])
    
    break

In [None]:
N_T = 5
q = transforms.random_quaternions(N_T, device=device)
x = torch.randn(N_T,3,device=device)
Ts = torch.cat([q, x], dim=-1)

time = torch.rand(N_T,device=device)

In [None]:
with torch.no_grad():
    score, (scene_out, grasp_out) = score_model(Ts=Ts, time=time,
                                                key_pcd=scene_pcd_rescaled, 
                                                query_pcd=grasp_pcd_rescaled, 
                                                extract_features = True,
                                                debug = True)

In [None]:
pcd = PointCloud(points=scene_out.x.detach().cpu(), colors=scene_out.w.detach().cpu())
pcd.show(point_size=3., width=800, height=800)

# pcd = PointCloud(points=grasp_out.x.detach().cpu(), colors=grasp_out.w.detach().cpu())
# pcd.show(point_size=3., width=800, height=800)

In [None]:
sdfa

In [None]:
import matplotlib
viridis_cmap = matplotlib.cm.get_cmap('viridis')
norm = matplotlib.colors.Normalize(vmin=0, vmax=255)

out = matplotlib.colors.colorConverter.to_rgb(viridis_cmap(norm(0.3)))

In [None]:
viridis_cmap(scene_out.w.detach().cpu())

In [None]:
sfda

# Rotate

In [None]:
from diffusion_edf.transforms import quaternion_apply, random_quaternions
from diffusion_edf.gnn_data import TransformPcd

transform_input = torch.jit.script(TransformPcd(irreps=model_configs['scene']['feature_extractor_configs']['irreps_input']).to(device=device))
transform_query = torch.jit.script(TransformPcd(irreps=o3.Irreps(f"{model_configs['time_emb_dim']}x0e")).to(device=device))
transform_output = torch.jit.script(TransformPcd(irreps=model_configs['scene']['feature_extractor_configs']['irreps_output']).to(device=device))
transform_field = torch.jit.script(TransformPcd(irreps=model_configs['scene']['tensor_field_configs']['irreps_output']).to(device=device))

rot = random_quaternions(1, device=device)
trans = torch.randn(1,3,device=device)
Ts = torch.cat([rot,trans], dim=-1)

In [None]:
input_pcd_rot: FeaturedPoints = flatten_featured_points(transform_input(input_pcd, Ts))
query_pcd_rot: FeaturedPoints = flatten_featured_points(transform_query(query_points, Ts))
post_rot_output_points: FeaturedPoints = flatten_featured_points(transform_output(output_points, Ts))
post_rot_field_output: FeaturedPoints = flatten_featured_points(transform_field(field_output, Ts))

In [None]:
# pre_rot_output_points = feature_extractor(input_pcd_rot)[scale_idx]
pre_rot_output_points = feature_extractor(input_pcd_rot)

isclose = torch.isclose(pre_rot_output_points.x, post_rot_output_points.x, atol=0.01, rtol=0.01)
print(f"Position Equivariance ratio: {(isclose.sum() / len(pre_rot_output_points.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(pre_rot_output_points.f, post_rot_output_points.f, atol=0.001, rtol=0.001)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(pre_rot_output_points.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

In [None]:
pre_rot_field_output = tensor_field(query_pcd_rot,
                                    input_points = post_rot_output_points)

In [None]:
isclose = torch.isclose(pre_rot_field_output.x, post_rot_field_output.x, atol=0.01, rtol=0.01)
print(f"Position Equivariance ratio: {(isclose.sum() / len(pre_rot_field_output.x.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.
isclose = torch.isclose(pre_rot_field_output.f, post_rot_field_output.f, atol=0.001, rtol=0.001)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(pre_rot_field_output.f.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.