<a href="https://colab.research.google.com/github/sam-spears/ML_NeckSpace/blob/main/ML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup and Import

In [1]:
# Run these in Colab first:
!pip install trimesh
!pip install torch-geometric
!pip install pyvista
!pip install scikit-learn
!pip install pyglet==1.5.27  # for visualization

Collecting trimesh
  Downloading trimesh-4.8.1-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.8.1-py3-none-any.whl (728 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m728.5/728.5 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.8.1
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Collecting pyvista
  Downloading pyvista-0.46.3-py3-none-any.whl.metadata (15 kB)
Collecting vtk!=9.4.0 (from pyvista)
  Downloading vtk-9.5.1-cp312-cp312-manylinux2014_x86_

In [3]:
!git clone https://github.com/sam-spears/ML_NeckSpace.git

Cloning into 'ML_NeckSpace'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (58/58), done.[K
remote: Total 61 (delta 3), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (61/61), 8.67 MiB | 12.89 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [2]:
import numpy as np
import trimesh
import os
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import pandas as pd
from tqdm import tqdm
import pickle

In [7]:
class MeshDataProcessor:
    def __init__(self, raw_dir, training_dir):
        self.raw_dir = raw_dir
        self.training_dir = training_dir
        self.scaler = StandardScaler()

    def load_stl(self, filepath):
        """Load STL file using trimesh"""
        mesh = trimesh.load(filepath)
        return mesh

    def extract_vertex_features(self, mesh):
        """Extract features for each vertex in the mesh"""
        features = []

        # 1. Vertex coordinates (normalized)
        vertices = mesh.vertices
        centroid = vertices.mean(axis=0)
        vertices_centered = vertices - centroid
        max_dist = np.max(np.linalg.norm(vertices_centered, axis=1))
        vertices_normalized = vertices_centered / max_dist

        # 2. Vertex normals
        vertex_normals = mesh.vertex_normals

        # 3. Curvature features (mean and Gaussian curvature)
        # Approximate curvature using neighborhood
        curvatures = self.compute_curvatures(mesh)

        # 4. Distance from centroid
        distances = np.linalg.norm(vertices_centered, axis=1).reshape(-1, 1)

        # 5. Local density (number of neighbors within radius)
        density = self.compute_local_density(mesh)

        # Combine all features
        features = np.hstack([
            vertices_normalized,  # 3 features
            vertex_normals,       # 3 features
            curvatures,          # 2 features
            distances,           # 1 feature
            density             # 1 feature
        ])  # Total: 10 features per vertex

        return features

    def compute_curvatures(self, mesh):
        """Compute approximate curvature for each vertex"""
        # Simplified curvature computation
        # In practice, you might want to use more sophisticated methods
        n_vertices = len(mesh.vertices)
        mean_curvature = np.zeros(n_vertices)
        gaussian_curvature = np.zeros(n_vertices)

        for i, vertex in enumerate(mesh.vertices):
            # Get neighboring vertices
            neighbors = mesh.vertex_neighbors[i]
            if len(neighbors) > 0:
                # Approximate mean curvature using normal variation
                neighbor_normals = mesh.vertex_normals[neighbors]
                normal_variation = np.std(neighbor_normals, axis=0).mean()
                mean_curvature[i] = normal_variation

                # Approximate Gaussian curvature using angle deficit
                angles = []
                for j in range(len(neighbors)):
                    v1 = mesh.vertices[neighbors[j]] - vertex
                    v2 = mesh.vertices[neighbors[(j+1)%len(neighbors)]] - vertex
                    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1, 1))
                    angles.append(angle)
                gaussian_curvature[i] = 2 * np.pi - sum(angles) if angles else 0

        return np.column_stack([mean_curvature, gaussian_curvature])

    def compute_local_density(self, mesh, radius=0.1):
        """Compute local vertex density"""
        density = np.zeros(len(mesh.vertices))
        for i, vertex in enumerate(mesh.vertices):
            distances = np.linalg.norm(mesh.vertices - vertex, axis=1)
            density[i] = np.sum(distances < radius)
        return density.reshape(-1, 1) / len(mesh.vertices)

    def create_labels(self, raw_mesh, isolated_mesh, threshold=0.01):
        """Create binary labels for vertices (1 = in isolated region, 0 = not)"""
        labels = np.zeros(len(raw_mesh.vertices))

        # For each vertex in raw mesh, check if it's close to any vertex in isolated mesh
        isolated_vertices = isolated_mesh.vertices
        for i, vertex in enumerate(raw_mesh.vertices):
            distances = np.linalg.norm(isolated_vertices - vertex, axis=1)
            if np.min(distances) < threshold:
                labels[i] = 1

        return labels

    def prepare_dataset(self):
        """Prepare complete dataset from all available meshes"""
        X_all = []
        y_all = []
        mesh_ids = []

        # Get list of training samples (with isolated regions)
        training_files = [f for f in os.listdir(self.training_dir) if f.endswith('_isolated.stl')]

        print(f"Processing {len(training_files)} training samples...")

        for training_file in tqdm(training_files):
            # Extract base name (e.g., nhp_0001)
            base_name = training_file.replace('_isolated.stl', '')
            raw_file = f"{base_name}.stl"

            # Load meshes
            raw_path = os.path.join(self.raw_dir, raw_file)
            isolated_path = os.path.join(self.training_dir, training_file)

            if not os.path.exists(raw_path):
                print(f"Warning: {raw_path} not found, skipping...")
                continue

            raw_mesh = self.load_stl(raw_path)
            isolated_mesh = self.load_stl(isolated_path)

            # Extract features and labels
            features = self.extract_vertex_features(raw_mesh)
            labels = self.create_labels(raw_mesh, isolated_mesh)

            X_all.append(features)
            y_all.append(labels)
            mesh_ids.extend([base_name] * len(features))

        # Combine all data
        X = np.vstack(X_all)
        y = np.hstack(y_all)

        # Normalize features
        X = self.scaler.fit_transform(X)

        return X, y, mesh_ids

