In [None]:
import numpy as np
from stl import mesh
from sklearn.base import BaseEstimator, TransformerMixin
from scipy.spatial import KDTree

class ComplexSTLTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, num_points=1024, normalize=True, feature_type='geometry'):
        """
        Initialize the ComplexSTLTransformer.

        Parameters:
        - num_points (int): The number of points to sample from the mesh surface.
        - normalize (bool): Whether to normalize the mesh.
        - feature_type (str): The type of feature to extract ('geometry').
        """
        self.num_points = num_points
        self.normalize = normalize
        self.feature_type = feature_type

    def transform(self, X):
        """
        Transform the STL files into a numerical feature matrix (point cloud).

        Parameters:
        - X (list): A list of paths to STL files.

        Returns:
        - np.ndarray: A 3D array where each entry is a point cloud representation of an STL file.
        """
        point_clouds = []

        for stl_file_path in X:
            # Load the STL file
            dental_mesh = mesh.Mesh.from_file(stl_file_path)
            vertices = dental_mesh.vectors.reshape(-1, 3)  # (N, 3) array of vertices

            # Normalize the mesh if required
            if self.normalize:
                vertices = self._normalize_mesh(vertices)

            # Sample points from the mesh surface
            point_cloud = self._sample_points(vertices)

            # Extract additional features if needed
            if self.feature_type == 'geometry':
                point_cloud = self._augment_with_normals(point_cloud)

            point_clouds.append(point_cloud)

        return np.array(point_clouds)

    def _normalize_mesh(self, vertices):
        """
        Normalize the mesh to be centered and scaled to unit size.

        Parameters:
        - vertices (np.ndarray): Array of vertices.

        Returns:
        - np.ndarray: Normalized vertices.
        """
        # Center the mesh
        center = np.mean(vertices, axis=0)
        vertices -= center

        # Scale to unit size
        max_extent = np.max(np.linalg.norm(vertices, axis=1))
        vertices /= max_extent

        return vertices

    def _sample_points(self, vertices):
        """
        Sample points uniformly across the mesh surface.

        Parameters:
        - vertices (np.ndarray): Array of vertices.

        Returns:
        - np.ndarray: Sampled point cloud.
        """
        num_vertices = len(vertices)
        sampled_indices = np.random.choice(num_vertices, self.num_points, replace=True)
        point_cloud = vertices[sampled_indices]

        return point_cloud

    def _augment_with_normals(self, point_cloud):
        """
        Augment the point cloud with surface normals.

        Parameters:
        - point_cloud (np.ndarray): Array of sampled points.

        Returns:
        - np.ndarray: Point cloud with augmented features.
        """
        # Construct KDTree for nearest neighbors
        kdtree = KDTree(point_cloud)
        normals = np.zeros_like(point_cloud)

        # Estimate normals by local neighborhood
        for i, point in enumerate(point_cloud):
            distances, indices = kdtree.query(point, k=10)
            neighbors = point_cloud[indices]
            centroid = np.mean(neighbors, axis=0)
            cov_matrix = np.cov((neighbors - centroid).T)
            eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)
            normal = eigenvectors[:, np.argmin(eigenvalues)]
            normals[i] = normal

        # Concatenate point coordinates with normals
        point_cloud_with_normals = np.hstack((point_cloud, normals))

        return point_cloud_with_normals