In [2]:
# ForAlign: An optimization algorithm for forest point cloud registration (Notebook Version)
# Coarse Alignment algorithm

import open3d as o3d
import numpy as np
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter

import pulp
from sklearn.neighbors import NearestNeighbors
from itertools import combinations
import rasterio
import hdbscan
from sklearn.decomposition import PCA

In [3]:
# ==========================
# I/O Handling Module
# ==========================
def load_pcd(file_path):
    """ Load a point cloud from a given file path. """
    return o3d.io.read_point_cloud(file_path)

def save_pcd(file_path, pcd):
    """ Save a point cloud to a specified file path. """
    o3d.io.write_point_cloud(file_path, pcd)

In [4]:
# ==========================
# DTM Processing Module
# ==========================
def generate_dtm(ground_pcd, resolution=0.5, sigma=1):
    """
    Generate a smoothed Digital Terrain Model (DTM) from ground points.

    Parameters:
    - ground_pcd: Open3D PointCloud object representing ground points.
    - resolution: Grid resolution (default: 0.5 meters).
    - sigma: Gaussian filter sigma for smoothing (default: 1).

    Returns:
    - x_mesh, y_mesh: Mesh grid for terrain coordinates.
    - z_values_smoothed: Smoothed elevation values for the terrain.
    """
    ground_points = np.asarray(ground_pcd.points)
    
    # Compute min/max bounds
    x_min, y_min = np.min(ground_points, axis=0)[:2]
    x_max, y_max = np.max(ground_points, axis=0)[:2]

    # Generate grid
    x_grid = np.arange(x_min, x_max, resolution, dtype=np.float64)
    y_grid = np.arange(y_min, y_max, resolution, dtype=np.float64)
    x_mesh, y_mesh = np.meshgrid(x_grid, y_grid)

    # Interpolate and smooth terrain data
    z_values = griddata(ground_points[:, :2], ground_points[:, 2], (x_mesh, y_mesh), method='linear')
    z_values_smoothed = gaussian_filter(z_values, sigma=sigma)

    return x_mesh, y_mesh, z_values_smoothed

# ==========================
# Normalized Height Calculation
# ==========================
def extract_normalized_height(offground_pcd, dtm_data):
    """
    Compute the normalized height of off-ground points by subtracting DTM elevation.

    Parameters:
    - offground_pcd: Open3D PointCloud object containing off-ground points.
    - dtm_data: Dictionary containing x_mesh, y_mesh, and smoothed elevation data.

    Returns:
    - normalized_elevation: Array of normalized height values.
    """
    offground_points = np.asarray(offground_pcd.points)
    x_mesh, y_mesh, z_values_smoothed = dtm_data["x_mesh"], dtm_data["y_mesh"], dtm_data["z_values_smoothed"]
    
    normalized_elevation = np.zeros(len(offground_points))

    for i, point in enumerate(offground_points):
        # Find closest grid index in DTM
        idx_x = (np.abs(x_mesh[0] - point[0])).argmin()
        idx_y = (np.abs(y_mesh[:, 0] - point[1])).argmin()
        
        # Compute relative height above ground
        normalized_elevation[i] = point[2] - z_values_smoothed[idx_y, idx_x]

    return normalized_elevation

# ==========================
# Tree Trunk Extraction Module
# ==========================
def filter_points_by_height(offground_pcd, normalized_elevation, elevation_min=4.5, elevation_max=5):
    """
    Extract tree trunk points by filtering off-ground points based on their normalized height.

    Parameters:
    - offground_pcd: Open3D PointCloud object containing off-ground points.
    - normalized_elevation: Array of height values for each point.
    - elevation_min: Minimum height threshold for filtering.
    - elevation_max: Maximum height threshold for filtering.

    Returns:
    - filtered_pcd: Open3D PointCloud object containing extracted tree trunk points.
    """
    offground_points = np.asarray(offground_pcd.points)
    indices_in_range = (normalized_elevation >= elevation_min) & (normalized_elevation <= elevation_max)
    filtered_points = offground_points[indices_in_range]
    
    filtered_pcd = o3d.geometry.PointCloud()
    filtered_pcd.points = o3d.utility.Vector3dVector(filtered_points)
    return filtered_pcd