# ============================================
# STEP 3: Machine Learning Models
# ============================================

class RegionDetector:
    def __init__(self):
        self.models = {}
        self.results = {}

    def train_random_forest(self, X_train, y_train, X_test, y_test):
        """Train Random Forest classifier"""
        print("\nTraining Random Forest...")
        rf = RandomForestClassifier(
            n_estimators=100,
            max_depth=20,
            min_samples_split=5,
            min_samples_leaf=2,
            class_weight='balanced',  # Handle imbalanced classes
            n_jobs=-1,
            random_state=42
        )

        rf.fit(X_train, y_train)

        # Evaluate
        train_score = rf.score(X_train, y_train)
        test_score = rf.score(X_test, y_test)

        # Feature importance
        feature_importance = rf.feature_importances_

        self.models['random_forest'] = rf
        self.results['random_forest'] = {
            'train_accuracy': train_score,
            'test_accuracy': test_score,
            'feature_importance': feature_importance
        }

        print(f"Random Forest - Train: {train_score:.3f}, Test: {test_score:.3f}")
        return rf

    def train_gradient_boosting(self, X_train, y_train, X_test, y_test):
        """Train Gradient Boosting classifier"""
        from sklearn.ensemble import GradientBoostingClassifier

        print("\nTraining Gradient Boosting...")
        gb = GradientBoostingClassifier(
            n_estimators=100,
            learning_rate=0.1,
            max_depth=5,
            random_state=42
        )

        gb.fit(X_train, y_train)

        train_score = gb.score(X_train, y_train)
        test_score = gb.score(X_test, y_test)

        self.models['gradient_boosting'] = gb
        self.results['gradient_boosting'] = {
            'train_accuracy': train_score,
            'test_accuracy': test_score
        }

        print(f"Gradient Boosting - Train: {train_score:.3f}, Test: {test_score:.3f}")
        return gb

    def train_neural_network(self, X_train, y_train, X_test, y_test):
        """Train a simple neural network"""
        print("\nTraining Neural Network...")

        # Convert to PyTorch tensors
        X_train_torch = torch.FloatTensor(X_train)
        y_train_torch = torch.LongTensor(y_train)
        X_test_torch = torch.FloatTensor(X_test)
        y_test_torch = torch.LongTensor(y_test)

        # Define model
        class SimpleNN(nn.Module):
            def __init__(self, input_dim):
                super(SimpleNN, self).__init__()
                self.fc1 = nn.Linear(input_dim, 128)
                self.fc2 = nn.Linear(128, 64)
                self.fc3 = nn.Linear(64, 32)
                self.fc4 = nn.Linear(32, 2)
                self.dropout = nn.Dropout(0.2)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                x = self.dropout(x)
                x = F.relu(self.fc2(x))
                x = self.dropout(x)
                x = F.relu(self.fc3(x))
                x = self.fc4(x)
                return x

        model = SimpleNN(X_train.shape[1])
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        # Training loop
        epochs = 100
        batch_size = 256

        for epoch in range(epochs):
            model.train()
            for i in range(0, len(X_train_torch), batch_size):
                batch_X = X_train_torch[i:i+batch_size]
                batch_y = y_train_torch[i:i+batch_size]

                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()

            if (epoch + 1) % 20 == 0:
                model.eval()
                with torch.no_grad():
                    train_outputs = model(X_train_torch)
                    train_pred = torch.argmax(train_outputs, dim=1)
                    train_acc = (train_pred == y_train_torch).float().mean()

                    test_outputs = model(X_test_torch)
                    test_pred = torch.argmax(test_outputs, dim=1)
                    test_acc = (test_pred == y_test_torch).float().mean()

                print(f"Epoch {epoch+1}/{epochs} - Train: {train_acc:.3f}, Test: {test_acc:.3f}")

        self.models['neural_network'] = model
        self.results['neural_network'] = {
            'train_accuracy': train_acc.item(),
            'test_accuracy': test_acc.item()
        }

        return model

