# Inverse Design Workflow: Grating Coupler

This notebook demonstrates gradient-based inverse design of a silicon-on-insulator (SOI) grating coupler using the Hyperwave cloud API. The optimization uses the adjoint method to compute gradients efficiently, running forward and adjoint FDTD simulations on cloud GPUs.

**Steps:**
1. Configure API
2. Set physical parameters
3. Define grid and layer stack
4. Create initial design
5. Configure source and mode
6. Set up optimization
7. Estimate cost
8. Run optimization
9. Analyze results

## Overview\n\n**Device:** Silicon grating coupler for fiber-to-chip coupling\n\n**Method:** Gradient-based inverse design using the adjoint method (forward + adjoint FDTD on cloud GPU)\n\n**Time:** 15-20 min (5-step demo) or 1-2 hours (50-step production)\n\n**Cost:** ~0.30 credits (5-step demo) or ~3.0 credits (50-step) on B200 GPU\n\nAll structure setup, source generation, and analysis run locally.\nOnly the simulation and optimization steps use cloud GPU credits.

## Installation

In [None]:
# Install hyperwave-community
%pip install "hyperwave-community @ git+https://github.com/spinsphotonics/hyperwave-community.git" -q

In [None]:
import hyperwave_community as hwc

## Step 1: Physical Parameters

Standard 220nm SOI at 1550nm with partial-etch grating coupler.

In [None]:
import math
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Materials
n_si = 3.48
n_sio2 = 1.44
n_clad = 1.44
n_air = 1.0

# Wavelength
wavelength_um = 1.55

# Layer thicknesses (um)
h_dev = 0.220        # total silicon device layer
etch_depth = 0.110   # partial etch depth
h_box = 2.0          # buried oxide
h_clad = 0.78        # SiO2 cladding
h_sub = 0.8          # silicon substrate
h_air = 1.0          # air above cladding
pad = 3.0            # absorber padding (top and bottom)

# Grid resolution
dx = 0.035           # 35nm structure grid
pixel_size = dx / 2  # 17.5nm theta grid (2x for subpixel averaging in epsilon())
domain = 40.0        # um total domain

# Waveguide
wg_width = 0.5       # um
wg_length = 5.0      # um

# Fiber
beam_waist = 5.2     # um (SMF-28 mode field radius at 1550nm)
fiber_angle = 14.5   # degrees from vertical (standard SMF coupling angle)

# Structure grid dimensions (35nm)
Lx = int(domain / dx)
Ly = Lx

# Theta grid dimensions (17.5nm, 2x structure)
theta_Lx = 2 * Lx
theta_Ly = 2 * Ly

# Layer thicknesses in pixels (FLOAT for subpixel averaging, critical for
# thin layers like etch where rounding shifts interfaces by up to 0.5 cells)
h_p = pad / dx
h0 = h_air / dx
h1 = h_clad / dx
h2 = etch_depth / dx
h3 = (h_dev - etch_depth) / dx
h4 = h_box / dx
h5 = h_sub / dx
Lz = int(math.ceil(h_p + h0 + h1 + h2 + h3 + h4 + h5 + h_p))

# Key Z positions (rounded to nearest pixel for array indexing)
z_etch = int(round(h_p + h0 + h1))
z_slab = z_etch + int(round(h2))
z_box = z_slab + int(round(h3))

# Frequency
wl_px = wavelength_um / dx
freq = 2 * np.pi / wl_px
freq_band = (freq, freq, 1)

# Permittivity
eps_si = n_si**2
eps_sio2 = n_sio2**2
eps_clad = n_clad**2
eps_air = n_air**2

# Density filtering (matches reference GC design)
DENSITY_RADIUS = 3
DENSITY_ALPHA = 0.8
DESIGN_LAYER = 3  # etch layer index

print(f"Structure grid: {Lx} x {Ly} x {Lz} ({dx * 1000:.0f} nm)")
print(f"Theta grid: {theta_Lx} x {theta_Ly} ({pixel_size * 1000:.1f} nm)")
print(f"Layers (px): pad={h_p:.2f} air={h0:.2f} clad={h1:.2f} "
      f"etch={h2:.2f} slab={h3:.2f} BOX={h4:.2f} sub={h5:.2f} pad={h_p:.2f}")
print(f"Fiber angle: {fiber_angle} deg")

## Step 2: Grid and Layer Stack

