# Heightmap Subdivision

This notebook demonstrates how to generate a 3D mesh from a 2D heightmap image using recursive subdivision.
The mesh density adapts to the curvature of the surface: areas with high curvature are subdivided more than flat areas.

In [None]:
import os
from PIL import Image
import compas.datastructures as cd
import compas.geometry as cg
from compas_notebook.viewer import Viewer

# Create a viewer
viewer = Viewer()

## Configuration & Image Loading

Define the parameters for the mesh generation and load the heightmap image.
You can adjust `max_recursion_depth` to control the level of detail.

In [None]:
# Define the area to cover
start_x = -20
start_y = -20
total_size = 40
z_scale = 15  # Maximum height

# Image path
image_filename = "heigh_map.jpg"
image_path = os.path.abspath(image_filename)

## Heightmap Mesher Class

We define a `HeightmapMesher` class to encapsulate the mesh generation logic. 

**Why use a class?**
1.  **State Management**: The class holds the image data (`pixels`, `width`, `height`) and the configuration (`start_x`, `z_scale`, etc.) internally. This avoids passing these parameters to every recursive function call.
2.  **Vertex Tracking**: It maintains a `vertex_map` to ensure vertices are shared (welded) correctly across different recursive calls, preventing duplicate vertices at the same location.
3.  **Encapsulation**: It keeps the global namespace clean and allows us to easily create multiple meshers with different settings if needed.

In [None]:
class HeightmapMesher:
    def __init__(self, img_path, start_x, start_y, total_size, z_scale):
        self.start_x = start_x
        self.start_y = start_y
        self.total_size = total_size
        self.z_scale = z_scale
        self.vertex_map = {}
        
        # Load image
        try:
            img = Image.open(img_path).convert('L')
            self.width, self.height = img.size
            self.pixels = img.load()
            print(f"Loaded heightmap: {self.width}x{self.height}")
        except Exception as e:
            print(f"Error loading image: {e}")
            # Fallback to a dummy image
            img = Image.new('L', (100, 100), color=128)
            self.width, self.height = img.size
            self.pixels = img.load()

    def get_height(self, x, y):
        """Returns the Z height for a given X, Y coordinate based on the image."""
        u = (x - self.start_x) / self.total_size
        v = 1.0 - ((y - self.start_y) / self.total_size)
        
        u = max(0.0, min(1.0, u))
        v = max(0.0, min(1.0, v))
        
        px = int(u * (self.width - 1))
        py = int(v * (self.height - 1))
        
        brightness = self.pixels[px, py]
        return (brightness / 255.0) * self.z_scale

    def get_or_create_vertex(self, mesh, x, y, z):
        """Returns an existing vertex key if (x,y) exists, otherwise creates a new one."""
        key = (round(x, 4), round(y, 4))
        if key not in self.vertex_map:
            v_id = mesh.add_vertex(x=x, y=y, z=z)
            self.vertex_map[key] = v_id
        return self.vertex_map[key]

    def recursive_subdivide(self, mesh, x, y, size, level, max_level):
        """Recursively subdivides a quad patch based on surface curvature."""
        half = size / 2
        
        # Corners
        p1 = (x, y)
        p2 = (x + size, y)
        p3 = (x + size, y + size)
        p4 = (x, y + size)
        pc = (x + half, y + half) # Center
        
        # Heights
        z1 = self.get_height(*p1)
        z2 = self.get_height(*p2)
        z3 = self.get_height(*p3)
        z4 = self.get_height(*p4)
        zc = self.get_height(*pc)
        
        # Curvature check
        z_avg = (z1 + z2 + z3 + z4) / 4
        curvature_metric = abs(zc - z_avg)
        threshold = 0.15 
        
        if level < max_level and curvature_metric > threshold:
            # Subdivide
            self.recursive_subdivide(mesh, x, y, half, level + 1, max_level)
            self.recursive_subdivide(mesh, x + half, y, half, level + 1, max_level)
            self.recursive_subdivide(mesh, x + half, y + half, half, level + 1, max_level)
            self.recursive_subdivide(mesh, x, y + half, half, level + 1, max_level)
        else:
            # Create Face
            a = self.get_or_create_vertex(mesh, p1[0], p1[1], z1)
            b = self.get_or_create_vertex(mesh, p2[0], p2[1], z2)
            c = self.get_or_create_vertex(mesh, p3[0], p3[1], z3)
            d = self.get_or_create_vertex(mesh, p4[0], p4[1], z4)
            mesh.add_face([a, b, c, d])

    def generate(self, max_depth):
        """Generates the mesh with the specified maximum recursion depth."""
        self.vertex_map = {} # Reset for new generation
        mesh = cd.Mesh()
        print(f"Generating recursive mesh with max_depth={max_depth}...")
        self.recursive_subdivide(mesh, self.start_x, self.start_y, self.total_size, 0, max_depth)
        return mesh