# ============================================
# STEP 4: Prediction and Extraction
# ============================================

class MeshRegionExtractor:
    def __init__(self, model, scaler):
        self.model = model
        self.scaler = scaler
        self.processor = MeshDataProcessor(None, None)

    def predict_region(self, mesh_path, confidence_threshold=0.5):
        """Predict which vertices belong to the target region"""
        mesh = trimesh.load(mesh_path)

        # Extract features
        features = self.processor.extract_vertex_features(mesh)
        features = self.scaler.transform(features)

        # Predict
        if isinstance(self.model, torch.nn.Module):
            self.model.eval()
            with torch.no_grad():
                features_torch = torch.FloatTensor(features)
                outputs = self.model(features_torch)
                probabilities = F.softmax(outputs, dim=1)[:, 1].numpy()
        else:
            probabilities = self.model.predict_proba(features)[:, 1]

        # Apply threshold
        predicted_mask = probabilities > confidence_threshold

        return mesh, predicted_mask, probabilities

    def extract_region(self, mesh_path, output_path=None, confidence_threshold=0.5):
        """Extract and save the predicted region"""
        mesh, mask, probabilities = self.predict_region(mesh_path, confidence_threshold)

        # Get faces that have at least 2 vertices in the predicted region
        face_mask = np.zeros(len(mesh.faces), dtype=bool)
        for i, face in enumerate(mesh.faces):
            if np.sum(mask[face]) >= 2:
                face_mask[i] = True

        # Create new mesh with selected faces
        extracted_mesh = mesh.submesh([face_mask], append=True)

        if output_path:
            extracted_mesh.export(output_path)
            print(f"Extracted region saved to {output_path}")

        return extracted_mesh, probabilities

# ============================================
# STEP 5: Main Training Pipeline
# ============================================

def main_pipeline(raw_dir, training_dir):
    """Complete training pipeline"""

    # 1. Prepare data
    print("=" * 50)
    print("STEP 1: Loading and processing data...")
    print("=" * 50)

    processor = MeshDataProcessor(raw_dir, training_dir)
    X, y, mesh_ids = processor.prepare_dataset()

    print(f"\nDataset shape: {X.shape}")
    print(f"Positive samples: {np.sum(y)} ({np.sum(y)/len(y)*100:.1f}%)")
    print(f"Negative samples: {len(y) - np.sum(y)} ({(len(y)-np.sum(y))/len(y)*100:.1f}%)")

    # 2. Split data by mesh (not by vertex) to avoid data leakage
    unique_meshes = list(set(mesh_ids))
    train_meshes, test_meshes = train_test_split(unique_meshes, test_size=0.2, random_state=42)

    train_mask = np.isin(mesh_ids, train_meshes)
    test_mask = np.isin(mesh_ids, test_meshes)

    X_train, y_train = X[train_mask], y[train_mask]
    X_test, y_test = X[test_mask], y[test_mask]

    print(f"\nTrain set: {X_train.shape[0]} vertices from {len(train_meshes)} meshes")
    print(f"Test set: {X_test.shape[0]} vertices from {len(test_meshes)} meshes")

    # 3. Train models
    print("\n" + "=" * 50)
    print("STEP 2: Training models...")
    print("=" * 50)

    detector = RegionDetector()

    # Train different models
    rf_model = detector.train_random_forest(X_train, y_train, X_test, y_test)
    gb_model = detector.train_gradient_boosting(X_train, y_train, X_test, y_test)
    nn_model = detector.train_neural_network(X_train, y_train, X_test, y_test)

    # 4. Evaluate and compare
    print("\n" + "=" * 50)
    print("STEP 3: Model Comparison")
    print("=" * 50)

    results_df = pd.DataFrame(detector.results).T
    print("\n", results_df[['train_accuracy', 'test_accuracy']])

    # 5. Save best model
    best_model_name = results_df['test_accuracy'].idxmax()
    best_model = detector.models[best_model_name]

    print(f"\nBest model: {best_model_name}")

    # Save model and scaler
    with open('best_model.pkl', 'wb') as f:
        pickle.dump(best_model, f)
    with open('scaler.pkl', 'wb') as f:
        pickle.dump(processor.scaler, f)

    print("\nModel saved successfully!")

    return best_model, processor.scaler

