In [2]:
import streamlit as st
import os
import tempfile
from vedo import load, screenshot, show


In [87]:
base_name ="21 LowerJawScan"
vtp_path = os.path.join("uploads/lower", base_name + ".vtp")
ply_path = os.path.join("uploads/lower/processed", base_name + ".ply")

In [88]:
import os
import numpy as np
import open3d as o3d
import tensorflow as tf
from vedo import load, write
import copy

INCISOR_LABELS = [ 6, 7, 8, 9]
TOOTH_NAMES = {
    #5: "Left Canine (5)",
    6: "Left Lateral (6)",
    7: "Left Central (7)",
    8: "Right Central (8)",
    9: "Right Lateral (9)",
    #10: "Right Canine (10)",
}

def load_model(model_path="predict/models/gumrecession_model_incisors.h5"):
    print(f"🔄 Loading model from: {model_path}")
    return tf.keras.models.load_model(model_path)

def get_toothface_gums(vtp_mesh, label_no):
    label = vtp_mesh.celldata['Label']
    faces = vtp_mesh.cells
    ind = np.where(label == label_no)[0].tolist()

    new_face = [faces[i] for i in ind]

    cutmesh = vtp_mesh.clone().threshold("Label", above=label_no - 0.1, below=label_no + 0.1, on='cells')
    if cutmesh.vertices.size == 0:
        return []

    xmax, ymax, zmax = cutmesh.vertices.max(axis=0)
    xmin, ymin, zmin = cutmesh.vertices.min(axis=0)

    m_xmax, m_ymax, m_zmax = vtp_mesh.vertices.max(axis=0)
    m_xmin, m_ymin, m_zmin = vtp_mesh.vertices.min(axis=0)

    x_b = (xmin, xmax)
    y_b = (m_ymin, m_ymax)
    z_b = (m_zmin, zmax)

    z_ids = vtp_mesh.find_cells_in_bounds(zbounds=z_b)
    y_ids = vtp_mesh.find_cells_in_bounds(ybounds=y_b)
    x_ids = vtp_mesh.find_cells_in_bounds(xbounds=x_b)

    inter = list(set(z_ids) & set(y_ids) & set(x_ids))
    new_face += [faces[i] for i in inter]
    return new_face

def extract_teeth_with_gums(vtp_path, ply_path, save_to="uploads/lower/debug"):
    print(f"🔍 Loading VTP: {vtp_path}")
    print(f"🔍 Loading PLY: {ply_path}")

    pred_mesh = load(vtp_path)
    full_mesh = o3d.io.read_triangle_mesh(ply_path)

    meshes_by_label = {}
    os.makedirs(save_to, exist_ok=True)

    for label in INCISOR_LABELS:
        print(f"✂️ Extracting Tooth {label} with surrounding tissue...")
        tooth_faces = get_toothface_gums(pred_mesh, label)
        if not tooth_faces:
            meshes_by_label[label] = None
            print(f"⚠️ Tooth {label} not found")
            continue

        mesh_cropped = copy.deepcopy(full_mesh)
        mesh_cropped.triangles = o3d.utility.Vector3iVector(np.asarray(tooth_faces))
        debug_path = os.path.join(save_to, f"tooth_{label}.ply")
        o3d.io.write_triangle_mesh(debug_path, mesh_cropped)
        print(f"💾 Saved to: {debug_path}")
        pcd = o3d.geometry.PointCloud()
        pcd.points = mesh_cropped.vertices
        # Optional: transfer colors if available
        if mesh_cropped.has_vertex_colors():
            pcd.colors = mesh_cropped.vertex_colors
        meshes_by_label[label] = pcd
        #meshes_by_label[label] = mesh_cropped

    return meshes_by_label


In [89]:
debug_folder = "uploads/lower/debug"
meshes = extract_teeth_with_gums(vtp_path, ply_path, save_to=debug_folder)

🔍 Loading VTP: uploads/lower\21 LowerJawScan.vtp
🔍 Loading PLY: uploads/lower/processed\21 LowerJawScan.ply
✂️ Extracting Tooth 6 with surrounding tissue...
💾 Saved to: uploads/lower/debug\tooth_6.ply
✂️ Extracting Tooth 7 with surrounding tissue...
💾 Saved to: uploads/lower/debug\tooth_7.ply
✂️ Extracting Tooth 8 with surrounding tissue...
💾 Saved to: uploads/lower/debug\tooth_8.ply
✂️ Extracting Tooth 9 with surrounding tissue...
💾 Saved to: uploads/lower/debug\tooth_9.ply


