# SwellSight Wave Image Generation (Standalone)

This notebook generates realistic wave images using AI (ControlNet + Stable Diffusion).
**This is a standalone version that includes all necessary code and doesn't require external files.**

## Overview
- Generate depth maps from wave parameters
- Create photorealistic wave images using AI
- Different wave types: beach_break, reef_break, point_break, closeout, a_frame
- Customizable wave height, direction, and characteristics
- Export high-quality images for training or analysis

In [None]:
# Install required packages if not available
import subprocess
import sys

def install_if_missing(package):
    try:
        __import__(package)
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Install required packages
install_if_missing("torch")
install_if_missing("torchvision")
install_if_missing("matplotlib")
install_if_missing("numpy")
install_if_missing("pillow")
install_if_missing("opencv-python")
install_if_missing("diffusers")
install_if_missing("transformers")
install_if_missing("accelerate")

print("‚úì All required packages are available")

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torch
from typing import Dict, Any, Tuple
import json
from tqdm.notebook import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Wave Parameter Generation

In [None]:
def sample_wave_params(rng=None):
    """Sample random wave parameters for generation."""
    if rng is None:
        rng = np.random.default_rng()
    
    height_m = float(rng.uniform(0.2, 2.5))
    wave_type = rng.choice(["beach_break", "reef_break", "point_break", "closeout", "a_frame"]).item()
    direction = rng.choice(["left", "right", "both"]).item()
    wavelength = float(rng.uniform(10.0, 28.0))
    angle_deg = float(rng.uniform(-20.0, 20.0))
    phase = float(rng.uniform(0.0, 2.0 * np.pi))

    return {
        "height_meters": height_m,
        "wave_type": wave_type,
        "direction": direction,
        "wavelength": wavelength,
        "angle_deg": angle_deg,
        "phase": phase,
        "occlusion_mode": "none",
    }

# Test parameter generation
test_params = sample_wave_params()
print("Sample wave parameters:")
for key, value in test_params.items():
    print(f"  {key}: {value}")

## Depth Map Generation

In [None]:
def generate_param_depth_map(params: Dict[str, Any], size: Tuple[int, int] = (768, 768), seed: int = 0) -> np.ndarray:
    """Generate a realistic wave depth map from parameters."""
    rng = np.random.default_rng(seed)
    H, W = size

    # Normalized coordinates: x in [-1, 1], y in [0, 1]
    x = np.linspace(-1.0, 1.0, W, dtype=np.float32)[None, :].repeat(H, axis=0)
    y = np.linspace(0.0, 1.0, H, dtype=np.float32)[:, None].repeat(W, axis=1)

    height = float(params.get("height_meters", 1.0))
    wave_type = str(params.get("wave_type", "beach_break"))
    direction = str(params.get("direction", "both"))

    # Base depth increases with y (farther is larger depth)
    gamma = 1.7
    y_p = y ** gamma
    base_depth = 0.6 + 3.5 * y_p

    # Direction controls breaker slant and approach angle
    if direction == "left":
        theta = np.deg2rad(18.0)
        slant = 0.14
    elif direction == "right":
        theta = np.deg2rad(-18.0)
        slant = -0.14
    else:
        theta = np.deg2rad(0.0)
        slant = 0.0

    # Wave frequency increases toward horizon (foreshortening)
    wavelength = float(params.get("wavelength", 18.0))
    k0 = (2.0 * np.pi) / max(wavelength, 1e-3)
    k = k0 * (1.0 + 2.8 * y_p)

    u = np.cos(theta) * x + np.sin(theta) * (y_p - 0.55)
    phase0 = float(params.get("phase", 0.0))
    phase = k * (u * 6.0) + phase0

    # Keep relief smaller than base depth
    amp = 0.08 + 0.18 * np.clip(height / 2.5, 0.0, 1.0)
    wave_relief = amp * np.sin(phase)

    # Soft spatial noise to avoid perfect stripes
    n = rng.normal(0.0, 1.0, size=(H, W)).astype(np.float32)
    n = cv2.GaussianBlur(n, (0, 0), sigmaX=3.0, sigmaY=3.0)
    n = (n - n.min()) / (n.max() - n.min() + 1e-8)
    noise_relief = (n - 0.5) * (0.04 + 0.03 * float(rng.random()))

    # Breaking line position and shape by wave_type
    break_y = 0.22 + 0.04 * float(rng.uniform(-1.0, 1.0))

    if wave_type == "closeout":
        curvature = 0.0
        irregular = 0.0
        slant *= 0.2
    elif wave_type == "a_frame":
        curvature = 0.10
        irregular = 0.01
        slant *= 0.3
    elif wave_type == "point_break":
        curvature = 0.03
        irregular = 0.01
        slant *= 1.2
    elif wave_type == "reef_break":
        curvature = 0.02
        irregular = 0.03
        slant *= 0.9
    else:
        curvature = 0.02
        irregular = 0.015
        slant *= 0.7

    # Breaker line equation
    if wave_type == "a_frame":
        line = break_y + slant * x + curvature * np.abs(x)
    else:
        line = break_y + slant * x + curvature * (x ** 2)

    # Irregularity along x
    if irregular > 0:
        ix = rng.normal(0.0, 1.0, size=(1, W)).astype(np.float32)
        ix = cv2.GaussianBlur(ix, (0, 0), sigmaX=10.0)
        ix = (ix - ix.min()) / (ix.max() - ix.min() + 1e-8)
        ix = (ix - 0.5) * irregular
        line = line + ix.repeat(H, axis=0)

    breaker_band = np.exp(-((y - line) ** 2) / (2.0 * (0.012 ** 2))).astype(np.float32)
    runup_band = np.exp(-((y - 0.10) ** 2) / (2.0 * (0.020 ** 2))).astype(np.float32)

    breaker_relief = 0.35 * breaker_band
    runup_relief = 0.22 * runup_band

    depth = base_depth - wave_relief - noise_relief - breaker_relief - runup_relief

    # Gentle shoreline slope near camera
    shore_slope = 0.25 * (1.0 - y) ** 2
    depth = depth - shore_slope

    return depth.astype(np.float32)

