# Mesh + Acoustic Optimization with PyTorch and JAX

This notebook demonstrates joint optimization of:
1. **Mesh geometry** (vertex positions)
2. **Acoustic impedance** (boundary conditions)

Using:
- JAX-FEM for acoustic Helmholtz solver
- PyTorch3D mesh losses for mesh regularization
- PyTorch optimizers with JAX autodiff backend

This follows the pattern from the Tesseract-JAX fem-shapeopt example, where we:
- Wrap JAX functions to work with PyTorch
- Use PyTorch optimizers for the optimization loop
- Compute gradients in JAX and convert to PyTorch tensors

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax_fem.solver import ad_wrapper, solver
from jax_fem.generate_mesh import Mesh

# Import local modules
from problems import AcousticHelmholtzImpedance, Source
from losses import compute_acoustic_loss
from mesh_setup import create_square_mesh_triangular  # Square mesh with triangular elements

# PyTorch for optimization and mesh losses
import torch
from pytorch3d.structures import Meshes
from pytorch3d.loss import (
    mesh_edge_loss,
    mesh_laplacian_smoothing,
    mesh_normal_consistency,
)

print("All imports successful!")

## 1. Setup Initial Mesh and Problem

In [None]:
# Physical parameters
side_length = 2.0  # Square side length
c = 343.0          # Speed of sound (m/s)
f_max = 1000       # Maximum frequency (Hz)
ppw = 5.0          # Points per wavelength

# Create initial square mesh with triangular elements
mesh, location_fns, ele_type = create_square_mesh_triangular(side_length, c, f_max, ppw)

# Store initial mesh for reference
initial_points = np.array(mesh.points)
cells = np.array(mesh.cells)

print(f"Mesh created with {len(initial_points)} vertices and {len(cells)} cells")
print(f"Element type: {ele_type}")

In [None]:
# Visualize initial mesh
fig, ax = plt.subplots(figsize=(8, 8))
ax.triplot(initial_points[:, 0], initial_points[:, 1], cells, 'k-', linewidth=0.5)
ax.set_aspect('equal')
ax.set_title('Initial Square Mesh (Triangular Elements)')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()

## 2. Setup Acoustic Problem and Reference Solution

In [None]:
def create_acoustic_problem(mesh_points, cells, k, source_params, location_fns, ele_type):
    """Create an acoustic Helmholtz problem with given mesh."""
    mesh = Mesh(mesh_points, cells)
    problem = AcousticHelmholtzImpedance(
        mesh=mesh,
        k=k,
        source_params=source_params,
        vec=1,
        dim=2,
        ele_type=ele_type,
        location_fns=location_fns,
        gauss_order=1
    )
    return problem

# Setup acoustic parameters
frequency = 500  # Hz
k = 2 * jnp.pi * frequency / c
source_params = Source(k_max=k, center=[0.0, 0.0], amplitude=1000.0)

# True impedance for synthetic data
Z_true = 1.5 + 0.3j

# Create initial problem and generate reference measurements
problem_ref = create_acoustic_problem(
    initial_points, cells, k, source_params, location_fns, ele_type
)
fwd_ref = ad_wrapper(problem_ref)
measurements = fwd_ref(Z_true)

print(f"Reference solution computed at f={frequency} Hz")
print(f"True impedance: Z = {Z_true}")

## 3. Define Loss Functions

We define the acoustic loss computed in JAX using jax-fem.

In [None]:
def compute_acoustic_loss_jax(mesh_points, Z):
    """
    Compute acoustic loss using JAX.
    
    Note: This function will NOT be differentiated w.r.t. mesh_points using jax.grad
    because JAX-FEM doesn't support that. We'll use finite differences instead.
    """
    # Solve acoustic problem
    problem = create_acoustic_problem(
        mesh_points, cells, k, source_params, location_fns, ele_type
    )
    fwd = ad_wrapper(problem)
    prediction = fwd(Z)
    
    # Acoustic loss
    loss_acoustic = compute_acoustic_loss(
        problem, prediction, measurements,
        w_mag=0.5, w_phase=0.5, w_rel=0.0
    )
    
    return loss_acoustic


