In [1]:
import os 
import numpy as np
from scipy import spatial
import glob
from multiprocessing import Process
from tqdm import tqdm

## 一、自定义函数

### 1.获取模型信息

In [2]:
def get_edges(faces):
    """
    根据面得到相应的边
    @faces: 模型的所有面
    return: 模型的边
    """
    edge2key = dict()
    edges = []
    edges_count = 0
    for face_id, face in enumerate(faces):
        faces_edges = []
        for i in range(3):
            cur_edge = (face[i], face[(i + 1) % 3])
            faces_edges.append(cur_edge)
        for idx, edge in enumerate(faces_edges):
            edge = tuple(sorted(list(edge)))
            if edge not in edge2key:
                edge2key[edge] = edges_count
                edges_count += 1
                edges.append(list(edge))
    return edges


def parse_obje(obj_file):
    """
    解析obj文件， 获取点，边，面
    @obj_file: obj模型文件路径
    return: 模型的点，边，面信息
    """
    
    vs = []
    faces = []
    edges = []

    with open(obj_file) as f:
        for line in f:
            line = line.strip()
            splitted_line = line.split()
            if not splitted_line:
                continue
            elif splitted_line[0] == 'v':
                vs.append([float(v) for v in splitted_line[1:]])
            elif splitted_line[0] == 'f':
                try:
                    faces.append([int(c) - 1 for c in splitted_line[1:]])
                except ValueError:
                    faces.append([int(c.split('/')[0]) - 1 for c in splitted_line[1:]])                   
            elif splitted_line[0] == 'e':
                if len(splitted_line) >= 4:
                    edge_v = [int(c) - 1 for c in splitted_line[1:-1]]
                    edge_c = int(splitted_line[-1])
                    edge_v.append(edge_c)                 # class
                    edges.append(edge_v)           
            else:
                continue

    vs = np.array(vs)
    faces = np.array(faces, dtype=int)
    if len(edges) == 0:
        edges = get_edges(faces)
    edges = np.array(edges)
        
    return vs, faces, edges

### 2.根据边标记对面进行标记

In [3]:
def label_face_by_edge(faces, edges, edge_labels):
    """
    利用边标签对面进行标记
    @faces: 模型的面
    @edges: 模型的边
    @edge_labels: 模型边对应的标签
    return: 面的标签
    """
    edge_dict = {}    # key: str([pt1, pt2]) value: label
    for ei, edge in enumerate(edges):
        key = tuple(edge)
        edge_dict[key] = edge_labels[ei]
    # print(edge_dict)
    face_labels = np.array(len(faces) * [[-1, -1, -1]])
    for i, face in enumerate(faces):
        # faces_edges = []
        for j in range(3):
            cur_edge = [face[j], face[(j + 1) % 3]]
            cur_label = edge_dict[tuple(sorted(cur_edge))]
            face_labels[i][j] = cur_label
            
        # face_labels.append(faces_edges)
    face_labels = np.where(np.sum(face_labels, axis=1)<2, 1, 2)
    
    return face_labels

### 3.利用边对点进行标记

In [4]:
def label_pts_by_edges(vs, edges, edge_labels):
    """
    根据边标签，对点进行标注
    @vs: 模型的点
    @edge: 模型的边
    @edge_labels: 模型边对应的标签
    return: 模型点的标签
    """
    pts_labels = np.array(len(vs) * [[-1, -1]])
    for ei, edge in enumerate(edges):
        edge_label = edge_labels[ei]
        pt1 = edge[0]
        pt2 = edge[1]
        pts_labels[pt1][edge_label] = edge_label
        pts_labels[pt2][edge_label] = edge_label
    
    return pts_labels

### 4.边标签投影到原始模型

In [5]:
def label_origin_edge(predict_edges, predict_labels, predict_vs, origin_edges, origin_vs):
    """
    根据预测的边及标签，对原始模型的边进行标注
    @predict_edges: 预测模型对应的边
    @predict_labels: 预测模型对应的标签
    @origin_edges: 原始模型的边
    return: 原始模型边对应的标签
    """
    predict_edge_pts = predict_vs[predict_edges].reshape(-1, 6)

    tree = spatial.KDTree(predict_edge_pts)

    origin_edge_pts = origin_vs[origin_edges].reshape(-1, 6)

    origin_labels = []
    for i, edge in enumerate(origin_edge_pts):