In [5]:
# ==========================
# Tree Trunk Clustering with HDBSCAN
# ==========================
def apply_hdbscan(points, min_cluster_size=3, min_samples=100, cluster_selection_epsilon=0.5):
    """
    Perform HDBSCAN clustering on a point cloud.

    Parameters:
    - points: np.array, the set of points to cluster.
    - min_cluster_size: Minimum cluster size (default=3).
    - min_samples: Minimum samples per cluster (default=100).
    - cluster_selection_epsilon: Minimum spacing between clusters (default=0.5).

    Returns:
    - labels: np.array, cluster labels for each point.
    """
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster_size, 
        min_samples=min_samples, 
        cluster_selection_epsilon=cluster_selection_epsilon
    )
    labels = clusterer.fit_predict(points)
    return labels

# ==========================
# Cluster Analysis and Scoring
# ==========================
def analyze_clusters(points, labels):
    """
    Compute properties of each cluster, including point count, linearity, and vertical alignment.

    Parameters:
    - points: np.array, original point cloud data.
    - labels: np.array, cluster labels assigned by HDBSCAN.

    Returns:
    - clusters_data: dict, mapping cluster IDs to extracted properties.
    """
    clusters_data = {}
    total_points = len(points[labels != -1])  # Ignore noise
    z_axis = np.array([0, 0, 1])  # Reference Z-axis

    for label in np.unique(labels):
        if label == -1:
            continue  # Skip noise
        cluster_points = points[labels == label]
        count = len(cluster_points)
        normalized_count = count / total_points

        # Compute PCA to measure linearity and vertical alignment
        if len(cluster_points) > 1:
            pca = PCA(n_components=min(3, len(cluster_points)))
            pca.fit(cluster_points)
            linearity = pca.explained_variance_ratio_[0]
            pc1 = pca.components_[0]
            vertical_alignment = abs(np.dot(pc1, z_axis))
        else:
            linearity = 1
            vertical_alignment = 1 if np.dot(cluster_points[0], z_axis) > 0 else 0

        clusters_data[label] = {
            'count': count,
            'normalized_count': normalized_count,
            'linearity': linearity,
            'vertical_alignment': vertical_alignment
        }
    
    return clusters_data

In [6]:
def calculate_centroid(cluster_points):
    """ Calculate the centroid of a cluster. """
    if cluster_points.ndim == 1:
        return cluster_points
    return np.mean(cluster_points, axis=0)

def project_to_xy(centroid):
    """ Project a centroid to the XY plane. """
    return centroid[:2]

def calculate_angle(v1, v2):
    """ Calculate the angle in radians between vectors 'v1' and 'v2'. """
    cos_theta = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
    angle = np.arccos(np.clip(cos_theta, -1, 1))  # Clip for numerical stability
    return angle

# ==========================
# Graph Construction for Tree Matching
# ==========================
def construct_graph_features(clusters, k=5):
    """
    Construct a kNN-based graph from tree trunk cluster centroids.

    Parameters:
    - clusters: list of np.array, tree trunk clusters.
    - k: int, number of nearest neighbors (default=5).

    Returns:
    - graph_features: dict, containing neighbor relationships and angles.
    - node_to_node_distances: dict, storing pairwise distances.
    """
    graph_features = {}
    node_to_node_distances = {}

    # Compute cluster centroids and project to XY plane
    centroids = np.array([np.mean(cluster, axis=0) for cluster in clusters])
    nodes = centroids[:, :2]  # Project to XY

    # Apply kNN
    knn = NearestNeighbors(n_neighbors=k+1)
    knn.fit(nodes)
    distances, indices = knn.kneighbors(nodes)

    for i, (centroid, neighbors_idx) in enumerate(zip(centroids, indices)):
        neighbor_indices = neighbors_idx[1:]
        neighbor_nodes = nodes[neighbor_indices]

        # Store distances
        for j, neighbor_idx in enumerate(neighbor_indices):
            node_to_node_distances[(i, neighbor_idx)] = distances[i][j+1]

        # Compute angles
        neighbor_combinations = list(combinations(neighbor_indices, 2))
        angles = [
            np.arctan2(nodes[n2][1] - centroid[1], nodes[n2][0] - centroid[0]) -
            np.arctan2(nodes[n1][1] - centroid[1], nodes[n1][0] - centroid[0])
            for n1, n2 in neighbor_combinations
        ]
        angles = np.mod(angles, 2 * np.pi)  # Ensure within [0, 2π)

        graph_features[i] = {
            'combinations': neighbor_combinations,
            'angles': angles
        }

    return graph_features, node_to_node_distances

    