def compute_mesh_regularization_torch(mesh_points_torch):
    """
    Compute PyTorch3D mesh regularization losses.
    This returns PyTorch tensors with gradients.
    """
    # Add z=0 for 2D mesh
    z_coords = torch.zeros((mesh_points_torch.shape[0], 1), device=mesh_points_torch.device)
    verts = torch.cat([mesh_points_torch, z_coords], dim=1).unsqueeze(0)  # (1, N, 3)
    
    faces = torch.from_numpy(cells).long().unsqueeze(0)  # (1, M, 3)
    
    # Create PyTorch3D mesh
    mesh_pt3d = Meshes(verts=verts, faces=faces)
    
    # Compute losses
    loss_edge = mesh_edge_loss(mesh_pt3d)
    loss_laplacian = mesh_laplacian_smoothing(mesh_pt3d, method="uniform")
    loss_normal = mesh_normal_consistency(mesh_pt3d)
    
    # Weighted sum
    total_loss = loss_edge + loss_laplacian + loss_normal
    
    return total_loss, {
        'edge': loss_edge.item(),
        'laplacian': loss_laplacian.item(),
        'normal': loss_normal.item()
    }

print("Loss functions defined")

## 4. Combined Loss Function with JAX and PyTorch

We support two methods for computing mesh gradients:

1. **Finite differences** (default): Slow but works with current JAX-FEM
2. **JAX autodiff** (experimental): Fast but fails due to JAX-FEM's `jnp.take` usage

You can toggle between them using the `use_jax_mesh_grad` parameter.