In [90]:
type(meshes[6])

open3d.cpu.pybind.geometry.PointCloud

In [91]:
len(meshes)

4

In [92]:
meshes[6]

PointCloud with 5331 points.

In [93]:

mesh=meshes[6]

In [94]:
meshes[6].colors

std::vector<Eigen::Vector3d> with 5331 elements.
Use numpy.asarray() to access data.

In [95]:
#mesh.compute_vertex_normals()

# Visualize it
o3d.visualization.draw_geometries([meshes[6]], window_name="PLY Viewer")

In [96]:
def load_model(model_path="models/gumrecession_model_incisors.h5"):
    print(f"🔄 Loading model from: {model_path}")
    return tf.keras.models.load_model(model_path)

In [97]:
model = load_model()

🔄 Loading model from: models/gumrecession_model_incisors.h5


In [98]:
def preprocess_mesh_for_classification(mesh, num_points=5000):
    #pcd = o3d.geometry.PointCloud()
    #pcd.points = mesh.vertices if hasattr(mesh, 'vertices') else mesh.points
    pcd = mesh
    #print(mesh)

    if len(pcd.points) < num_points:
        print(f"⚠️ Not enough points in mesh: {len(pcd.points)}")
        return None

    #downpcd = pcd.uniform_down_sample(max(1, len(pcd.points) // num_points))
    downpcd = pcd.farthest_point_down_sample(5000)
    #if len(downpcd.points) > num_points:
        #downpcd = o3d.geometry.PointCloud()
        #downpcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points)[:num_points])

    colors=np.asarray(downpcd.colors)
    print(colors.shape)
    points = downpcd.points - np.mean(downpcd.points, axis=0)# Centralise points
    #points = np.asarray(downpcd.points)
    points = points / np.max(np.linalg.norm(points, axis=1)) # Nomralise points 
    points=np.asarray(points)
    print(points.shape)
    #points = points - points.mean(axis=0)
    #points = points / np.max(np.linalg.norm(points, axis=1))
    #colors = np.zeros_like(points)
    array = np.concatenate((colors, points), axis=1)
    return array

In [99]:
def predict_gum_recession(meshes_by_label, model):
    results = {}
    for label in INCISOR_LABELS:
        print(f"\n🦷 Predicting for tooth {label} ({TOOTH_NAMES[label]})...")
        mesh = meshes_by_label.get(label)
        if mesh is None:
            results[label] = "❌ Not found"
            continue

        print("🔍 Preprocessing tooth mesh...")
        array = preprocess_mesh_for_classification(mesh)
        if array is None or array.shape[0] < 500:
            results[label] = "⚠️ Insufficient points"
            continue

        prediction = model.predict(array[np.newaxis, :, :])[0][0]
        print(prediction)
        results[label] = "✅ Present" if prediction >= 0.5 else "🟢 Absent"
    return results

In [100]:
preprocess_mesh_for_classification(meshes[7], num_points=5000)

(5000, 3)
(5000, 3)


array([[ 0.56470588,  0.41960784,  0.39215686,  0.42891206, -0.35342451,
        -0.43977056],
       [ 0.54901961,  0.40392157,  0.38039216,  0.43505485, -0.34872041,
        -0.45144927],
       [ 0.57254902,  0.41568627,  0.38431373,  0.43671856, -0.34144004,
        -0.42170593],
       ...,
       [ 0.63529412,  0.47843137,  0.48235294,  0.3856192 , -0.27489826,
        -0.20898518],
       [ 0.63529412,  0.47843137,  0.48235294,  0.48063763, -0.24835366,
        -0.68181503],
       [ 0.63529412,  0.47843137,  0.48235294, -0.25216674, -0.14307584,
        -0.15214299]])

In [101]:
array = preprocess_mesh_for_classification(meshes[6])

(5000, 3)
(5000, 3)


In [102]:
prediction = model.predict(array[np.newaxis, :, :])[0][0]
print(prediction)

0.9977261