In [None]:
def is_within_threshold(value1, value2, threshold):
    """ Check if two values are within a given threshold. """
    return abs(value1 - value2) <= threshold

def angle_difference(angle1, angle2):
    """ Compute the absolute difference between two angles, considering circularity. """
    return min(abs(angle1 - angle2), 2 * np.pi - abs(angle1 - angle2))

# ==========================
# Graph Matching
# ==========================
def match_sets(dls_features, tls_features, dls_node_to_node_distances, tls_node_to_node_distances, distance_threshold, angle_threshold_degrees):
    """
    Match tree trunk clusters using graph-based similarity.

    Parameters:
    - dls_features, tls_features: dict, graph features for DLS and TLS datasets.
    - dls_node_to_node_distances, tls_node_to_node_distances: dict, distances between nodes.
    - distance_threshold: float, distance matching threshold.
    - angle_threshold_degrees: float, angle matching threshold in degrees.

    Returns:
    - matches: list, containing matched graph nodes.
    """
    angle_threshold = np.radians(angle_threshold_degrees)
    matches = []

    for dls_axis, dls_data in dls_features.items():
        for tls_axis, tls_data in tls_features.items():
            for dls_combination, tls_combination in zip(dls_data['combinations'], tls_data['combinations']):
                dls_distances = [dls_node_to_node_distances[(dls_axis, n)] for n in dls_combination]
                tls_distances = [tls_node_to_node_distances[(tls_axis, n)] for n in tls_combination]

                # Match based on distance and angle threshold
                distance_match = all(abs(d - t) <= distance_threshold for d, t in zip(dls_distances, tls_distances))
                angle_match = all(abs(d - t) <= angle_threshold for d, t in zip(dls_data['angles'], tls_data['angles']))

                if distance_match and angle_match:
                    matches.append((dls_axis, tls_axis, dls_combination, tls_combination))
    
    return matches
    

def standardize_match(match):
    """ Standardize match representation for consistency. """
    branch_pairs = tuple(sorted([(match['dls_combination'][i], match['tls_combination'][i]) for i in range(2)]))
    return ((match['dls_axis'], match['tls_axis']),) + branch_pairs

def record_matches(matched_sets):
    """ Record all matched sets. """
    return [standardize_match(match) for match in matched_sets]

# ==========================
# Pattern Recognition
# ==========================
def find_complete_patterns(recorded_matches):
    """
    Identify complete graph matching patterns.

    Parameters:
    - recorded_matches: list, containing matched graph nodes.

    Returns:
    - complete_patterns: list, containing full matching structures.
    """
    patterns = {}
    
    for match in recorded_matches:
        nodes = frozenset(match)
        patterns.setdefault(nodes, []).append(match)
    
    return [matches[0] for nodes, matches in patterns.items() if len(matches) == 3 and len(set(map(frozenset, matches))) == 1]