In [None]:
class CombinedLoss:
    """Combined loss function with configurable gradient computation methods."""
    
    def __init__(self, w_acoustic=1.0, w_mesh_reg=0.1, fd_epsilon=1e-5, use_jax_mesh_grad=False):
        """
        Args:
            w_acoustic: Weight for acoustic loss
            w_mesh_reg: Weight for mesh regularization
            fd_epsilon: Finite difference step size
            use_jax_mesh_grad: If True, attempt JAX autodiff for mesh (will fail with current JAX-FEM)
                              If False, use finite differences (slow but works)
        """
        self.w_acoustic = w_acoustic
        self.w_mesh_reg = w_mesh_reg
        self.fd_epsilon = fd_epsilon
        self.use_jax_mesh_grad = use_jax_mesh_grad
        
        if use_jax_mesh_grad:
            print("WARNING: JAX autodiff for mesh will likely fail with current JAX-FEM")
            print("  Error: 'The 'raise' mode to jnp.take is not supported.'")
            print("  This happens during Problem initialization when JAX traces through mesh creation")
            # JIT compile gradient function for both mesh and Z
            self.jax_grad_fn_both = jax.jit(jax.value_and_grad(
                compute_acoustic_loss_jax, argnums=(0, 1)  # Both mesh and Z
            ))
        else:
            # JIT compile the JAX gradient function for Z only
            self.jax_grad_fn_Z = jax.jit(jax.value_and_grad(
                compute_acoustic_loss_jax, argnums=1  # Only differentiate w.r.t. Z
            ))
        
    def compute_mesh_gradient_fd(self, mesh_points_np, Z_complex):
        """Compute acoustic loss gradient w.r.t. mesh using finite differences.
        
        This is more efficient than differentiating each coordinate separately.
        We compute gradients for a subset of points at each iteration.
        """
        n_points, n_dims = mesh_points_np.shape
        grad_mesh = np.zeros_like(mesh_points_np)
        
        # Compute base loss
        base_loss = float(compute_acoustic_loss_jax(mesh_points_np, Z_complex))
        
        # Compute gradient for each coordinate using finite differences
        # This is the bottleneck, but unavoidable given JAX-FEM's limitations
        for i in range(n_points):
            for d in range(n_dims):
                # Perturb this coordinate
                mesh_perturbed = mesh_points_np.copy()
                mesh_perturbed[i, d] += self.fd_epsilon
                
                # Compute loss with perturbation
                loss_perturbed = float(compute_acoustic_loss_jax(mesh_perturbed, Z_complex))
                
                # Finite difference gradient
                grad_mesh[i, d] = (loss_perturbed - base_loss) / self.fd_epsilon
        
        return grad_mesh, base_loss
    
    def __call__(self, mesh_points_torch, Z_torch):
        """
        Compute total loss and return it as PyTorch tensor.
        
        Args:
            mesh_points_torch: PyTorch tensor (N, 2)
            Z_torch: PyTorch tensor (2,) representing [real, imag]
        """
        # Convert to numpy/JAX for acoustic loss
        mesh_points_np = mesh_points_torch.detach().cpu().numpy()
        Z_complex = complex(Z_torch[0].item(), Z_torch[1].item())
        
        if self.use_jax_mesh_grad:
            # Attempt JAX autodiff for both mesh and Z (will likely fail)
            try:
                mesh_points_jax = jnp.array(mesh_points_np)
                acoustic_loss_jax, (grad_mesh_jax, grad_Z_jax) = self.jax_grad_fn_both(
                    mesh_points_jax, Z_complex
                )
                grad_mesh_np = np.array(grad_mesh_jax)
                acoustic_loss_val = float(acoustic_loss_jax)
                print("  SUCCESS: JAX autodiff worked for mesh! (unexpected)")
            except NotImplementedError as e:
                print(f"  ERROR: JAX autodiff failed as expected: {e}")
                print("  Falling back to finite differences...")
                # Fallback to finite differences
                grad_mesh_np, acoustic_loss_val = self.compute_mesh_gradient_fd(mesh_points_np, Z_complex)
                # Still need Z gradient
                _, grad_Z_jax = self.jax_grad_fn_Z(mesh_points_np, Z_complex)
        else:
            # Use finite differences for mesh (safe but slow)
            grad_mesh_np, acoustic_loss_val = self.compute_mesh_gradient_fd(mesh_points_np, Z_complex)
            
            # Use JAX autodiff for Z
            _, grad_Z_jax = self.jax_grad_fn_Z(mesh_points_np, Z_complex)
        
        # Convert acoustic loss to PyTorch
        acoustic_loss = torch.tensor(
            float(acoustic_loss_val),
            dtype=torch.float32,
            requires_grad=True
        )
        
        # Compute mesh regularization in PyTorch (with native gradients)
        mesh_loss, mesh_metrics = compute_mesh_regularization_torch(mesh_points_torch)
        
        # Combined loss
        total_loss = self.w_acoustic * acoustic_loss + self.w_mesh_reg * mesh_loss
        
        # Store gradients (we'll apply them manually)
        self.jax_grad_mesh = grad_mesh_np
        self.jax_grad_Z = grad_Z_jax
        
        return total_loss, {
            'acoustic': float(acoustic_loss_val),
            'mesh': mesh_loss.item(),
            'mesh_metrics': mesh_metrics
        }

print("Combined loss class defined")

## 5. Optimization Loop with PyTorch

We use PyTorch optimizers but manually apply gradients from both JAX (impedance) and finite differences (mesh).

