# Joint Search Viewer
View the results of joint pose search

In [1]:
from pathlib import Path
import sys
import os
import random
import torch
import numpy as np
import meshplot as mp
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation

root_dir = Path().resolve().parent
if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))

from utils import util
from joint.joint_environment import JointEnvironment
from joint.joint_prediction_set import JointPredictionSet

from search.search_simplex import SearchSimplex
from search.search_random import SearchRandom

from train import JointPrediction
from datasets.joint_graph_dataset import JointGraphDataset

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


## Load Network & Data
Load a pretrained checkpoint to use for inference.

Load the dataset and create an instance of the JointPredictionSet class.
We assume that the joint json and mesh part files are in the same directory.

In [2]:
def load_network(checkpoint_file):
    """Load the network"""
    if not checkpoint_file.exists():
        print("Checkpoint file does not exist")
        return None
    model = JointPrediction.load_from_checkpoint(
        checkpoint_file,
        map_location=torch.device("cpu")  # Just use the CPU
    )
    return model

# Use a checkpoint from the pretrained model
checkpoint_file = root_dir / "pretrained/paper/last_run_0.ckpt"
model = load_network(checkpoint_file)

# Change to point to the Fusion 360 Gallery joint dataset
# this directory should contain the joint json and obj part files
# data_dir = root_dir / "data/tester"
data_dir = root_dir / "data/zw3d-joinable-dataset"
dataset = JointGraphDataset(
    root_dir=data_dir,
    split="val",
    label_scheme="Joint,JointEquivalent"
)

Data cache loaded from: /home/fusiqiao/code/JoinABLe/data/zw3d-joinable-dataset/val.pickle


In [8]:
from random import randint

# Data sample in the dataset we want to visualize
# index = randint(0, len(dataset))
# print("index = ", index)
index = 0

# Graphs for part one and two, and the densely connected joint graph
g1, g2, joint_graph = dataset[index]
# The joint file json
joint_file = data_dir / dataset.files[index]

# Create the prediction with the given data and model
jps = JointPredictionSet(
    joint_file,
    g1, g2, joint_graph,
    model
)

## Ground Truth
View the ground truth assembled state of the parts.

In [9]:
v, f, c, e = jps.get_meshes(
    joint_index=0,
    show_joint_entity_colors=False,
    show_joint_equivalent_colors=False,
)
p = mp.plot(v, f, c=c)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

## Joint Axis Prediction without Search
View the parts when assembled using only the joint axis prediction, without performing search for the offset/rotation/flip parameters.

In [10]:
# Get the transform to move body1, to align with the static body2
# using default parameters, i.e. no offset, rotation, or flip
transform = JointEnvironment.get_transform_from_parameters(
    jps,
    prediction_index=0,  # Top-1 prediction
    offset=0,
    rotation_in_degrees=0,
    flip=False
)
transform = jps.get_transform(body=2) @ transform

# Render the meshes to visualize
v, f, c, e, n, ni = jps.get_meshes(
    apply_transform=True,
    body_one_transform=transform,
    show_joint_entity_colors=False,
    show_joint_equivalent_colors=False,
    return_vertex_normals=True
)
p = mp.plot(v, f, c=c)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, -2.2…

## Joint Pose Search
Perform joint pose search to find the offset, rotation, and flip parameters.

In [11]:
seed = 24
random_state = np.random.RandomState(seed)
# Nelder–Mead Simplex Search as used in the paper
search = SearchSimplex(random_state=random_state)
# Random search can also be used
# search = SearchRandom(random_state=random_state, budget=500)

result = search.search(jps)
# Returns a dict with:
# - prediction_index: Index of the prediction from the network, 0 being the highest probability
# - offset: Offset parameter
# - rotation: Rotation parameter
# - flip: Flip parameter
# - transform: Transform created from the axis and parameters, apply this transform to body 2 will assemble the parts together
# - evaluation: Evaluation score where lower is better
# - overlap: Overlap between the parts
# - contact: Contact between the parts
result

{'prediction_index': 21,
 'offset': 0.0,
 'rotation': 0.0001875,
 'flip': True,
 'transform': array([[ 2.15704927e-17,  2.77556752e-01, -9.60709243e-01,
          5.39262319e-17],
        [ 7.46621065e-17,  9.60709243e-01,  2.77556752e-01,
          1.86655266e-16],
        [-1.00000000e+00,  7.77156117e-17,  0.00000000e+00,
         -4.50000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          1.00000000e+00]]),
 'evaluation': -3.6459429215792545,
 'overlap': 0.0,
 'contact': 0.36459429215792544}

In [12]:
# Visualized the best result
transform = jps.get_transform(body=2) @ result["transform"]
v, f, c, e, n, ni = jps.get_meshes(
    apply_transform=True,
    body_one_transform=transform,
    show_joint_entity_colors=False,
    show_joint_equivalent_colors=False,
    return_vertex_normals=True
)
p = mp.plot(v, f, c=c)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…