In [7]:
# ==========================
# Graph Matching Optimization via Linear Programming
# ==========================
def optimize_graph_matching(complete_patterns):
    """
    Perform graph matching optimization using linear programming (LP).
    
    The goal is to maximize the number of valid matches while ensuring 
    one-to-one correspondence between matched nodes.

    Parameters:
    - complete_patterns: list, containing all valid graph matching patterns.

    Returns:
    - matched_x_pairs: list of tuples, representing matched node pairs.
    - matched_y_patterns: list of tuples, representing optimized matching sets.
    """
    prob = pulp.LpProblem("GraphMatching", pulp.LpMaximize)

    # Define decision variables
    y_vars = {}
    x_vars = {}
    unique_mn_pairs = set(pair for pattern in complete_patterns for pair in pattern)

    # Binary variables for node matching
    for m, n in unique_mn_pairs:
        x_vars[(m, n)] = pulp.LpVariable(f"x_{m}_{n}", cat='Binary')

    # Binary variables for pattern selection
    for c, pattern in enumerate(complete_patterns):
        for m, n in pattern: 
            y_vars[(m, n, c)] = pulp.LpVariable(f"y_{m}_{n}_{c}", cat='Binary')

    # Objective: Maximize the total number of matched patterns
    prob += pulp.lpSum(y_vars.values())

    # Ensure y_vars cannot be 1 unless x_vars is 1
    for (m, n, c), y_var in y_vars.items():
        prob += y_var <= x_vars[(m, n)]

    # Ensure x_vars is set if any y_var is set for a pattern
    for m, n in unique_mn_pairs:
        prob += pulp.lpSum(y_vars[(m, n, c)] for c in range(len(complete_patterns)) if (m, n, c) in y_vars) >= x_vars[(m, n)]

    # Ensure each node is matched at most once
    for m in set(pair[0] for pair in unique_mn_pairs):
        prob += pulp.lpSum(x_vars[(m, n)] for n in set(pair[1] for pair in unique_mn_pairs) if (m, n) in x_vars) <= 1

    for n in set(pair[1] for pair in unique_mn_pairs):
        prob += pulp.lpSum(x_vars[(m, n)] for m in set(pair[0] for pair in unique_mn_pairs) if (m, n) in x_vars) <= 1

    # Ensure consistency within each pattern
    for c, pattern in enumerate(complete_patterns):
        for i, (m1, n1) in enumerate(pattern):
            for j, (m2, n2) in enumerate(pattern):
                if i < j:
                    prob += y_vars[(m1, n1, c)] == y_vars[(m2, n2, c)]

    # Solve the optimization problem
    prob.solve()

    # Extract optimized matches
    matched_x_pairs = [k for k, v in x_vars.items() if v.varValue == 1]
    matched_y_patterns = [(m, n, c) for (m, n, c), v in y_vars.items() if v.varValue == 1]

    return matched_x_pairs, matched_y_patterns




# ==========================
# Elevation Retrieval from DTM
# ==========================
def get_elevation_at_points(dtm_file, points_2d):
    """
    Retrieve elevation values from a Digital Terrain Model (DTM) for given 2D points.

    Parameters:
    - dtm_file: str, file path to the DTM raster.
    - points_2d: np.array (Nx2), 2D coordinates (x, y) for elevation lookup.

    Returns:
    - elevations: np.array (N,), corresponding elevation values.
    """
    with rasterio.open(dtm_file) as dtm:
        row_indices, col_indices = dtm.index(points_2d[:, 0], points_2d[:, 1])
        elevation_data = dtm.read(1)  # Read first band
        elevations = elevation_data[row_indices, col_indices]
    return elevations


    
def get_z_elevation_from_dtm(nodes, dtm_data, pattern_indices):
    """
    Get Z elevation values from a smoothed Digital Terrain Model (DTM) using nearest neighbor interpolation.

    Parameters:
    - nodes: np.array, XY positions of tree clusters.
    - dtm_data: dict, contains precomputed DTM grid ('x_mesh', 'y_mesh', 'z_values_smoothed').
    - pattern_indices: list of int, indices of nodes requiring Z values.

    Returns:
    - points_3d: np.array (Nx3), 3D points with updated Z elevations.
    """
    points_2d = np.array([nodes[i] for i in pattern_indices])
    x_mesh, y_mesh, z_values_smoothed = dtm_data['x_mesh'], dtm_data['y_mesh'], dtm_data['z_values_smoothed']

    points_3d = []
    for point in points_2d:
        idx_x = np.abs(x_mesh[0] - point[0]).argmin()
        idx_y = np.abs(y_mesh[:, 0] - point[1]).argmin()
        z_elevation = z_values_smoothed[idx_y, idx_x]
        points_3d.append([point[0], point[1], z_elevation])

    return np.array(points_3d)

    