print("‚úì Depth map generation function defined")

## Depth Map Processing

In [None]:
def robust_normalize_to_u8(depth: np.ndarray, invert: bool = False, eps: float = 1e-8) -> np.ndarray:
    """Convert a float depth map to uint8 using robust percentile clipping."""
    d = np.asarray(depth, dtype=np.float32)

    # Replace non-finite values to avoid breaking normalization
    finite_mask = np.isfinite(d)
    if not np.any(finite_mask):
        return np.zeros(d.shape, dtype=np.uint8)

    v = d[finite_mask]
    lo = np.percentile(v, 2.0)
    hi = np.percentile(v, 98.0)
    if (hi - lo) < eps:
        hi = lo + 1.0

    d = np.clip(d, lo, hi)
    d = (d - lo) / (hi - lo + eps)

    if invert:
        d = 1.0 - d

    d_u8 = (d * 255.0).astype(np.uint8)

    # If there were NaNs originally, set them to 0
    if not np.all(finite_mask):
        d_u8[~finite_mask] = 0

    return d_u8


def to_control_image(depth_u8: np.ndarray, size: tuple = (1024, 1024)) -> Image.Image:
    """Convert a uint8 depth map into a 3-channel PIL image, resized for ControlNet."""
    d = np.asarray(depth_u8)
    if d.dtype != np.uint8:
        d = d.astype(np.uint8)

    # Ensure HxW
    if d.ndim == 3:
        d = d[..., 0]

    # Resize with good quality
    d_resized = cv2.resize(d, size, interpolation=cv2.INTER_CUBIC)

    # ControlNet expects RGB-like input, replicate grayscale into 3 channels
    rgb = np.stack([d_resized, d_resized, d_resized], axis=-1)
    return Image.fromarray(rgb, mode="RGB")

print("‚úì Depth processing functions defined")

## Generate Sample Depth Maps

In [None]:
# Generate sample depth maps for different wave types
wave_types = ["beach_break", "reef_break", "point_break", "closeout", "a_frame"]
directions = ["left", "right", "both"]

print("üåä Generating sample depth maps...")

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

sample_data = []