# ============================================
# STEP 6: Usage Example
# ============================================

if __name__ == "__main__":
    # Set your GitHub repository paths
    RAW_DIR = "/content/ML_NeckSpace/dsa"  # Update with your path
    TRAINING_DIR = "/content/ML_NeckSpace/TRAINING"  # Update with your path

    # Train the model
    model, scaler = main_pipeline(RAW_DIR, TRAINING_DIR)

STEP 1: Loading and processing data...
Processing 20 training samples...


100%|██████████| 20/20 [02:21<00:00,  7.09s/it]



Dataset shape: (181295, 10)
Positive samples: 13327.0 (7.4%)
Negative samples: 167968.0 (92.6%)

Train set: 144558 vertices from 16 meshes
Test set: 36737 vertices from 4 meshes

STEP 2: Training models...

Training Random Forest...
Random Forest - Train: 0.993, Test: 0.978

Training Gradient Boosting...
Gradient Boosting - Train: 0.992, Test: 0.979

Training Neural Network...
Epoch 20/100 - Train: 0.985, Test: 0.967
Epoch 40/100 - Train: 0.989, Test: 0.974
Epoch 60/100 - Train: 0.991, Test: 0.977
Epoch 80/100 - Train: 0.992, Test: 0.977
Epoch 100/100 - Train: 0.993, Test: 0.976

STEP 3: Model Comparison

                   train_accuracy test_accuracy
random_forest           0.992965      0.978033
gradient_boosting       0.991733      0.978714
neural_network          0.993075      0.975991

Best model: gradient_boosting

Model saved successfully!


In [24]:
# Test on a new mesh
test_mesh_path = "/content/ML_NeckSpace/Test_Mesh/nhp_0021.stl"  # An unseen mesh (used nhp_0020 which is in dsa folder)
output_path = "/content/predicted_region.stl"

extractor = MeshRegionExtractor(model, scaler)
extracted_mesh, confidence_scores = extractor.extract_region(
    test_mesh_path,
    output_path,
    confidence_threshold=0.5
)

print(f"\nExtraction complete! Vertices in region: {len(extracted_mesh.vertices)}")

Extracted region saved to /content/predicted_region.stl

Extraction complete! Vertices in region: 809


In [25]:
import trimesh
import plotly.graph_objects as go
import numpy as np

# Load the extracted mesh
extracted_mesh_path = "/content/predicted_region.stl"
extracted_mesh = trimesh.load(extracted_mesh_path)

# Extract vertices and faces
vertices = extracted_mesh.vertices
faces = extracted_mesh.faces

# Create the Plotly figure
fig = go.Figure(data=[go.Mesh3d(
    x=vertices[:, 0],
    y=vertices[:, 1],
    z=vertices[:, 2],
    i=faces[:, 0],
    j=faces[:, 1],
    k=faces[:, 2],
    color='red',
    opacity=0.50)])

# Update layout for better visualization
fig.update_layout(
    scene = dict(
        xaxis = dict(visible=False),
        yaxis = dict(visible=False),
        zaxis = dict(visible=False),
        aspectmode='data' # important for keeping the aspect ratio
    ),
    margin=dict(l=0, r=0, t=0, b=0)
)

fig.show()