# Procedural Depth Map Generation

This notebook demonstrates the procedural generation of depth maps for synthetic wave data using physics-based wave parameters.

## Purpose
- Generate realistic depth maps from wave parameters
- Visualize the effect of different wave characteristics
- Understand the depth generation pipeline
- Test parameter variations interactively

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import ipywidgets as widgets
from IPython.display import display
from typing import Dict, Any, Tuple
import seaborn as sns

## Wave Parameter Sampling

In [None]:
def sample_wave_params(rng: np.random.Generator) -> Dict[str, Any]:
    """Sample procedural wave parameters for synthetic depth generation."""
    height_m = float(rng.uniform(0.2, 2.0))

    wave_type = rng.choice(["beach_break", "reef_break", "point_break", "closeout", "a_frame"]).item()
    direction = rng.choice(["left", "right", "both"]).item()

    wavelength = float(rng.uniform(10.0, 28.0))
    angle_deg = float(rng.uniform(-20.0, 20.0))
    phase = float(rng.uniform(0.0, 2.0 * np.pi))

    return {
        "height_meters": height_m,
        "wave_type": wave_type,
        "direction": direction,
        "wavelength": wavelength,
        "angle_deg": angle_deg,
        "phase": phase,
        "occlusion_mode": "none",  # Prevents circles in ControlNet outputs
    }

## Depth Map Generation Functions

In [None]:
def _add_soft_irregular_occlusions(depth: np.ndarray, rng: np.random.Generator, max_count: int = 2) -> np.ndarray:
    """Add very soft, irregular occlusions to avoid obvious circles."""
    out = depth.copy()
    H, W = out.shape

    count = int(rng.integers(0, max_count + 1))
    if count == 0:
        return out

    # Set occlusions to slightly farther depth, but not extreme
    far_val = float(out.max()) + 0.12

    for _ in range(count):
        cx = int(rng.integers(int(W * 0.20), int(W * 0.80)))
        cy = int(rng.integers(int(H * 0.20), int(H * 0.80)))

        ax = int(rng.integers(int(W * 0.03), int(W * 0.07)))
        ay = int(rng.integers(int(H * 0.03), int(H * 0.07)))
        angle = float(rng.uniform(0.0, 180.0))

        mask = np.zeros((H, W), dtype=np.float32)
        cv2.ellipse(mask, (cx, cy), (ax, ay), angle, 0, 360, 1.0, thickness=-1)

        # Heavy blur to remove hard edges
        mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=9.0, sigmaY=9.0)

        # Make the effect weak so it does not "stamp" artifacts into RGB
        strength = float(rng.uniform(0.08, 0.18))
        mask = mask * strength

        out = out * (1.0 - mask) + far_val * mask

    return out

