# Joint Prediction Viewer
Visualize predictions for joint entities

In [1]:
import sys
import json
from pathlib import Path
import torch
import numpy as np
import meshplot as mp

root_dir = Path().resolve().parent
if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))

from joint.joint_prediction_set import JointPredictionSet
from joint.joint_environment import JointEnvironment
from datasets.joint_graph_dataset import JointGraphDataset
from train import JointPrediction

## Load the Network
Load a pretrained checkpoint to use for inference

In [2]:
def load_network(checkpoint_file):
    """Load the network"""
    if not checkpoint_file.exists():
        print("Checkpoint file does not exist")
        return None
    model = JointPrediction.load_from_checkpoint(
        checkpoint_file,
        map_location=torch.device("cpu")  # Just use the CPU
    )
    return model

checkpoint_file = root_dir / "results/00_baseline/last.ckpt"
model = load_network(checkpoint_file)

## Load the Data
Load the dataset and create an instance of the JointPredictionSet class
We assume that the joint json and mesh part files are in the same directory

In [3]:
# Change to point to the Fusion 360 Gallery joint dataset
# this directory should contain the joint json and obj part files
data_dir = root_dir / "data/test"
dataset = JointGraphDataset(data_dir, label_scheme="Joint,JointEquivalent", delete_cache=True)

Data cache deleted from: /home/fusiqiao/Projects/JoinABLe/data/test/train.pickle
Using new train test split
Loading 38 train data


100%|██████████| 38/38 [00:03<00:00, 10.52it/s]


Total graph load time: 3.613908529281616 sec
Skipped: 0 files
Done loading 38 files
Data cache written to: /home/fusiqiao/Projects/JoinABLe/data/test/train.pickle


In [4]:
from random import randint

# Data sample in the dataset we want to visualize
# index = randint(0, len(dataset))
# print("index = ", index)
index = 0

# Graphs for part one and two, and the densely connected joint graph
g1, g2, joint_graph = dataset[index]
# The joint file json
joint_file = data_dir / dataset.files[index]

# Load the prediction data
jps = JointPredictionSet(
    joint_file,
    g1, g2, joint_graph,
    model
)

## Assemble the Top-1 Prediction
We use the JointEnvironment to calculate the transform that aligns the two parts together

In [None]:
# Get the transform to move body1, to align with the static body2
# using default parameters, i.e. no offset, rotation, or flip
transform = JointEnvironment.get_transform_from_parameters(
    jps,
    prediction_index=0,  # Top-1 prediction
    offset=0,
    rotation_in_degrees=0,
    flip=False
)

# Render the meshes to visualize
v, f, c, e, n, ni = jps.get_meshes(
    apply_transform=True,
    body_one_transform= jps.get_transform(body=2) @ transform,
    show_joint_entity_colors=False,
    show_joint_equivalent_colors=False,
    return_vertex_normals=True
)
p = mp.plot(v, f, c=c)


## Visualize Entity Predictions
Show the predictions as pink highlights on the joint bodies

In [None]:
body = 1
v1, f1, _, _ = jps.get_mesh(
    body=body,
    apply_transform=False,
    show_joint_entity_colors=False,
    show_joint_equivalent_colors=False
)
c1, e1 = jps.get_joint_predictions(body=body)
p = mp.plot(v1, f1, c=c1, shading={"colormap": "cool", "normalize": [0, 1]})
# TODO: Add support for edge colors
if e1 is not None:
    p.add_edges(v1, e1, shading={"line_color": "red"})

In [None]:
body = 2
v2, f2, _, _ = jps.get_mesh(
    body=body,
    apply_transform=False,
    show_joint_entity_colors=False,
    show_joint_equivalent_colors=False
)
c2, e2 = jps.get_joint_predictions(body=body)
p = mp.plot(v2, f2, c=c2, shading={"colormap": "cool", "normalize": [0, 1]})
# TODO: Add support for edge colors
if e2 is not None:
    p.add_edges(v2, e2, shading={"line_color": "red"})

## Visualize Entity Axes
Show the joint axes derived from the predicted B-Rep faces/edges

In [None]:
body = 1
p = mp.plot(v1, f1, c=c1, shading={"colormap": "cool", "normalize": [0, 1]})
if e1 is not None:
    p.add_edges(v1, e1, shading={"line_color": "red"})
start_pts, end_pts = jps.get_joint_prediction_axis_lines(body=body, limit=1)
# start_pts, end_pts = jps.get_joint_prediction_axis_lines(body=body, show_index=0)
p.add_lines(start_pts, end_pts, shading={"line_color": "green"})
p.add_points(start_pts, shading={"point_color": "green", "point_size": 1})

In [None]:
body = 2
p = mp.plot(v2, f2, c=c2, shading={"colormap": "cool", "normalize": [0, 1]})
if e2 is not None:
    p.add_edges(v2, e2, shading={"line_color": "red"})
start_pts, end_pts = jps.get_joint_prediction_axis_lines(body=body, limit=1)
# start_pts, end_pts = jps.get_joint_prediction_axis_lines(body=body, index=0)
p.add_lines(start_pts, end_pts, shading={"line_color": "green"})
p.add_points(start_pts, shading={"point_color": "green", "point_size": 1})