# ==========================
# Rigid Body Transformation Estimation
# ==========================
def find_rigid_body_transformation(src_points, dst_points):
    """
    Compute a rigid-body transformation (rotation + translation) using Singular Value Decomposition (SVD).

    Parameters:
    - src_points: np.array (Nx3), source 3D point set.
    - dst_points: np.array (Nx3), target 3D point set.

    Returns:
    - R: np.array (3x3), rotation matrix.
    - t: np.array (3,), translation vector.
    """
    src_centroid = np.mean(src_points, axis=0)
    dst_centroid = np.mean(dst_points, axis=0)
    
    src_centered = src_points - src_centroid
    dst_centered = dst_points - dst_centroid
    
    # Compute covariance matrix
    H = src_centered.T @ dst_centered
    
    # Singular Value Decomposition (SVD)
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # Ensure proper rotation matrix (no reflection)
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T
    
    # Compute translation vector
    t = dst_centroid - R @ src_centroid

    return R, t

    

def compute_transformation_matrix(dls_3d_points, tls_3d_points):
    """
    Compute the final 4x4 transformation matrix for aligning two point clouds.

    Parameters:
    - dls_3d_points: np.array (Nx3), source point cloud (DLS).
    - tls_3d_points: np.array (Nx3), target point cloud (TLS).

    Returns:
    - transformation_matrix: np.array (4x4), transformation matrix (R | t).
    """
    R, t = find_rigid_body_transformation(dls_3d_points, tls_3d_points)

    # Construct 4x4 homogeneous transformation matrix
    transformation_matrix = np.eye(4)
    transformation_matrix[:3, :3] = R
    transformation_matrix[:3, 3] = t

    return transformation_matrix

# Implement