In [None]:
def generate_param_depth_map_v2(
    params: Dict[str, Any],
    size: Tuple[int, int] = (768, 768),
    seed: int = 0,
) -> np.ndarray:
    """
    Beach-camera-like depth map for ControlNet Depth:
    - Strong depth gradient: far (top) is larger depth, near (bottom) is smaller depth
    - Perspective foreshortening: wave frequency increases toward horizon
    - Breaking line shaped by wave_type and direction
    - Run-up band near camera
    - Optional occlusions are disabled by default to prevent circles in the generated RGB
    """
    rng = np.random.default_rng(seed)
    H, W = size

    # Normalized coordinates: x in [-1, 1], y in [0, 1]
    # y=0 is near camera, y=1 is far/horizon
    x = np.linspace(-1.0, 1.0, W, dtype=np.float32)[None, :].repeat(H, axis=0)
    y = np.linspace(0.0, 1.0, H, dtype=np.float32)[:, None].repeat(W, axis=1)

    height = float(params.get("height_meters", 1.0))
    wave_type = str(params.get("wave_type", "beach_break"))
    direction = str(params.get("direction", "both"))

    # Base depth increases with y (farther is larger depth)
    gamma = 1.7
    y_p = y ** gamma
    base_depth = 0.6 + 3.5 * y_p

    # Direction controls breaker slant and approach angle
    if direction == "left":
        theta = np.deg2rad(18.0)
        slant = 0.14
    elif direction == "right":
        theta = np.deg2rad(-18.0)
        slant = -0.14
    else:
        theta = np.deg2rad(0.0)
        slant = 0.0

    # Wave frequency increases toward horizon (foreshortening)
    wavelength = float(params.get("wavelength", 18.0))
    k0 = (2.0 * np.pi) / max(wavelength, 1e-3)
    k = k0 * (1.0 + 2.8 * y_p)

    u = np.cos(theta) * x + np.sin(theta) * (y_p - 0.55)
    phase0 = float(params.get("phase", 0.0))
    phase = k * (u * 6.0) + phase0

    # Keep relief smaller than base depth
    amp = 0.08 + 0.18 * np.clip(height / 2.5, 0.0, 1.0)
    wave_relief = amp * np.sin(phase)

    # Soft spatial noise to avoid perfect stripes
    n = rng.normal(0.0, 1.0, size=(H, W)).astype(np.float32)
    n = cv2.GaussianBlur(n, (0, 0), sigmaX=3.0, sigmaY=3.0)
    n = (n - n.min()) / (n.max() - n.min() + 1e-8)
    noise_relief = (n - 0.5) * (0.04 + 0.03 * float(rng.random()))

    # Breaking line position and shape by wave_type
    break_y = 0.22 + 0.04 * float(rng.uniform(-1.0, 1.0))

    if wave_type == "closeout":
        curvature = 0.0
        irregular = 0.0
        slant *= 0.2
    elif wave_type == "a_frame":
        curvature = 0.10
        irregular = 0.01
        slant *= 0.3
    elif wave_type == "point_break":
        curvature = 0.03
        irregular = 0.01
        slant *= 1.2
    elif wave_type == "reef_break":
        curvature = 0.02
        irregular = 0.03
        slant *= 0.9
    else:
        curvature = 0.02
        irregular = 0.015
        slant *= 0.7

    # Breaker line equation
    if wave_type == "a_frame":
        line = break_y + slant * x + curvature * np.abs(x)
    else:
        line = break_y + slant * x + curvature * (x ** 2)

    # Irregularity along x
    if irregular > 0:
        ix = rng.normal(0.0, 1.0, size=(1, W)).astype(np.float32)
        ix = cv2.GaussianBlur(ix, (0, 0), sigmaX=10.0)
        ix = (ix - ix.min()) / (ix.max() - ix.min() + 1e-8)
        ix = (ix - 0.5) * irregular
        line = line + ix.repeat(H, axis=0)

    breaker_band = np.exp(-((y - line) ** 2) / (2.0 * (0.012 ** 2))).astype(np.float32)
    runup_band = np.exp(-((y - 0.10) ** 2) / (2.0 * (0.020 ** 2))).astype(np.float32)

    breaker_relief = 0.35 * breaker_band
    runup_relief = 0.22 * runup_band

    depth = base_depth - wave_relief - noise_relief - breaker_relief - runup_relief

    # Gentle shoreline slope near camera
    shore_slope = 0.25 * (1.0 - y) ** 2
    depth = depth - shore_slope

    # Optional occlusions, disabled by default
    occlusion_mode = str(params.get("occlusion_mode", "none"))
    if occlusion_mode == "soft_irregular":
        depth = _add_soft_irregular_occlusions(depth, rng=rng, max_count=2)

    return depth.astype(np.float32)

## Visualization Functions

