In [1]:
import os
import json
from tqdm import * 

data_folder = "../data/zw3d-joinable-dataset/"

In [2]:
def read_parts():
    prefix = "part"
    part_files = []

    for filename in os.listdir(data_folder):
        if filename.startswith(prefix) and filename.endswith(".json"):
            part_files.append(filename)

    part_map = {}
    for part_file in tqdm(part_files):
        with open(data_folder + part_file, "r", encoding="utf-8") as file:
            data = json.load(file)
        part_name, _ = part_file.split(".")
        part_map[part_name] = data
    return part_map

part_map = read_parts()

100%|██████████| 6462/6462 [00:32<00:00, 198.82it/s]


In [3]:
def load_data(split):
    with open(data_folder + "train_test.json", "r", encoding="utf-8") as file:
        train_test_splits = json.load(file)

    joint_set_list = []
    for filename in tqdm(train_test_splits[split]):
        with open(data_folder + filename + ".json", "r", encoding="utf-8") as file:
            joint_set = json.load(file)
        joint_set_list.append(joint_set)
    return joint_set_list

test_data = load_data("test")

100%|██████████| 1296/1296 [00:00<00:00, 13636.80it/s]


In [49]:
from collections import defaultdict
from itertools import product
import heapq

def find_nodes_with_radius(part):
    radius_to_nodes = defaultdict(list)
    for i, node in enumerate(part["nodes"]):
        if "surface_type" in node and node["surface_type"] in "CylinderSurfaceType":
            rounded_radius = round(node["param_1"], 1)
            radius_to_nodes[rounded_radius].append(i)
        # if "curve_type" in node and node["curve_type"] in ("Arc3DCurveType", "Circle3DCurveType"):
        #     rounded_radius = round(node["radius"], 1)
        #     radius_to_nodes[rounded_radius].append(i)
    return radius_to_nodes

def find_plane_nodes(part):
    plane_list = []
    for i, node in enumerate(part["nodes"]):
        if "surface_type" in node and node["surface_type"] == "PlaneSurfaceType":
            plane_list.append((node["area"], i))
    return plane_list

def predict(joint_set, result_num=50):
    res = []

    # 如果有半径相似的实体，加入结果集
    part_1, part_2 = part_map[joint_set["body_one"]], part_map[joint_set["body_two"]]
    radius_to_nodes1 = find_nodes_with_radius(part_1)
    radius_to_nodes2 = find_nodes_with_radius(part_2)

    # closest_pairs = ([], [])
    # if radius_to_nodes1 and radius_to_nodes2:
    #     min_diff = float("inf")
    #     for radius1, radius2 in product(radius_to_nodes1.keys(), radius_to_nodes2.keys()):
    #         diff = abs(radius1 - radius2)
    #         if diff < min_diff:
    #             min_diff = diff
    #             closest_pairs = (radius_to_nodes1[radius1], radius_to_nodes2[radius2])
    # for id1 in closest_pairs[0]:
    #     for id2 in closest_pairs[1]:
    #         res.append((id1, id2))
    #         if len(res) == result_num:
    #             return res
    
    top_3_min_diff = []
    if radius_to_nodes1 and radius_to_nodes2:
        for radius1, radius2 in product(radius_to_nodes1.keys(), radius_to_nodes2.keys()):
            radius_nodes1_nodes2 = (-abs(radius1 - radius2), radius_to_nodes1[radius1], radius_to_nodes2[radius2])
            heapq.heappush(top_3_min_diff, radius_nodes1_nodes2)
            if len(top_3_min_diff) > 3:
                heapq.heappop(top_3_min_diff)

    top_3_min_diff.sort(key=lambda x: -x[0])
    for _, nodes1, nodes2 in top_3_min_diff:
        res.append((nodes1[0], nodes2[0]))
    
    
    # 按从大到小的顺序找一对平面，加入结果集
    plane_list1 = find_plane_nodes(part_1)
    plane_list2 = find_plane_nodes(part_2)
    plane_pair_list = []
    for area1, id1 in plane_list1:
        for area2, id2 in plane_list2:
            plane_pair_list.append((area1 + area2, id1, id2))
            
    plane_pair_list.sort(key=lambda x: x[0], reverse=True)
    for _, id1, id2 in plane_pair_list:
        res.append((id1, id2))
        if len(res) == result_num:
            return res

    return res

def load_label(joint_set):
    part_1, part_2 = joint_set["body_one"], joint_set["body_two"]
    labels = []
    for joint in joint_set["joints"]:
        id_set1, id_set2 = set(), set()
        geom1 = joint["geometry_or_origin_one"]
        geom2 = joint["geometry_or_origin_two"]
        entity1 = geom1["entity_one"]
        entity2 = geom2["entity_one"]
        id1 = entity1["index"] 
        id2 = entity2["index"]
        if "surface_type" not in entity1:
            id1 += part_map[part_1]["properties"]["face_count"]
        if "surface_type" not in entity2:
            id2 += part_map[part_2]["properties"]["face_count"]
        id_set1.add(id1)
        id_set2.add(id2)

        for eq in geom1["entity_one_equivalents"]:
            id = eq["index"] 
            if "surface_type" not in eq:
                id += part_map[part_1]["properties"]["face_count"]
            id_set1.add(id)
        for eq in geom2["entity_one_equivalents"]:
            id = eq["index"] 
            if "surface_type" not in eq:
                id += part_map[part_2]["properties"]["face_count"]
            id_set2.add(id)
        labels.append((joint["joint_type"], id_set1, id_set2))
    return labels

In [50]:
def test(data_set):
    top1 = 0
    top5 = 0
    top50 = 0
    for data in data_set:
        preds = predict(data)
        labels = load_label(data)
        for i, pred in enumerate(preds):
            flag = False
            for label in labels:
                if pred[0] in label[1] and pred[1] in label[2]:
                    flag = True
                    break
            if flag:
                top50 += 1
                if i < 5:
                    top5 += 1
                if i == 0:
                    top1 += 1
                break
    print("Top 1:\t %f" % (top1 / len(data_set)))
    print("Top 5:\t %f" % (top5 / len(data_set)))
    print("Top 50:\t %f" % (top50 / len(data_set)))
    
test(test_data)

Top 1:	 0.168981
Top 5:	 0.337963
Top 50:	 0.683642