## Step-by-Step Visualization

We will now demonstrate the effect of increasing the recursion depth.
We define a helper function to generate and visualize the mesh for a specific maximum depth.

## Concept: Curvature Metric Visualization

Before running the full recursion, let's visualize how the algorithm decides to subdivide a single face.

The algorithm checks a single quad patch defined by 4 corners.
1.  It calculates the **Average Height ($Z_{avg}$)** of the 4 corners. This represents the height of the center if the face were perfectly flat.
2.  It samples the **Actual Height ($Z_{center}$)** of the surface at the center point from the heightmap.
3.  The **Curvature Metric** is the absolute difference: $|Z_{center} - Z_{avg}|$.

If this deviation is larger than the `threshold`, the face is considered "curved" and gets subdivided.

In the visualization below:
*   **Gray Mesh**: The flat quad connecting the corners.
*   **Red Point**: The average height ($Z_{avg}$) on the flat approximation.
*   **Green Point**: The actual surface height ($Z_{center}$).
*   **Yellow Line**: The deviation (metric) between them.

In [None]:
from compas_viewer import Viewer as AppViewer
from compas.colors import Color

# Initialize a mesher for this demonstration
demo_mesher = HeightmapMesher(image_path, start_x, start_y, total_size, z_scale)

# Define a single large quad (the root quad) to visualize
x, y = start_x, start_y
size = total_size
half = size / 2

# 1. Get Geometry of the Quad
p1 = (x, y)
p2 = (x + size, y)
p3 = (x + size, y + size)
p4 = (x, y + size)
pc = (x + half, y + half) # Center (x, y)

# 2. Get Heights
z1 = demo_mesher.get_height(*p1)
z2 = demo_mesher.get_height(*p2)
z3 = demo_mesher.get_height(*p3)
z4 = demo_mesher.get_height(*p4)
zc = demo_mesher.get_height(*pc) # Actual surface height at center

# 3. Calculate Metric
z_avg = (z1 + z2 + z3 + z4) / 4
curvature_metric = abs(zc - z_avg)

print(f"Corner Heights: {z1:.2f}, {z2:.2f}, {z3:.2f}, {z4:.2f}")
print(f"Average Corner Height (Flat): {z_avg:.2f}")
print(f"Actual Center Height (Surface): {zc:.2f}")
print(f"Curvature Metric (Deviation): {curvature_metric:.2f}")

# 4. Visualization
# We use AppViewer (compas_viewer) here for better point color support in this demo
viewer = AppViewer()

# Draw the flat quad
mesh = cd.Mesh()
a = mesh.add_vertex(x=p1[0], y=p1[1], z=z1)
b = mesh.add_vertex(x=p2[0], y=p2[1], z=z2)
c = mesh.add_vertex(x=p3[0], y=p3[1], z=z3)
d = mesh.add_vertex(x=p4[0], y=p4[1], z=z4)
mesh.add_face([a, b, c, d])
viewer.scene.add(mesh, opacity=0.5, name="Flat Quad")

# Draw the points
pt_surface = cg.Point(pc[0], pc[1], zc)
pt_avg = cg.Point(pc[0], pc[1], z_avg)

viewer.scene.add(pt_surface, pointcolor=Color.green(), pointsize=20, name="Actual Surface Point (Green)")
viewer.scene.add(pt_avg, pointcolor=Color.red(), pointsize=20, name="Average Point (Red)")

# Draw the deviation line
line = cg.Line(pt_surface, pt_avg)
viewer.scene.add(line, linecolor=Color.yellow(), linewidth=5, name="Deviation")

viewer.show()

In [None]:
# Initialize the mesher once
mesher = HeightmapMesher(image_path, start_x, start_y, total_size, z_scale)

def generate_and_visualize(max_depth):
    # Generate mesh using the mesher instance
    mesh = mesher.generate(max_depth)

    print(mesh.summary())

    # Visualize the mesh
    viewer = Viewer()
    viewer.scene.add(mesh)
    viewer.show()

### Level 0: No Subdivision
The mesh consists of a single quad (or very few if the threshold is met immediately, but at level 0 it stops).

In [None]:
generate_and_visualize(0)

### Level 2: Low Detail
Basic shape begins to emerge.

In [None]:
generate_and_visualize(2)

### Level 4: Medium Detail
More features are captured.

In [None]:
generate_and_visualize(4)

### Level 6: High Detail (Final)
The mesh closely follows the heightmap curvature.

In [None]:
generate_and_visualize(6)