In [None]:
def plot_depth_map(depth_map, title="Depth Map", figsize=(10, 8)):
    """Plot depth map with ocean-like colormap."""
    # Create ocean-like colormap (shallow=light, deep=dark)
    colors = ['#87CEEB', '#4682B4', '#191970', '#000080']
    ocean_cmap = LinearSegmentedColormap.from_list('ocean', colors)
    
    plt.figure(figsize=figsize)
    im = plt.imshow(depth_map, cmap=ocean_cmap, aspect='equal')
    plt.colorbar(im, label='Depth (relative units)', shrink=0.8)
    plt.title(title)
    plt.xlabel('X (shore-parallel)')
    plt.ylabel('Y (shore-perpendicular, 0=near camera, 1=horizon)')
    
    # Add annotations
    H, W = depth_map.shape
    plt.text(W*0.05, H*0.95, 'Near\n(Camera)', fontsize=10, 
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    plt.text(W*0.05, H*0.05, 'Far\n(Horizon)', fontsize=10,
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()


def plot_depth_analysis(depth_map, params):
    """Plot detailed analysis of depth map."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Main depth map
    colors = ['#87CEEB', '#4682B4', '#191970', '#000080']
    ocean_cmap = LinearSegmentedColormap.from_list('ocean', colors)
    
    im = axes[0, 0].imshow(depth_map, cmap=ocean_cmap, aspect='equal')
    axes[0, 0].set_title(f'Depth Map\n{params["wave_type"]} - {params["direction"]} - {params["height_meters"]:.1f}m')
    plt.colorbar(im, ax=axes[0, 0], shrink=0.8)
    
    # Depth profile along center line (shore-perpendicular)
    H, W = depth_map.shape
    center_profile = depth_map[:, W//2]
    y_coords = np.linspace(0, 1, H)
    
    axes[0, 1].plot(y_coords, center_profile, 'b-', linewidth=2)
    axes[0, 1].set_xlabel('Y (0=near, 1=far)')
    axes[0, 1].set_ylabel('Depth')
    axes[0, 1].set_title('Depth Profile (Center Line)')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].invert_yaxis()  # Shallow at top
    
    # Depth profile along breaking line (shore-parallel)
    break_line_y = int(H * 0.25)  # Approximate breaking line position
    break_profile = depth_map[break_line_y, :]
    x_coords = np.linspace(-1, 1, W)
    
    axes[1, 0].plot(x_coords, break_profile, 'r-', linewidth=2)
    axes[1, 0].set_xlabel('X (shore-parallel)')
    axes[1, 0].set_ylabel('Depth')
    axes[1, 0].set_title('Depth Profile (Breaking Line)')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].invert_yaxis()
    
    # Depth histogram
    axes[1, 1].hist(depth_map.flatten(), bins=50, alpha=0.7, edgecolor='black')
    axes[1, 1].set_xlabel('Depth Value')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_title('Depth Distribution')
    axes[1, 1].axvline(depth_map.mean(), color='red', linestyle='--', 
                      label=f'Mean: {depth_map.mean():.3f}')
    axes[1, 1].legend()
    
    plt.tight_layout()
    plt.show()

## Example Depth Map Generation

In [None]:
# Generate example depth maps for different wave types
wave_types = ["beach_break", "reef_break", "point_break", "closeout", "a_frame"]
directions = ["left", "right", "both"]

# Set random seed for reproducibility
rng = np.random.default_rng(42)

print("Generating example depth maps for different wave types...")

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

colors = ['#87CEEB', '#4682B4', '#191970', '#000080']
ocean_cmap = LinearSegmentedColormap.from_list('ocean', colors)

for i, wave_type in enumerate(wave_types):
    params = {
        "height_meters": 1.5,
        "wave_type": wave_type,
        "direction": "both",
        "wavelength": 20.0,
        "angle_deg": 0.0,
        "phase": 0.0,
        "occlusion_mode": "none"
    }
    
    depth_map = generate_param_depth_map_v2(params, size=(256, 256), seed=42 + i)
    
    im = axes[i].imshow(depth_map, cmap=ocean_cmap, aspect='equal')
    axes[i].set_title(f'{wave_type.replace("_", " ").title()}')
    axes[i].set_xticks([])
    axes[i].set_yticks([])

# Remove the last subplot
axes[5].remove()

plt.tight_layout()
plt.show()

## Interactive Parameter Exploration

In [None]:
# Interactive widget for exploring parameters
def interactive_depth_generation():
    """Create interactive widgets for depth map generation."""
    
    # Create widgets
    height_slider = widgets.FloatSlider(
        value=1.5, min=0.2, max=3.0, step=0.1,
        description='Height (m):'
    )
    
    wave_type_dropdown = widgets.Dropdown(
        options=wave_types,
        value='beach_break',
        description='Wave Type:'
    )
    
    direction_dropdown = widgets.Dropdown(
        options=directions,
        value='both',
        description='Direction:'
    )
    
    wavelength_slider = widgets.FloatSlider(
        value=20.0, min=8.0, max=35.0, step=1.0,
        description='Wavelength:'
    )
    
    angle_slider = widgets.FloatSlider(
        value=0.0, min=-30.0, max=30.0, step=2.0,
        description='Angle (deg):'
    )
    
    phase_slider = widgets.FloatSlider(
        value=0.0, min=0.0, max=6.28, step=0.1,
        description='Phase:'
    )
    
    seed_slider = widgets.IntSlider(
        value=42, min=0, max=1000, step=1,
        description='Seed:'
    )
    
    def update_depth_map(height, wave_type, direction, wavelength, angle, phase, seed):
        params = {
            "height_meters": height,
            "wave_type": wave_type,
            "direction": direction,
            "wavelength": wavelength,
            "angle_deg": angle,
            "phase": phase,
            "occlusion_mode": "none"
        }
        
        depth_map = generate_param_depth_map_v2(params, size=(512, 512), seed=seed)
        plot_depth_analysis(depth_map, params)
    
    # Create interactive widget
    interactive_widget = widgets.interactive(
        update_depth_map,
        height=height_slider,
        wave_type=wave_type_dropdown,
        direction=direction_dropdown,
        wavelength=wavelength_slider,
        angle=angle_slider,
        phase=phase_slider,
        seed=seed_slider
    )
    
    return interactive_widget

# Display interactive widget
print("Interactive Depth Map Generator:")
print("Adjust the parameters below to see how they affect the depth map generation.")
interactive_widget = interactive_depth_generation()
display(interactive_widget)

## Parameter Effect Analysis

In [None]:
# Analyze the effect of different parameters
def analyze_parameter_effects():
    """Analyze how different parameters affect the depth map."""
    
    base_params = {
        "height_meters": 1.5,
        "wave_type": "beach_break",
        "direction": "both",
        "wavelength": 20.0,
        "angle_deg": 0.0,
        "phase": 0.0,
        "occlusion_mode": "none"
    }
    
    # Test different heights
    heights = [0.5, 1.0, 1.5, 2.0, 2.5]
    
    fig, axes = plt.subplots(1, len(heights), figsize=(20, 4))
    colors = ['#87CEEB', '#4682B4', '#191970', '#000080']
    ocean_cmap = LinearSegmentedColormap.from_list('ocean', colors)
    
    for i, height in enumerate(heights):
        params = base_params.copy()
        params["height_meters"] = height
        
        depth_map = generate_param_depth_map_v2(params, size=(256, 256), seed=42)
        
        axes[i].imshow(depth_map, cmap=ocean_cmap, aspect='equal')
        axes[i].set_title(f'Height: {height}m')
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    
    plt.suptitle('Effect of Wave Height on Depth Map', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Test different directions
    fig, axes = plt.subplots(1, len(directions), figsize=(15, 4))
    
    for i, direction in enumerate(directions):
        params = base_params.copy()
        params["direction"] = direction
        
        depth_map = generate_param_depth_map_v2(params, size=(256, 256), seed=42)
        
        axes[i].imshow(depth_map, cmap=ocean_cmap, aspect='equal')
        axes[i].set_title(f'Direction: {direction}')
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    
    plt.suptitle('Effect of Wave Direction on Depth Map', fontsize=16)
    plt.tight_layout()
    plt.show()

analyze_parameter_effects()

## Batch Generation Example

In [None]:
# Generate a batch of random depth maps
def generate_batch_depth_maps(n_samples=9, seed=42):
    """Generate a batch of depth maps with random parameters."""
    rng = np.random.default_rng(seed)
    
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()
    
    colors = ['#87CEEB', '#4682B4', '#191970', '#000080']
    ocean_cmap = LinearSegmentedColormap.from_list('ocean', colors)
    
    for i in range(n_samples):
        # Sample random parameters
        params = sample_wave_params(rng)
        
        # Generate depth map
        depth_map = generate_param_depth_map_v2(params, size=(256, 256), seed=seed + i)
        
        # Plot
        axes[i].imshow(depth_map, cmap=ocean_cmap, aspect='equal')
        axes[i].set_title(
            f'{params["wave_type"]}\n'
            f'{params["direction"]} - {params["height_meters"]:.1f}m',
            fontsize=10
        )
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    
    plt.suptitle('Batch of Randomly Generated Depth Maps', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    return [sample_wave_params(rng) for _ in range(n_samples)]

print("Generating batch of random depth maps...")
batch_params = generate_batch_depth_maps(n_samples=9, seed=123)

## Depth Processing Utilities

In [None]:
def robust_normalize_to_u8(depth_float: np.ndarray, invert: bool = True) -> np.ndarray:
    """Normalize depth map to uint8 range."""
    # Robust normalization using percentiles
    p_low, p_high = np.percentile(depth_float, [2, 98])
    depth_clipped = np.clip(depth_float, p_low, p_high)
    
    # Normalize to [0, 1]
    depth_norm = (depth_clipped - p_low) / (p_high - p_low + 1e-8)
    
    if invert:
        depth_norm = 1.0 - depth_norm
    
    return (depth_norm * 255).astype(np.uint8)


def add_realism_to_depth(depth_u8: np.ndarray, seed: int = 0) -> np.ndarray:
    """Add mild realism without stamping circular occlusions."""
    rng = np.random.default_rng(seed)
    
    # Ensure uint8 grayscale
    out = depth_u8.copy()
    if out.dtype != np.uint8:
        out = out.astype(np.uint8)
    
    # Very mild blur to remove banding
    out = cv2.GaussianBlur(out, (0, 0), sigmaX=0.6, sigmaY=0.6)
    
    # Add tiny grain noise (no shapes)
    noise = rng.normal(0.0, 2.0, size=out.shape).astype(np.float32)
    out_f = out.astype(np.float32) + noise
    
    out_f = np.clip(out_f, 0, 255)
    return out_f.astype(np.uint8)


# Test depth processing
test_params = {
    "height_meters": 1.8,
    "wave_type": "reef_break",
    "direction": "left",
    "wavelength": 22.0,
    "angle_deg": 10.0,
    "phase": 1.5,
    "occlusion_mode": "none"
}

depth_float = generate_param_depth_map_v2(test_params, size=(512, 512), seed=42)
depth_u8 = robust_normalize_to_u8(depth_float, invert=True)
depth_processed = add_realism_to_depth(depth_u8, seed=42)

# Visualize processing pipeline
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original float depth
colors = ['#87CEEB', '#4682B4', '#191970', '#000080']
ocean_cmap = LinearSegmentedColormap.from_list('ocean', colors)
axes[0].imshow(depth_float, cmap=ocean_cmap)
axes[0].set_title('Original Float Depth')
axes[0].set_xticks([])
axes[0].set_yticks([])

# Normalized uint8
axes[1].imshow(depth_u8, cmap='gray')
axes[1].set_title('Normalized uint8')
axes[1].set_xticks([])
axes[1].set_yticks([])

# Processed with realism
axes[2].imshow(depth_processed, cmap='gray')
axes[2].set_title('Processed with Realism')
axes[2].set_xticks([])
axes[2].set_yticks([])

plt.suptitle('Depth Processing Pipeline', fontsize=16)
plt.tight_layout()
plt.show()

print(f"Original depth range: {depth_float.min():.3f} to {depth_float.max():.3f}")
print(f"Normalized depth range: {depth_u8.min()} to {depth_u8.max()}")
print(f"Processed depth range: {depth_processed.min()} to {depth_processed.max()}")

## Summary

This notebook demonstrates the procedural generation of depth maps for synthetic wave data. Key features:

1. **Physics-based Generation**: Depth maps incorporate realistic wave physics including foreshortening, breaking patterns, and shore effects

2. **Parameter Control**: Wave height, type, direction, wavelength, and other parameters directly influence the depth map structure

3. **Wave Type Specificity**: Different wave types (beach break, reef break, etc.) produce characteristic depth patterns

4. **Processing Pipeline**: Float depth maps are normalized and processed for use with ControlNet

5. **Interactive Exploration**: Widgets allow real-time parameter adjustment to understand their effects

The generated depth maps serve as conditioning input for SDXL + ControlNet to produce realistic synthetic wave images with known ground truth labels.