In [None]:
def optimize_mesh_and_impedance(
    initial_points, cells, measurements,
    n_iterations=100,
    lr_mesh=0.001,
    lr_impedance=0.01,
    w_acoustic=1.0,
    w_mesh_reg=0.1,
    use_jax_mesh_grad=False
):
    """
    Joint optimization using PyTorch optimizers with hybrid gradients.
    
    Args:
        initial_points: Initial mesh vertex positions (N, 2)
        cells: Mesh connectivity (M, 3)
        measurements: Reference measurements for inverse problem
        n_iterations: Number of optimization iterations
        lr_mesh: Learning rate for mesh vertices
        lr_impedance: Learning rate for impedance parameter
        w_acoustic: Weight for acoustic loss
        w_mesh_reg: Weight for mesh regularization
        use_jax_mesh_grad: If True, attempt JAX autodiff for mesh (will fail with current JAX-FEM)
                          If False, use finite differences (slow but works)
    """
    # Initialize parameters as PyTorch tensors
    mesh_points = torch.tensor(
        initial_points, dtype=torch.float32, requires_grad=True
    )
    Z_params = torch.tensor(
        [1.0, 0.1], dtype=torch.float32, requires_grad=True  # [real, imag]
    )
    
    # Create optimizers
    optimizer_mesh = torch.optim.Adam([mesh_points], lr=lr_mesh)
    optimizer_Z = torch.optim.Adam([Z_params], lr=lr_impedance)
    
    # Create loss function
    loss_fn = CombinedLoss(
        w_acoustic=w_acoustic, 
        w_mesh_reg=w_mesh_reg,
        use_jax_mesh_grad=use_jax_mesh_grad
    )
    
    # History
    history = {
        'total_loss': [],
        'acoustic_loss': [],
        'mesh_loss': [],
        'Z_history': [],
        'mesh_history': []
    }
    
    print("Starting optimization...")
    print(f"Initial Z guess: {Z_params[0]:.4f} + {Z_params[1]:.4f}j")
    print(f"True Z: {Z_true}")
    print(f"Mesh gradient method: {'JAX autodiff (will likely fail)' if use_jax_mesh_grad else 'Finite differences'}")
    if not use_jax_mesh_grad:
        print(f"WARNING: This will be slow due to finite difference mesh gradients")
        print(f"  (~{2 * len(initial_points)} forward passes per iteration)")
    print()
    
    for i in range(n_iterations):
        # Zero gradients
        optimizer_mesh.zero_grad()
        optimizer_Z.zero_grad()
        
        # Compute loss (this computes gradients internally)
        print(f"Iter {i}: Computing loss and gradients...")
        total_loss, loss_dict = loss_fn(mesh_points, Z_params)
        
        # Backward pass for PyTorch components (mesh regularization)
        total_loss.backward()
        
        # Add finite difference gradients for mesh (acoustic part)
        grad_mesh_fd = torch.tensor(
            loss_fn.jax_grad_mesh, dtype=torch.float32
        )
        mesh_points.grad += w_acoustic * grad_mesh_fd
        
        # For impedance: add JAX gradients
        grad_Z_complex = loss_fn.jax_grad_Z
        # Fix complex gradient for Wirtinger calculus
        grad_Z_fixed = jnp.real(grad_Z_complex) - 1j * jnp.imag(grad_Z_complex)
        grad_Z_torch = torch.tensor(
            [float(jnp.real(grad_Z_fixed)), float(jnp.imag(grad_Z_fixed))],
            dtype=torch.float32
        )
        Z_params.grad = grad_Z_torch if Z_params.grad is None else Z_params.grad + w_acoustic * grad_Z_torch
        
        # Update parameters
        optimizer_mesh.step()
        optimizer_Z.step()
        
        # Store history
        Z_current = complex(Z_params[0].item(), Z_params[1].item())
        history['total_loss'].append(total_loss.item())
        history['acoustic_loss'].append(loss_dict['acoustic'])
        history['mesh_loss'].append(loss_dict['mesh'])
        history['Z_history'].append(Z_current)
        if i % 5 == 0:  # Store mesh less frequently to save memory
            history['mesh_history'].append(mesh_points.detach().cpu().numpy())
        
        # Print progress
        print(f"Iter {i:3d}: Loss={total_loss.item():.6f} "
              f"(acoustic={loss_dict['acoustic']:.6f}, mesh={loss_dict['mesh']:.6f}) "
              f"Z={Z_current:.4f}")
    
    print("\nOptimization complete!")
    print(f"Final Z: {Z_current}")
    print(f"True Z:  {Z_true}")
    print(f"Error: {np.abs(Z_current - Z_true):.6f}")
    
    return mesh_points.detach().cpu().numpy(), Z_current, history