#         if i % 50000 == 0:
#             print(i, "is finded!")
        dist, idx = tree.query(edge)
        origin_labels.append(predict_labels[idx])
    
    return origin_labels

### 5.点投影到原模型

In [6]:
def project_points(predict_pts, origin_vs):
    """
    根据预测的边，筛选出边界点，将点投影回原模型
    @predict_pts: 边界点
    @origin_vs: 原始模型所有点
    return: 返回原始模型的边界点
    """
    tree = spatial.KDTree(origin_vs)
    
    origin_pts = []
    for i, pt in enumerate(predict_pts):
        dist, idx = tree.query(pt)
        origin_pts.append(origin_vs[idx])
        
    origin_pts = np.asarray(origin_pts)
    return origin_pts
    

### 6.分开保存模型 便于显示

In [7]:
def save_model_part(save_path, vs, faces, face_labels, model1_name="mesh1.obj", model2_name="mesh2.obj"):
    """
    根据标签将模型标记的部分分别保存
    @obj_vs: 模型的顶点
    @obj_faces: 模型的面
    @face_labels: 面的标签 
    return: None
    """
    mesh1 = open(os.path.join(save_path, model1_name), "w")
    mesh2 = open(os.path.join(save_path, model2_name), "w")
    for v in vs:
        mesh1.write("v " + str(v[0]) + " " + str(v[1]) + " " + str(v[2]) + "\n")
        mesh2.write("v " + str(v[0]) + " " + str(v[1]) + " " + str(v[2]) + "\n")

    for idx, face in enumerate(faces):
        if face_labels[idx] == 1:
            mesh1.write("f " + str(face[0]+1) + " " + str(face[1]+1) + " " + str(face[2]+1) + "\n")
        if face_labels[idx] == 2:
            mesh2.write("f " + str(face[0]+1) + " " + str(face[1]+1) + " " + str(face[2]+1) + "\n")

    mesh1.close()
    mesh2.close()

### 7. 导出边界点

In [8]:
def save_pts_to_vtk(pts, save_path="./test.vtk"):
    """
    将牙龈点pts格式转为vtk格式
    @pts: 点集 [[x, y, z], [x, y, z], ...]
    @save_path: 保存路径
    return: None
    """
    import vtkplotter as vtkp
    vtk_point = vtkp.Points(pts.reshape(-1, 3))
    vtkp.write(vtk_point, save_path, binary=False)
#     print("vtk file is saved in ", save_path)

### 8.导出模型

In [9]:
def export(file, vs, edges, faces, labels):
    """
    将模型的点，边，面及边的标签保存
    @file: 保存路径
    @vs: 模型的点
    @edges: 模型的边
    @faces: 模型的面
    @labels: 模型边的标签
    """
    with open(file, 'w+') as f:
        for vi, v in enumerate(vs):
            vcol = ' %f %f %f' % (vcolor[vi, 0], vcolor[vi, 1], vcolor[vi, 2]) if vcolor is not None else ''
            f.write("v %f %f %f%s\n" % (v[0], v[1], v[2], vcol))
        for face_id in range(len(faces) - 1):
            f.write("f %d %d %d\n" % (faces[face_id][0] , faces[face_id][1], faces[face_id][2] ))
        for ei, edge in enumerate(edges):
            f.write("\ne %d %d %d" % (edge[0] , edge[1], labels[ei]))

## 二、主函数


In [10]:
def show_predict(predict_model, origin_model, save_path):
    """
    对预测的模型进行分析，找出牙龈线点，并对原始模型进行分割
    @predict_model: 预测的模型
    @origin_model: 预测对应的原始模型
    @save_path: 结果保存路径
    return: None
    """
    # ------加载模型 获取信息------
    ## 预测模型
    predict_vs, predict_faces, predict_edges = parse_obje(predict_model)
    predict_labels = predict_edges[:, -1]
    predict_edges = predict_edges[:, :-1]
    
    ## 原始模型
    origin_vs, origin_faces, origin_edges = parse_obje(origin_model)
    
    # ------处理预测模型------
    ## 找预测的牙龈线点
    predict_pts_labels = label_pts_by_edges(predict_vs, predict_edges, predict_labels)
    predict_gum_pt_ids = np.where((predict_pts_labels[:,0]==0) & (predict_pts_labels[:,1]==1))[0]
    predict_gum_pts = predict_vs[predict_gum_pt_ids]
    
    save_pts_to_vtk(predict_gum_pts, os.path.join(save_path, "predict.vtk"))
    
    ## 标记预测的面
    predict_face_labels = label_face_by_edge(predict_faces, predict_edges, predict_labels)
    save_model_part(save_path, predict_vs, predict_faces, predict_face_labels, "predict1.obj", "predict2.obj")
    