8-layer SOI stack. The etch layer is controlled by theta with density filtering
(`radius=3`, `alpha=0.8`) for minimum feature size control.

Layers (top to bottom): pad (air, 3um) | air (1um) | cladding (SiO2, 0.78um) |
**etch (SiO2/Si, 0.11um)** | slab (Si, 0.11um) | BOX (SiO2, 2um) |
substrate (Si, 0.8um) | pad (Si, 3um)

In [None]:
# Waveguide in structure pixels (35nm grid)
wg_len = int(round(wg_length / dx))
wg_hw = int(round(wg_width / 2 / dx))

# Waveguide in theta pixels (17.5nm grid, 2x structure)
wg_len_theta = int(round(wg_length / pixel_size))
wg_hw_theta = int(round(wg_width / 2 / pixel_size))

# Rectangular design region (in theta pixels)
abs_margin = 80  # structure pixels
abs_margin_theta = 2 * abs_margin
design_region = {
    'x_start': wg_len_theta,
    'x_end': theta_Lx - abs_margin_theta,
    'y_start': abs_margin_theta,
    'y_end': theta_Ly - abs_margin_theta,
}

# Build initial theta at 2x resolution (17.5nm)
# theta=1 means unetched silicon, theta=0 means etched to cladding
theta_init = np.zeros((theta_Lx, theta_Ly), dtype=np.float32)
theta_init[:wg_len_theta, theta_Ly // 2 - wg_hw_theta : theta_Ly // 2 + wg_hw_theta] = 1.0
dr = design_region
theta_init[dr['x_start']:dr['x_end'], dr['y_start']:dr['y_end']] = 0.5

# Build 3D structure from initial theta
density_etch = hwc.density(jnp.array(theta_init), radius=DENSITY_RADIUS, alpha=DENSITY_ALPHA)
zero = jnp.zeros((theta_Lx, theta_Ly))
ones = jnp.ones((theta_Lx, theta_Ly))

# hwc.Layer(density_pattern, permittivity_values, layer_thickness)
# Float thicknesses enable subpixel averaging at layer interfaces.
# Density patterns are at 2x (theta) resolution; create_structure downsamples
# to structure resolution via epsilon() with magnification=1.
layers = [
    hwc.Layer(zero,         eps_air,            h_p),  # pad (top)
    hwc.Layer(zero,         eps_air,            h0),   # air
    hwc.Layer(zero,         eps_clad,           h1),   # cladding
    hwc.Layer(density_etch, (eps_clad, eps_si), h2),   # etch (designable)
    hwc.Layer(ones,         eps_si,             h3),   # slab (solid Si)
    hwc.Layer(zero,         eps_sio2,           h4),   # BOX
    hwc.Layer(zero,         eps_si,             h5),   # substrate
    hwc.Layer(zero,         eps_si,             h_p),  # pad (bottom)
]
structure = hwc.create_structure(layers=layers, vertical_radius=0)
eps_3d = np.array(structure.permittivity[0])

print(f"Structure shape: {eps_3d.shape}")
print(f"Theta shape: {theta_init.shape} ({pixel_size * 1000:.1f} nm)")

## Step 3: Initial Design

`theta` controls the etch layer (top 110nm of 220nm Si):
`theta=1` = unetched silicon, `theta=0` = etched to cladding.

Theta operates at 2x the structure resolution (17.5nm vs 35nm). The `create_structure`
function downsamples via subpixel averaging in the `epsilon()` function, giving smoother
gradients and finer geometric control.

Fixed waveguide on the left, rectangular design region initialized to 0.5.

In [None]:
# Plot initial theta and permittivity cross-sections
fig, axes = plt.subplots(1, 4, figsize=(22, 5))

# Theta
extent_theta = [0, theta_Lx * pixel_size, 0, theta_Ly * pixel_size]
axes[0].imshow(theta_init.T, origin='lower', cmap='PuOr', vmin=0, vmax=1, extent=extent_theta)
rect = Rectangle(
    (dr['x_start'] * pixel_size, dr['y_start'] * pixel_size),
    (dr['x_end'] - dr['x_start']) * pixel_size,
    (dr['y_end'] - dr['y_start']) * pixel_size,
    linewidth=2, edgecolor='lime', facecolor='none', linestyle='--',
)
axes[0].add_patch(rect)
axes[0].set_xlabel('x (um)')
axes[0].set_ylabel('y (um)')
axes[0].set_title('Initial Theta')
plt.colorbar(axes[0].images[0], ax=axes[0], label='theta')

# Permittivity XY at device layer
z_dev = z_etch + int(h2 // 2)
im1 = axes[1].imshow(eps_3d[:, :, z_dev].T, origin='lower', cmap='viridis',
                      extent=[0, Lx * dx, 0, Ly * dx])
axes[1].set_xlabel('x (um)')
axes[1].set_ylabel('y (um)')
axes[1].set_title(f'Permittivity XY (z = {z_dev * dx:.2f} um)')
plt.colorbar(im1, ax=axes[1])

# Permittivity XZ at y=center
im2 = axes[2].imshow(eps_3d[:, Ly // 2, :].T, origin='lower', cmap='viridis',
                      aspect='auto', extent=[0, Lx * dx, 0, Lz * dx])
axes[2].axhline(z_etch * dx, color='r', ls='--', lw=0.8)
axes[2].axhline(z_slab * dx, color='r', ls='--', lw=0.8)
axes[2].set_xlabel('x (um)')
axes[2].set_ylabel('z (um)')
axes[2].set_title('Permittivity XZ (y = center)')
plt.colorbar(im2, ax=axes[2])

# Permittivity YZ at x=center
im3 = axes[3].imshow(eps_3d[Lx // 2, :, :].T, origin='lower', cmap='viridis',
                      aspect='auto', extent=[0, Ly * dx, 0, Lz * dx])
axes[3].axhline(z_etch * dx, color='r', ls='--', lw=0.8)
axes[3].axhline(z_slab * dx, color='r', ls='--', lw=0.8)
axes[3].set_xlabel('y (um)')
axes[3].set_ylabel('z (um)')
axes[3].set_title('Permittivity YZ (x = center)')
plt.colorbar(im3, ax=axes[3])

plt.tight_layout()
plt.show()

print(f"Design region: {(dr['x_end'] - dr['x_start']) * pixel_size:.1f} x "
      f"{(dr['y_end'] - dr['y_start']) * pixel_size:.1f} um")

## Step 4: Gaussian Source

Unidirectional Gaussian beam via the wave equation error method. Propagates
downward (-z) at 14.5 degrees from vertical, simulating fiber illumination from
above at the standard SMF coupling angle. Negative `theta` tilts the beam
toward the waveguide (in the -x direction).

In [None]:
# Source position: in the air gap, 50nm above cladding surface
source_above_surface_um = 0.05
source_z = int(round((pad + h_air - source_above_surface_um) / dx))

# Grating center in structure pixels (convert from theta pixel design region)
grating_x = int(round((dr['x_start'] + dr['x_end']) / 2 * pixel_size / dx))
grating_y = Ly // 2
waist_px = beam_waist / dx

source_field, input_power = hwc.create_gaussian_source(
    sim_shape=(Lx, Ly, Lz),
    frequencies=jnp.array([freq]),
    source_pos=(grating_x, grating_y, source_z),
    waist_radius=waist_px,
    theta=-fiber_angle,  # negative tilts beam toward waveguide (-x direction)
    phi=0.0,             # tilt in XZ plane
    polarization='y',
    x_span=float(Lx),
    y_span=float(Ly),
    max_steps=5000,
    check_every_n=200,
    show_plots=False,
)

source_offset = (grating_x - Lx // 2, grating_y - Ly // 2, source_z)
input_power = float(jnp.mean(input_power))

# Plot source |Ey| and |Hx|
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, idx, name in [(axes[0], 1, '|Ey|'), (axes[1], 3, '|Hx|')]:
    im = ax.imshow(np.abs(np.array(source_field[0, idx, :, :, 0])).T,
                   origin='lower', cmap='hot', extent=[0, Lx * dx, 0, Ly * dx])
    ax.set_xlabel('x (um)')
    ax.set_ylabel('y (um)')
    ax.set_title(f'Source {name}')
    plt.colorbar(im, ax=ax)
plt.suptitle(f'Gaussian source (waist = {beam_waist} um, tilt = {fiber_angle} deg)', fontsize=13)
plt.tight_layout()
plt.show()

print(f"Source shape: {source_field.shape}")
print(f"Source offset: {source_offset}")
print(f"Input power: {input_power:.6f}")

## Step 5: Waveguide Mode

Solve for the fundamental TE eigenmode, then use `mode_converter` to get
the full E+H field (needed for mode overlap loss computation).

In [None]:
from hyperwave_community.sources import mode_converter

# Solve eigenmode at the waveguide center
mode_x_pos = wg_len // 2

source_mode, offset_mode, mode_info = hwc.create_mode_source(
    structure=structure,
    freq_band=freq_band,
    mode_num=0,
    propagation_axis='x',
    source_position=mode_x_pos,
    perpendicular_bounds=(Ly // 2 - 50, Ly // 2 + 50),
    z_bounds=(z_etch - 10, z_box + 10),
    visualize=False,
)

n_eff = float(mode_info['beta'][0]) / (2 * np.pi / wl_px)

# Convert E-only mode to full E+H via short waveguide propagation
eps_slice = structure.permittivity[:, mode_x_pos, :, :]
mode_E = mode_info['field']  # (1, 3, 1, y_crop, z_crop)

mode_full = mode_converter(
    mode_E_field=mode_E,
    freq_band=freq_band,
    permittivity_slice=eps_slice,
    propagation_axis='x',
    visualize=False,
)

# Compute P_mode_cross for mode overlap normalization
mode_np = np.array(mode_full)
E_mode = mode_np[0, 0:3, 0]  # (3, y_crop, z_crop)
H_mode = mode_np[0, 3:6, 0]  # (3, y_crop, z_crop)
cross = np.cross(E_mode, np.conj(H_mode), axis=0)
P_mode_cross = float(np.abs(np.real(np.sum(cross[0, :, :]))))

# Plot E and H components
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
mode_E_np = np.array(mode_E)
mode_full_np = np.array(mode_full)

for i, name in enumerate(['Ex', 'Ey', 'Ez']):
    im = axes[0, i].imshow(np.abs(mode_E_np[0, i, 0]).T, origin='lower',
                            cmap='viridis', aspect='auto')
    axes[0, i].set_title(f'|{name}|')
    axes[0, i].set_xlabel('y (px)')
    axes[0, i].set_ylabel('z (px)')
    plt.colorbar(im, ax=axes[0, i])

for i, name in enumerate(['Hx', 'Hy', 'Hz']):
    im = axes[1, i].imshow(np.abs(mode_full_np[0, 3 + i, 0]).T, origin='lower',
                            cmap='viridis', aspect='auto')
    axes[1, i].set_title(f'|{name}|')
    axes[1, i].set_xlabel('y (px)')
    axes[1, i].set_ylabel('z (px)')
    plt.colorbar(im, ax=axes[1, i])

plt.suptitle(f'Waveguide mode (n_eff = {n_eff:.4f})', fontsize=14)
plt.tight_layout()
plt.show()

print(f"n_eff = {n_eff:.4f}")
print(f"P_mode_cross = {P_mode_cross:.6f}")
print(f"Mode solve error: {float(mode_info['error'][0]):.2e}")
print(f"Mode full shape: {mode_full.shape}")

## Step 6: Forward Simulation (Verify Setup)

Run a forward simulation with the initial design to verify:
- Source illuminates the design region correctly
- Light propagates toward the waveguide
- Layer stack and monitors are correctly placed

This costs 1 GPU simulation but catches setup errors before optimization.

### API Key Setup

This notebook uses Colab Secrets to securely store your API key.
To set up your key:

1. Click the key icon in the left sidebar of this notebook
2. Add a new secret named `HYPERWAVE_API_KEY`
3. Paste your API key as the value
4. Toggle "Notebook access" to ON

If you don't have an API key, [sign up](https://spinsphotonics.com/signup) to get one for free.

In [None]:
from google.colab import userdata
hwc.configure_api(api_key=userdata.get('HYPERWAVE_API_KEY'))
hwc.get_account_info();

In [None]:
# Absorber (auto-scaled for resolution)
abs_params = hwc.get_optimized_absorber_params(
    resolution_nm=dx * 1000,
    wavelength_um=wavelength_um,
    structure_dimensions=(Lx, Ly, Lz),
)
abs_widths = abs_params['absorption_widths']
abs_coeff = abs_params['absorber_coeff']

# Set up monitors for field extraction
monitors = hwc.MonitorSet()
output_x = abs_widths[0] + 10

# XY slice at device layer
monitors.add(hwc.Monitor(shape=(Lx, Ly, 1), offset=(0, 0, z_dev)), name='xy_device')
# XZ slice at y=center
monitors.add(hwc.Monitor(shape=(Lx, 1, Lz), offset=(0, Ly // 2, 0)), name='xz_center')
# YZ slice at x=center
monitors.add(hwc.Monitor(shape=(1, Ly, Lz), offset=(Lx // 2, 0, 0)), name='yz_center')
# Waveguide output
monitors.add(hwc.Monitor(shape=(1, Ly, Lz), offset=(output_x, 0, 0)), name='wg_output')

# Run forward simulation
recipe = structure.extract_recipe()
fwd_results = hwc.simulate(
    structure_recipe=recipe,
    source_field=source_field,
    source_offset=source_offset,
    freq_band=freq_band,
    monitors_recipe=monitors.recipe,
    absorption_widths=abs_widths,
    absorption_coeff=abs_coeff,
)

print(f"Forward sim complete: {fwd_results['sim_time']:.1f}s GPU time")

In [None]:
# Extract fields and plot |E|^2 cross-sections
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, name, title, ext, xlabel, ylabel in [
    (axes[0], 'xy_device', f'|E|^2 XY (z = {z_dev * dx:.2f} um)',
     [0, Lx * dx, 0, Ly * dx], 'x (um)', 'y (um)'),
    (axes[1], 'xz_center', '|E|^2 XZ (y = center)',
     [0, Lx * dx, 0, Lz * dx], 'x (um)', 'z (um)'),
    (axes[2], 'yz_center', '|E|^2 YZ (x = center)',
     [0, Ly * dx, 0, Lz * dx], 'y (um)', 'z (um)'),
]:
    field = np.array(fwd_results['monitor_data'][name])
    E2 = np.sum(np.abs(field[0, 0:3])**2, axis=0).squeeze()
    vmax = np.percentile(E2, 95) * 4
    im = ax.imshow(E2.T, origin='lower', cmap='hot', extent=ext,
                   aspect='auto', vmax=vmax)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    plt.colorbar(im, ax=ax, extend='max')

plt.suptitle('Forward simulation (initial design)', fontsize=14)
plt.tight_layout()
plt.show()

# Check power at waveguide output
wg_field = np.array(fwd_results['monitor_data']['wg_output'])
S = hwc.S_from_slice(jnp.mean(jnp.array(wg_field), axis=2))
power = float(jnp.abs(jnp.sum(S[:, 0, :, :], axis=(1, 2))))
print(f"Waveguide output power: {power:.6f}")
print(f"Coupling (approx): {power / input_power * 100:.1f}%")

## Step 7: Optimization Setup

### How adjoint-method inverse design works

Each optimization step runs **two FDTD simulations** on the cloud GPU:

1. **Forward solve:** Inject the Gaussian source, propagate through the current
   structure, extract the field at the **loss monitor** (waveguide output).
2. **Adjoint solve:** Inject the gradient of the loss (w.r.t. the loss monitor
   field) back into the simulation *at the loss monitor position*, propagate
   again, and extract the adjoint field at the **design monitor**.

The permittivity gradient is then computed only inside the design monitor volume,
not the entire 3D domain. This is memory-efficient: instead of differentiating
through the full `(Lx, Ly, Lz)` grid, we only need fields in the thin design
region (the etch layer, ~3 pixels in z).

### Structure spec

For forward simulation (Step 6), we passed a pre-built `Structure` with the
permittivity already computed. For inverse design, the cloud GPU needs to
**rebuild** the permittivity from a new `theta` at every step. The
`structure_spec` provides the layer stack template (materials, thicknesses,
filter parameters) so the GPU knows how to reconstruct permittivity from any theta.

### Mode overlap loss

We define a custom loss function that computes the mode coupling efficiency
(overlap integral between the simulated output field and the target waveguide
mode) and negates it. The optimizer minimizes the loss, so negating the
efficiency makes it maximize coupling into the waveguide.

In [None]:
def make_structure_spec(layers, design_layer, density_radius, density_alpha,
                        vertical_radius=0):
    """Build structure_spec dict from a layers list for inverse design."""
    return {
        'layers_info': [
            {
                'permittivity_values': (
                    list(l.permittivity_values) if isinstance(l.permittivity_values, tuple)
                    else float(l.permittivity_values)
                ),
                'layer_thickness': float(l.layer_thickness),
                'density_radius': density_radius if i == design_layer else 0,
                'density_alpha': density_alpha if i == design_layer else 0,
            }
            for i, l in enumerate(layers)
        ],
        'construction_params': {'vertical_radius': vertical_radius},
    }

structure_spec = make_structure_spec(layers, DESIGN_LAYER, DENSITY_RADIUS, DENSITY_ALPHA)

# Loss monitor: where the output field is measured for the loss function.
# Placed at the waveguide output, just inside the absorber boundary.
loss_monitor_shape = (1, Ly, Lz)
loss_monitor_offset = (abs_widths[0] + 10, 0, 0)

# Design monitor: the 3D volume where gradients are computed.
# Covers the full XY domain but only the etch layer in Z (~3 pixels thick).
design_monitor_shape = (Lx, Ly, int(round(h2)))
design_monitor_offset = (0, 0, z_etch)

# Waveguide mask at theta resolution (forces theta=1 in waveguide region)
waveguide_mask = np.zeros((theta_Lx, theta_Ly), dtype=np.float32)
waveguide_mask[:wg_len_theta, theta_Ly // 2 - wg_hw_theta : theta_Ly // 2 + wg_hw_theta] = 1.0

# Custom mode overlap loss function (serialized via cloudpickle to GPU).
# Captures mode field (~125KB), input power, and P_mode_cross in closure.
_mode_np = np.array(mode_full)
_input_power = float(input_power)
_P_mode_cross = float(P_mode_cross)

def mode_overlap_loss(loss_field):
    """Negative mode coupling efficiency (minimize to maximize coupling).

    Computes: -|Re(integral(E_mode x H_sim) * integral(E_sim x H_mode))| / 4
    normalized by input power and mode self-overlap.
    """
    import jax.numpy as jnp

    alpha = 1.0 / jnp.sqrt(_input_power)
    beta = jnp.sqrt(2.0 / _P_mode_cross)

    # Average output field along propagation axis (x), normalize
    f = jnp.mean(loss_field * alpha, axis=2)
    m = jnp.array(_mode_np) * beta

    # Mode field components (axis=0, squeeze x dimension)
    e0, h0 = m[0, 0:3, 0, :, :], m[0, 3:6, 0, :, :]
    # Output field components (already squeezed by mean)
    e1, h1 = f[0, 0:3, :, :], f[0, 3:6, :, :]

    # Cross product overlap integrals (x-component for x-propagation)
    cross_e0h1 = jnp.sum(jnp.cross(e0, jnp.conj(h1), axis=0)[0, :, :])
    cross_e1h0 = jnp.sum(jnp.cross(e1, jnp.conj(h0), axis=0)[0, :, :])

    # Mode coupling efficiency: |Re(I1 * I2)| / 4
    eff = jnp.abs(jnp.real(cross_e0h1 * cross_e1h0)) / 4.0
    return -eff  # Negate: optimizer minimizes, so we maximize coupling

# Optimization parameters
NUM_STEPS = 5          # increase to 50-100 for production
LR = 0.01
GRAD_CLIP = 0.5

print(f"Optimizer: Adam + cosine decay (LR={LR}, clip={GRAD_CLIP}, {NUM_STEPS} steps)")
print(f"Loss: mode overlap (P_mode_cross={P_mode_cross:.6f}, input_power={input_power:.6f})")
print(f"Loss monitor at x={loss_monitor_offset[0]} ({loss_monitor_offset[0] * dx:.1f} um)")
print(f"Design monitor: {design_monitor_shape} at z={z_etch} "
      f"({design_monitor_shape[2] * dx * 1000:.0f} nm thick)")
print(f"Theta shape: {theta_init.shape}, waveguide_mask shape: {waveguide_mask.shape}")

## Step 8: Estimate Cost

Estimate the GPU cost before running optimization. Each optimization step runs
a forward and adjoint simulation, so the cost is roughly 2x a single simulation
per step.

In [None]:
grid_points = Lx * Ly * Lz
cost = hwc.estimate_cost(
    grid_points=grid_points,
    max_steps=10000,
)
# Inverse design runs forward + adjoint per step, roughly 2x
per_step_cost = cost['estimated_cost_usd'] * 2
total_estimate = per_step_cost * NUM_STEPS
print(f"Per step: ~${per_step_cost:.3f}")
print(f"Total ({NUM_STEPS} steps): ~${total_estimate:.2f}")

## Step 9: Run Optimization

`run_optimization()` runs the full loop on a cloud GPU and streams results
after each step. Interrupting the kernel cancels the GPU task
immediately. You are only charged for completed steps.

In [None]:
results = []

print(f"Running {NUM_STEPS}-step optimization on cloud GPU...")
print(f"(Interrupt kernel to stop early, GPU task will be cancelled)\n")

try:
    for step_result in hwc.run_optimization(
        theta=theta_init,
        source_field=source_field,
        source_offset=source_offset,
        freq_band=freq_band,
        structure_spec=structure_spec,
        loss_monitor_shape=loss_monitor_shape,
        loss_monitor_offset=loss_monitor_offset,
        design_monitor_shape=design_monitor_shape,
        design_monitor_offset=design_monitor_offset,
        loss_fn=mode_overlap_loss,
        num_steps=NUM_STEPS,
        learning_rate=LR,
        grad_clip_norm=GRAD_CLIP,
        waveguide_mask=waveguide_mask,
        absorption_widths=abs_widths,
        absorption_coeff=abs_coeff,
    ):
        results.append(step_result)
        step = step_result['step']
        loss = step_result['loss']
        eff_pct = abs(loss) * 100
        grad_max = step_result['grad_max']
        dt = step_result['step_time']
        print(f"Step {step:3d}/{NUM_STEPS}:  coupling = {eff_pct:.2f}%  "
              f"|grad|_max = {grad_max:.3e}  ({dt:.1f}s)")

except KeyboardInterrupt:
    print(f"\nStopped early. GPU task cancelled.")
    print(f"Results from {len(results)} completed steps saved.")

if results:
    efficiencies = [abs(r['loss']) * 100 for r in results]
    best_idx = int(np.argmax(efficiencies))
    best_eff = efficiencies[best_idx]
    print(f"\nBest coupling: {best_eff:.2f}% "
          f"({-10 * np.log10(max(best_eff / 100, 1e-10)):.2f} dB, step {best_idx + 1})")
    theta = results[-1]['theta']
else:
    print("No steps completed.")
    theta = theta_init

## Step 10: Results

Plot the optimization trajectory and compare initial vs optimized designs.

In [None]:
efficiencies = [abs(r['loss']) * 100 for r in results]
best_idx = int(np.argmax(efficiencies))
best_theta = results[best_idx]['theta']

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Efficiency curve
axes[0].plot(range(1, len(efficiencies) + 1), efficiencies, 'b-o', markersize=3)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Coupling Efficiency (%)')
axes[0].set_title('Mode Coupling Efficiency')
axes[0].grid(True, alpha=0.3)

# Initial vs best theta
extent = [0, theta_Lx * pixel_size, 0, theta_Ly * pixel_size]
for ax, th, title in [(axes[1], theta_init, 'Initial'),
                       (axes[2], best_theta, f'Best (step {best_idx + 1})')]:
    im = ax.imshow(th.T, origin='lower', cmap='PuOr', vmin=0, vmax=1, extent=extent)
    ax.set_xlabel('x (um)')
    ax.set_ylabel('y (um)')
    ax.set_title(title)
    plt.colorbar(im, ax=ax)
    rect = Rectangle(
        (dr['x_start'] * pixel_size, dr['y_start'] * pixel_size),
        (dr['x_end'] - dr['x_start']) * pixel_size,
        (dr['y_end'] - dr['y_start']) * pixel_size,
        linewidth=1.5, edgecolor='lime', facecolor='none', linestyle='--',
    )
    ax.add_patch(rect)

plt.tight_layout()
plt.show()

## Step 11: Verify Optimized Design

Run a forward simulation with the best design to visualize field distribution
and verify coupling into the waveguide.

In [None]:
# Build structure with best theta
density_best = hwc.density(jnp.array(best_theta), radius=DENSITY_RADIUS, alpha=DENSITY_ALPHA)

layers_best = [
    hwc.Layer(zero,         eps_air,            h_p),
    hwc.Layer(zero,         eps_air,            h0),
    hwc.Layer(zero,         eps_clad,           h1),
    hwc.Layer(density_best, (eps_clad, eps_si), h2),
    hwc.Layer(ones,         eps_si,             h3),
    hwc.Layer(zero,         eps_sio2,           h4),
    hwc.Layer(zero,         eps_si,             h5),
    hwc.Layer(zero,         eps_si,             h_p),
]
structure_best = hwc.create_structure(layers=layers_best, vertical_radius=0)

# Forward simulation with same monitors
recipe_best = structure_best.extract_recipe()
opt_results = hwc.simulate(
    structure_recipe=recipe_best,
    source_field=source_field,
    source_offset=source_offset,
    freq_band=freq_band,
    monitors_recipe=monitors.recipe,
    absorption_widths=abs_widths,
    absorption_coeff=abs_coeff,
)

print(f"Verification sim complete: {opt_results['sim_time']:.1f}s GPU time")

In [None]:
# Plot |E|^2 cross-sections for optimized design
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Row 0: |E|^2 XY, XZ, YZ
for ax, name, title, ext, xlabel, ylabel in [
    (axes[0, 0], 'xy_device', f'|E|^2 XY (z = {z_dev * dx:.2f} um)',
     [0, Lx * dx, 0, Ly * dx], 'x (um)', 'y (um)'),
    (axes[0, 1], 'xz_center', '|E|^2 XZ (y = center)',
     [0, Lx * dx, 0, Lz * dx], 'x (um)', 'z (um)'),
    (axes[0, 2], 'yz_center', '|E|^2 YZ (x = center)',
     [0, Ly * dx, 0, Lz * dx], 'y (um)', 'z (um)'),
]:
    field = np.array(opt_results['monitor_data'][name])
    E2 = np.sum(np.abs(field[0, 0:3])**2, axis=0).squeeze()
    vmax = np.percentile(E2, 95) * 4
    im = ax.imshow(E2.T, origin='lower', cmap='hot', extent=ext,
                   aspect='auto', vmax=vmax)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    plt.colorbar(im, ax=ax, extend='max')

# Row 1: Individual E-field components at device layer
field_xy = np.array(opt_results['monitor_data']['xy_device'])
for i, name in enumerate(['|Ex|', '|Ey|', '|Ez|']):
    comp = np.abs(field_xy[0, i, :, :, 0])
    im = axes[1, i].imshow(comp.T, origin='lower', cmap='hot',
                            extent=[0, Lx * dx, 0, Ly * dx])
    axes[1, i].set_xlabel('x (um)')
    axes[1, i].set_ylabel('y (um)')
    axes[1, i].set_title(name)
    plt.colorbar(im, ax=axes[1, i])

plt.suptitle('Optimized design fields', fontsize=14)
plt.tight_layout()
plt.show()

# Mode coupling efficiency using mode overlap integral
wg_field = np.array(opt_results['monitor_data']['wg_output'])
from hyperwave_community.monitors import mode_coupling_efficiency
eff_lin, eff_dB = mode_coupling_efficiency(
    output_field=jnp.array(wg_field),
    mode_field=jnp.array(mode_full),
    input_power=input_power,
    mode_cross_power=P_mode_cross,
    axis=0,
)
eff_pct = float(eff_lin[0]) * 100
loss_dB = float(eff_dB[0])
print(f"Mode coupling efficiency: {eff_pct:.2f}% ({loss_dB:.2f} dB)")

## Summary

This notebook demonstrated the inverse design workflow for a grating coupler:

1. Defined an SOI layer stack with a designable partial-etch layer
2. Created a Gaussian beam source simulating fiber illumination
3. Solved for the waveguide mode (used for the loss function)
4. Verified the setup with a forward simulation
5. Ran adjoint-method optimization on cloud GPUs
6. Verified the optimized design with a final forward simulation

**Next steps for production designs:**

- **More steps:** Increase `NUM_STEPS` to 50-100. Grating couplers typically
  need 50+ steps to converge.
- **Binarization:** Gradually increase `density_alpha` from 0 to 1 over the
  course of optimization to push the design toward fabricable binary values.
- **Fabrication constraints:** Enforce minimum feature size, minimum spacing,
  and etch design rules for the target foundry process.
- **Different starting points:** Try several initial theta values (e.g., 0.3,
  0.5, 0.7, or random patterns) and keep the best result. The loss landscape
  is non-convex, so different initializations can find different local optima.
- **Finer grid:** Decrease `dx` (e.g., 25nm or 20nm) for higher accuracy.
  Finer grids resolve sub-wavelength features better but increase cost
  quadratically.
- **Multi-frequency:** Expand `freq_band` to multiple wavelengths (e.g.,
  1530-1570nm) to optimize for bandwidth.