In [23]:
def main():
    """ 
    Main function to execute the graph matching and transformation process.
    The workflow consists of:
    
    1. Loading TLS and DLS point clouds (ground and off-ground).
    2. Generating Digital Terrain Models (DTM) from ground points.
    3. Computing normalized heights for off-ground points.
    4. Extracting tree trunks by filtering points within a height range.
    5. Performing clustering on tree trunks using HDBSCAN.
    6. Analyzing and scoring clusters based on geometric properties.
    7. Selecting top clusters and constructing a kNN-based graph.
    8. Matching trees using structural graph-based constraints.
    9. Optimizing matching results with linear programming.
    10. Computing the transformation matrix from matched trees.
    """

    # Define paths for ground and off-ground point clouds
    tls_g = "your_path_to_TLS_ground_point_cloud.pcd"  # TLS ground points
    tls_og = "your_path_to_TLS_offground_point_cloud.pcd"  # TLS off-ground points
    dls_g = "your_path_to_DLS_ground_point_cloud.pcd"  # DLS ground points
    dls_og = "your_path_to_DLS_offground_point_cloud.pcd"  # DLS off-ground points

    
    # Load ground and off-ground point clouds for TLS and DLS
    tls_ground_pcd = load_pcd(tls_g)
    tls_offground_pcd = load_pcd(tls_og)
    dls_ground_pcd = load_pcd(dls_g)
    dls_offground_pcd = load_pcd(dls_og)

    # Generate Digital Terrain Models (DTM) for TLS and DLS
    tls_dtm_data = generate_dtm(tls_ground_pcd)
    dls_dtm_data = generate_dtm(dls_ground_pcd)

    # Compute normalized heights by subtracting DTM elevation from off-ground points
    tls_normalized_height = extract_normalized_height(tls_offground_pcd, tls_dtm_data)
    dls_normalized_height = extract_normalized_height(dls_offground_pcd, dls_dtm_data)

    # Ensure the extracted heights match the input point count
    assert len(tls_normalized_height) == len(np.asarray(tls_offground_pcd.points)), "TLS: Mismatch in points and normalized elevation"
    assert len(dls_normalized_height) == len(np.asarray(dls_offground_pcd.points)), "DLS: Mismatch in points and normalized elevation"

    # Filter tree trunks by selecting points within a specific height range
    tls_stems_pcd = filter_points_by_height(tls_offground_pcd, tls_normalized_height, elevation_min=4.5, elevation_max=5)
    dls_stems_pcd = filter_points_by_height(dls_offground_pcd, dls_normalized_height, elevation_min=7.5, elevation_max=8.5)

    # Convert tree trunk point clouds to NumPy arrays
    tls_stems_points = np.asarray(tls_stems_pcd.points)
    dls_stems_points = np.asarray(dls_stems_pcd.points)

    # Perform tree trunk clustering using HDBSCAN
    tls_clusters = apply_hdbscan(tls_stems_points)
    dls_clusters = apply_hdbscan(dls_stems_points)

    # Compute geometric properties for each cluster
    tls_cluster_data = analyze_clusters(tls_stems_points, tls_clusters)
    dls_cluster_data = analyze_clusters(dls_stems_points, dls_clusters)

    # Compute weighted scores for clusters based on reliability
    tls_cluster_data = calculate_weighted_score(tls_cluster_data)
    dls_cluster_data = calculate_weighted_score(dls_cluster_data)

    # Sort clusters by score in descending order
    sorted_tls_clusters = sorted(tls_cluster_data.items(), key=lambda x: x[1]['score'], reverse=True)
    sorted_dls_clusters = sorted(dls_cluster_data.items(), key=lambda x: x[1]['score'], reverse=True)

    # Ensure valid clusters exist before proceeding
    if not sorted_tls_clusters:
        raise ValueError("No valid TLS clusters detected!")
    if not sorted_dls_clusters:
        raise ValueError("No valid DLS clusters detected!")

    # Select the top 20% of clusters based on their scores
    top_tls_clusters = [label for label, _ in sorted_tls_clusters[:int(len(sorted_tls_clusters) * 0.2)]]
    top_dls_clusters = [label for label, _ in sorted_dls_clusters[:int(len(sorted_dls_clusters) * 0.2)]]

    # Filter points belonging to the selected clusters
    tls_filtered_points = tls_stems_points[np.isin(tls_clusters, top_tls_clusters)]
    dls_filtered_points = dls_stems_points[np.isin(dls_clusters, top_dls_clusters)]

    # Construct tree trunk graphs using kNN (k=3)
    tls_graph_features, tls_node_to_node_distances = construct_graph_features(tls_filtered_points, k=3)
    dls_graph_features, dls_node_to_node_distances = construct_graph_features(dls_filtered_points, k=3)

    # Define matching thresholds for distance and angle constraints
    distance_threshold = 0.5  # Euclidean distance threshold (meters)
    angle_threshold = 10      # Angular difference threshold (degrees)

    # Perform graph-based matching using tree trunk structures
    matched_sets = match_sets(dls_graph_features, tls_graph_features, dls_node_to_node_distances, tls_node_to_node_distances, distance_threshold, angle_threshold)

    # Record all matched tree trunk pairs
    recorded_matches = record_matches(matched_sets)

    # Identify complete triangle patterns among matched sets
    complete_patterns = find_complete_patterns(recorded_matches)

    # Solve optimization problem to refine matching results
    matched_x_pairs, matched_y_patterns = optimize_graph_matching(complete_patterns)

    # Extract final matched indices for TLS and DLS
    distance_threshold = 0.5  # Euclidean distance threshold (meters)
    angle_threshold = 10      # Angular difference threshold (degrees)


    # Retrieve 3D coordinates of matched trees using DTM
    dls_3d_points = get_z_elevation_from_dtm(dls_filtered_points, dls_dtm_data, dls_indices)
    tls_3d_points = get_z_elevation_from_dtm(tls_filtered_points, tls_dtm_data, tls_indices)

    # Compute the rigid body transformation matrix
    transformation_matrix = compute_transformation_matrix(dls_3d_points, tls_3d_points)

    # Output the final transformation matrix
    print("Final Transformation Matrix:")
    print(transformation_matrix)


In [None]:
if __name__ == "__main__":
    main()