for i in range(6):
    # Create specific parameters for variety
    wave_type = wave_types[i % len(wave_types)]
    direction = directions[i % len(directions)]
    height = np.random.uniform(0.8, 2.2)
    
    params = {
        "height_meters": height,
        "wave_type": wave_type,
        "direction": direction,
        "wavelength": np.random.uniform(12.0, 25.0),
        "angle_deg": np.random.uniform(-15.0, 15.0),
        "phase": np.random.uniform(0, 2*np.pi),
        "occlusion_mode": "none"
    }
    
    # Generate depth map
    depth_float = generate_param_depth_map(params, size=(512, 512), seed=i*42)
    depth_u8 = robust_normalize_to_u8(depth_float, invert=True)
    
    # Store for later use
    sample_data.append((depth_u8, params))
    
    # Visualize
    im = axes[i].imshow(depth_u8, cmap='ocean', aspect='equal')
    title = f"{wave_type.replace('_', ' ').title()}\n{height:.1f}m, {direction}"
    axes[i].set_title(title, fontsize=12, fontweight='bold')
    axes[i].axis('off')
    
    print(f"‚úì Generated: {wave_type} - {height:.1f}m - {direction}")

plt.suptitle('SwellSight Generated Wave Depth Maps', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nüéâ Generated {len(sample_data)} depth maps successfully!")

## AI Image Generation Setup (Optional)

**Note**: This section requires significant computational resources and may take time to run. The depth maps above are already useful for analysis!

In [None]:
# Check if we want to generate AI images (requires more resources)
GENERATE_AI_IMAGES = True  # Set to False to skip AI generation

if GENERATE_AI_IMAGES:
    try:
        from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
        from diffusers.utils import load_image
        
        print("‚úì Diffusers library available")
        print("‚ö†Ô∏è  AI image generation will require downloading models (~6GB)")
        print("‚ö†Ô∏è  This may take 10-15 minutes per image on CPU")
        
        # Ask user confirmation
        proceed = input("\nProceed with AI image generation? (y/n): ").lower().strip()
        GENERATE_AI_IMAGES = proceed == 'y'
        
    except ImportError:
        print("‚ö†Ô∏è  Diffusers not available. Skipping AI image generation.")
        print("   Install with: pip install diffusers")
        GENERATE_AI_IMAGES = False
else:
    print("‚ÑπÔ∏è  AI image generation disabled. Using depth maps only.")

In [None]:
if GENERATE_AI_IMAGES:
    print("ü§ñ Setting up AI image generation pipeline...")
    
    # Load ControlNet model for depth
    controlnet = ControlNetModel.from_pretrained(
        "diffusers/controlnet-depth-sdxl-1.0",
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    )
    
    # Load Stable Diffusion XL pipeline
    pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        controlnet=controlnet,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        use_safetensors=True
    )
    
    if device == "cuda":
        pipe = pipe.to(device)
    else:
        # Enable CPU offloading for better memory management
        pipe.enable_model_cpu_offload()
    
    print("‚úì AI pipeline ready!")
    
    # Prompt templates
    PROMPT_TEMPLATE = (
        "Ultra realistic beach camera photo, ocean water surface, waves breaking near shore, "
        "visible shoreline and run-up foam, visible horizon line, oblique angle from sand level, "
        "wave height about {height:.1f} meters, breaking type {wave_type}, peeling direction {direction}, "
        "sea spray, whitewater, foam patterns, natural daylight, sharp focus, high detail, photo"
    )
    
    NEG_PROMPT = (
        "sand, dunes, desert, ripple sand, seabed, underwater, top-down, aerial, drone, "
        "cartoon, illustration, painting, CGI, low detail, blurry, flat water, "
        "text, watermark, logo, people, surfers, boats, buildings, "
        "bokeh, lens flare, circular blur, circle artifact, ring, discs, spots"
    )
else:
    print("‚ÑπÔ∏è  Skipping AI setup - using depth maps only")

## Generate AI Wave Images