print("Optimization function defined")

In [None]:
# Run optimization
# 
# Option 1: Use finite differences for mesh gradients (SLOW but WORKS)
# Option 2: Use JAX autodiff for mesh gradients (FAST but FAILS with current JAX-FEM)
#
# Set use_jax_mesh_grad=True to see the error and attempt JAX autodiff
# Set use_jax_mesh_grad=False to use finite differences (default, recommended)

USE_JAX_MESH_GRAD = False  # Change to True to attempt JAX autodiff (will fail)

if USE_JAX_MESH_GRAD:
    print("=" * 70)
    print("ATTEMPTING JAX AUTODIFF FOR MESH (will likely fail)")
    print("=" * 70)
    print()

optimized_mesh, optimized_Z, history = optimize_mesh_and_impedance(
    initial_points, cells, measurements,
    n_iterations=10,  # Reduced for speed - increase for better results
    lr_mesh=0.0001,   # Small learning rate for stability
    lr_impedance=0.01,
    w_acoustic=1.0,
    w_mesh_reg=0.05,
    use_jax_mesh_grad=USE_JAX_MESH_GRAD  # Set to True to attempt JAX autodiff
)

## 6. Visualize Results

In [None]:
# Plot loss curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Total loss
axes[0, 0].semilogy(history['total_loss'])
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Total Loss')
axes[0, 0].set_title('Total Loss')
axes[0, 0].grid(True)

# Acoustic vs Mesh loss
axes[0, 1].semilogy(history['acoustic_loss'], label='Acoustic')
axes[0, 1].semilogy(history['mesh_loss'], label='Mesh Reg')
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Individual Losses')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Impedance convergence - Real part
Z_real = [np.real(z) for z in history['Z_history']]
axes[1, 0].plot(Z_real, label='Estimated')
axes[1, 0].axhline(y=np.real(Z_true), color='r', linestyle='--', label='True')
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Re(Z)')
axes[1, 0].set_title('Impedance - Real Part')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Impedance convergence - Imaginary part
Z_imag = [np.imag(z) for z in history['Z_history']]
axes[1, 1].plot(Z_imag, label='Estimated')
axes[1, 1].axhline(y=np.imag(Z_true), color='r', linestyle='--', label='True')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Im(Z)')
axes[1, 1].set_title('Impedance - Imaginary Part')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Compare initial vs optimized mesh
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Initial mesh
axes[0].triplot(initial_points[:, 0], initial_points[:, 1], cells, 'b-', linewidth=0.5)
axes[0].set_aspect('equal')
axes[0].set_title('Initial Mesh')
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')

# Optimized mesh
axes[1].triplot(optimized_mesh[:, 0], optimized_mesh[:, 1], cells, 'r-', linewidth=0.5)
axes[1].set_aspect('equal')
axes[1].set_title('Optimized Mesh')
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')

plt.tight_layout()
plt.show()

In [None]:
# Visualize mesh evolution
n_snapshots = len(history['mesh_history'])
fig, axes = plt.subplots(1, min(4, n_snapshots), figsize=(16, 4))

snapshot_indices = np.linspace(0, n_snapshots-1, min(4, n_snapshots), dtype=int)

for idx, snap_idx in enumerate(snapshot_indices):
    ax = axes[idx] if n_snapshots > 1 else axes
    mesh_snap = history['mesh_history'][snap_idx]
    ax.triplot(mesh_snap[:, 0], mesh_snap[:, 1], cells, 'k-', linewidth=0.5)
    ax.set_aspect('equal')
    ax.set_title(f'Iteration {snap_idx * 5}')
    ax.set_xlabel('x')
    ax.set_ylabel('y')

plt.tight_layout()
plt.show()