#     # ------原始模型处理------
#     ## 投影点
#     origin_gum_pts = project_points(predict_gum_pts, origin_vs)
#     save_pts_to_vtk(origin_gum_pts, os.path.join(save_path, "origin1.vtk"))
    
#     ## 投影边
#     origin_labels = label_origin_edge(predict_edges, predict_labels, predict_vs, origin_edges, origin_vs)
#     origin_face_labels = label_face_by_edge(origin_faces, origin_edges, origin_labels)
#     save_model_part(save_path, origin_vs, origin_faces, origin_face_labels,  "origin1.obj", "origin2.obj")
    
#     origin_pts_labels = label_pts_by_edges(origin_vs, origin_edges, origin_labels)
#     origin_gum_pt_ids = np.where((origin_pts_labels[:,0]==0) & (origin_pts_labels[:,1]==1))[0]
#     origin_gum_pts = origin_vs[origin_gum_pt_ids]
#     save_pts_to_vtk(origin_gum_pts, os.path.join(save_path, "origin2.vtk"))

### test

In [11]:
# predict_model = "./1R5DH_VS_SET_VSc1_Subsetup14_Mandibular_predict.obj"
# origin_model = "./1R5DH_VS_SET_VSc1_Subsetup14_Mandibular_origin.obj"
# show_predict(predict_model, origin_model, "./")

#### 导出模型

In [12]:
# export("./export.obj", origin_vs, origin_edges, origin_faces, origin_labels)

## 三、批量处理

In [11]:
def show_predict_batch(predict_model_list, predict_path, origin_path):
    """
    批量处理预测模型
    @predict_model_list: 预测的模型列表
    @predict_path: 预测模型存放路径
    @origin_path: 原始模型存放路径
    return: None
    """
    for i, predict_model in enumerate(tqdm(predict_model_list)):
        origin_model_basename = os.path.basename(predict_model)[:-6]
        origin_model_name = origin_model_basename + ".obj"
        origin_model = os.path.join(origin_path, origin_model_name)
        if not os.path.isfile(origin_model):
            print(origin_model, "not find!")
            continue

        save_dir = os.path.join(predict_path, origin_model_basename)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        show_predict(predict_model, origin_model, save_dir)

In [12]:

def parallel_show_predict(model_list, predict_path, origin_path, n_workers=8):
    """
    多进程处理
    """
    if len(model_list) < n_workers:
        n_workers = len(model_list)
    chunk_len = len(model_list) // n_workers

    chunk_lists = [model_list[i:i+chunk_len] for i in range(0, (n_workers-1)*chunk_len, chunk_len)]
    chunk_lists.append(model_list[(n_workers - 1)*chunk_len:])
    
    process_list = [Process(target=show_predict_batch, args=(chunk_list, predict_path, origin_path, )) for chunk_list in chunk_lists]
    for process in process_list:
        process.start()
    for process in process_list:
        process.join()

In [13]:
predict_path = "/home/heygears/work/Tooth_data_prepare/model_test/predict_mesh_obj_20201210_100/"
origin_path = "/home/heygears/work/Tooth_data_prepare/model_test/origin_mesh_obj/"

predict_model_list = glob.glob(os.path.join(predict_path, "*.obj"))
parallel_show_predict(predict_model_list, predict_path, origin_path, n_workers=8)

100%|██████████| 8/8 [00:20<00:00,  2.62s/it]]
100%|██████████| 8/8 [00:20<00:00,  2.62s/it]
100%|██████████| 8/8 [00:22<00:00,  2.79s/it]]
100%|██████████| 8/8 [00:23<00:00,  2.91s/it]
100%|██████████| 8/8 [00:23<00:00,  2.95s/it]
100%|██████████| 8/8 [00:24<00:00,  3.01s/it]
100%|██████████| 8/8 [00:25<00:00,  3.15s/it]]
100%|██████████| 14/14 [00:31<00:00,  2.22s/it]