In [None]:
if GENERATE_AI_IMAGES and 'pipe' in locals():
    print("üé® Generating AI wave images...")
    
    # Create output directory
    output_dir = "generated_wave_images"
    os.makedirs(output_dir, exist_ok=True)
    
    generated_images = []
    
    # Generate images for first 3 samples (to save time)
    for i, (depth_u8, params) in enumerate(sample_data[:3]):
        print(f"\nüñºÔ∏è  Generating image {i+1}/3...")
        
        # Prepare control image
        control_image = to_control_image(depth_u8, size=(1024, 1024))
        
        # Create prompt
        prompt = PROMPT_TEMPLATE.format(
            height=params['height_meters'],
            wave_type=params['wave_type'],
            direction=params['direction']
        )
        
        print(f"   Prompt: {prompt[:100]}...")
        
        # Generate image
        with torch.no_grad():
            result = pipe(
                prompt=prompt,
                negative_prompt=NEG_PROMPT,
                image=control_image,
                num_inference_steps=20,  # Reduced for speed
                guidance_scale=6.0,
                controlnet_conditioning_scale=0.75,
                generator=torch.Generator(device=device).manual_seed(42 + i)
            )
        
        generated_image = result.images[0]
        
        # Save image
        filename = f"ai_wave_{i+1}_{params['wave_type']}_{params['height_meters']:.1f}m.png"
        filepath = os.path.join(output_dir, filename)
        generated_image.save(filepath)
        
        generated_images.append((generated_image, params, filename))
        print(f"   ‚úì Saved: {filename}")
    
    # Display results
    fig, axes = plt.subplots(len(generated_images), 2, figsize=(16, 6*len(generated_images)))
    if len(generated_images) == 1:
        axes = axes.reshape(1, -1)
    
    for i, (ai_image, params, filename) in enumerate(generated_images):
        # Show depth map
        axes[i, 0].imshow(sample_data[i][0], cmap='ocean')
        axes[i, 0].set_title(f'Depth Map {i+1}', fontweight='bold')
        axes[i, 0].axis('off')
        
        # Show AI generated image
        axes[i, 1].imshow(ai_image)
        title = f"AI Generated Wave\n{params['wave_type'].replace('_', ' ').title()} - {params['height_meters']:.1f}m"
        axes[i, 1].set_title(title, fontweight='bold')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüéâ Generated {len(generated_images)} AI wave images!")
    print(f"üìÅ Images saved in: {output_dir}/")
    
else:
    print("‚ÑπÔ∏è  AI image generation skipped")

## Save Depth Maps

In [None]:
# Save all depth maps
depth_dir = "depth_maps"
os.makedirs(depth_dir, exist_ok=True)

print("üíæ Saving depth maps...")

saved_data = []

for i, (depth_u8, params) in enumerate(sample_data):
    # Save depth map as image
    depth_filename = f"depth_{i+1}_{params['wave_type']}_{params['height_meters']:.1f}m.png"
    depth_path = os.path.join(depth_dir, depth_filename)
    Image.fromarray(depth_u8, mode='L').save(depth_path)
    
    # Save parameters
    record = {
        "depth_path": depth_path,
        "height_meters": params['height_meters'],
        "wave_type": params['wave_type'],
        "direction": params['direction'],
        "wavelength": params['wavelength'],
        "angle_deg": params['angle_deg'],
        "source": "notebook_generated"
    }
    
    saved_data.append(record)
    print(f"‚úì Saved: {depth_filename}")

# Save metadata
metadata_path = os.path.join(depth_dir, "metadata.json")
with open(metadata_path, 'w') as f:
    json.dump(saved_data, f, indent=2)

print(f"\nüìã Metadata saved: {metadata_path}")
print(f"üìÅ All depth maps saved in: {depth_dir}/")
print(f"\nüéâ Wave generation complete!")

## Summary

This notebook successfully generated:

### ‚úÖ Depth Maps
- High-quality wave depth maps for different wave types
- Realistic wave characteristics (height, direction, breaking patterns)
- Saved as PNG images with metadata

### ‚úÖ AI Images (Optional)
- Photorealistic wave images generated from depth maps
- Uses Stable Diffusion XL + ControlNet
- Customizable wave parameters

### üéØ Use Cases
- **Training Data**: Use generated images to train SwellSight models
- **Testing**: Validate model performance on synthetic data
- **Augmentation**: Expand existing datasets with controlled variations
- **Research**: Study wave characteristics and breaking patterns

### üìÅ Output Files
- `depth_maps/`: Depth map images and metadata
- `generated_wave_images/`: AI-generated photorealistic waves (if enabled)

**Next Steps**: Use these generated images with other SwellSight notebooks for training and analysis!