In [None]:
# Compute and visualize displacement field
displacement = optimized_mesh - initial_points
displacement_mag = np.linalg.norm(displacement, axis=1)

fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
    initial_points[:, 0], initial_points[:, 1],
    c=displacement_mag, cmap='viridis', s=50
)
ax.quiver(
    initial_points[:, 0], initial_points[:, 1],
    displacement[:, 0], displacement[:, 1],
    scale=0.1, alpha=0.5
)
plt.colorbar(scatter, ax=ax, label='Displacement magnitude')
ax.set_aspect('equal')
ax.set_title('Vertex Displacement Field')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()

print(f"Max displacement: {displacement_mag.max():.6f}")
print(f"Mean displacement: {displacement_mag.mean():.6f}")

## 7. Visualize Acoustic Solution

In [None]:
# Solve with optimized impedance parameter
problem_opt = create_acoustic_problem(
    final_mesh, cells, k, source_params, location_fns, ele_type
)
fwd_opt = ad_wrapper(problem_opt)
solution_opt = fwd_opt(optimized_Z)

# Extract pressure field
pressure_opt = solution_opt[0][:, 0]
pressure_ref = measurements[0][:, 0]

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Reference solution
sc0 = axes[0].tripcolor(
    initial_points[:, 0], initial_points[:, 1], cells,
    np.abs(pressure_ref), shading='gouraud', cmap='viridis'
)
axes[0].set_aspect('equal')
axes[0].set_title('Reference |p|')
plt.colorbar(sc0, ax=axes[0])

# Optimized solution
sc1 = axes[1].tripcolor(
    final_mesh[:, 0], final_mesh[:, 1], cells,
    np.abs(pressure_opt), shading='gouraud', cmap='viridis'
)
axes[1].set_aspect('equal')
axes[1].set_title('Optimized |p|')
plt.colorbar(sc1, ax=axes[1])

# Error
error = np.abs(pressure_opt - pressure_ref)
sc2 = axes[2].tripcolor(
    final_mesh[:, 0], final_mesh[:, 1], cells,
    error, shading='gouraud', cmap='hot'
)
axes[2].set_aspect('equal')
axes[2].set_title('Absolute Error')
plt.colorbar(sc2, ax=axes[2])

plt.tight_layout()
plt.show()

print(f"Max error: {error.max():.6e}")
print(f"Mean error: {error.mean():.6e}")
print(f"Relative error: {error.mean() / np.abs(pressure_ref).mean():.6%}")

## Summary

This notebook demonstrates:

1. **Hybrid optimization approach** for acoustic mesh optimization:
   - **Finite differences** for mesh gradients (default, works but slow)
   - **JAX autodiff** for impedance gradients (fast)
   - **PyTorch3D** for mesh regularization losses
   - **PyTorch optimizers** for parameter updates

2. **Configurable gradient computation**:
   - Set `use_jax_mesh_grad=False` (default): Use finite differences for mesh
   - Set `use_jax_mesh_grad=True`: Attempt JAX autodiff (will fail with current JAX-FEM)

3. **Key limitation**: Mesh gradient computation via FD is slow (~2N forward passes per iteration where N=number of mesh points). For the square mesh with ~1100 points, this is ~2200 FEM solves per iteration.

4. **Why JAX autodiff fails**: JAX-FEM doesn't support autodiff through mesh coordinates because:
   - Problem initialization involves `numpy.take` with mode='raise'
   - When JAX traces through this during autodiff, it fails with NotImplementedError
   - Finite differences avoid this by treating each forward pass independently

5. **Better alternatives** for large-scale mesh optimization:
   - Parametrize mesh deformation with low-dimensional parameters (e.g., RBF deformation)
   - Use a coarser mesh for optimization
   - Use the Tesseract approach (separate mesh deformation module with finite difference Jacobian)
   - Wait for JAX-FEM to support mesh differentiation

The current implementation works but is computationally expensive. It successfully demonstrates joint optimization of mesh geometry and acoustic impedance using PyTorch3D for mesh quality.