# Neural operator surrogate solvers in 2D

## Setup

In [None]:
%matplotlib ipympl
import torch
import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib.colors import CenteredNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

from mlatint import FNO, Sampler, FDFD

torch.set_grad_enabled(False);

In [None]:
fno = FNO()
fdfd = FDFD()
blobs = Sampler("blobs")
triangles = Sampler("triangles")

extent = (-2.56, 2.56, -2.56, 2.56)  # for plotting

## FDFD simulation

To get started, we will run some regular Maxwell simulations with the sampled geometries.

In [None]:
geometry = blobs.sample()
ez = fdfd(geometry)

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 4, figsize=(9, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("Re(Ez)")
ax[2].set_title("Im(Ez)")
ax[3].set_title("|Ez|")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

extent = (-2.56, 2.56, -2.56, 2.56)
im1 = ax[0].imshow(geometry, cmap="gray_r", extent=extent)
im2 = ax[1].imshow(ez.real, cmap="RdBu", norm=CenteredNorm(), extent=extent)
im3 = ax[2].imshow(ez.imag, cmap="RdBu", norm=CenteredNorm(), extent=extent)
im4 = ax[3].imshow(ez.abs(), cmap="magma", extent=extent)
plt.show()

radio = widgets.RadioButtons(
    options=["Blobs", "Triangles"],
    description="Geometry",
    disabled=False
)

def update(_):
    geom_val = radio.get_interact_value()
    sampler = blobs if geom_val == "Blobs" else triangles
    
    x = sampler.sample()
    ez = fdfd(x)
    
    im1.set_data(x)
    im2.set_data(ez.real)
    im3.set_data(ez.imag)
    im4.set_data(ez.abs())

button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio, button])

## FNO "simulation"

Now we do the same but replace the FDFD solver with our pre-trained FNO surrogate solver.

In [None]:
geometry = blobs.sample()
ez = fno(geometry)

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 4, figsize=(9, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("Re(Ez)")
ax[2].set_title("Im(Ez)")
ax[3].set_title("|Ez|")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

extent = (-2.56, 2.56, -2.56, 2.56)
im1 = ax[0].imshow(geometry, cmap="gray_r", extent=extent)
im2 = ax[1].imshow(ez.real, cmap="RdBu", norm=CenteredNorm(), extent=extent)
im3 = ax[2].imshow(ez.imag, cmap="RdBu", norm=CenteredNorm(), extent=extent)
im4 = ax[3].imshow(ez.abs(), cmap="magma", extent=extent)
plt.show()

radio = widgets.RadioButtons(
    options=["Blobs", "Triangles"],
    description="Geometry",
    disabled=False
)

def update(_):
    geom_val = radio.get_interact_value()
    sampler = blobs if geom_val == "Blobs" else triangles
    
    x = sampler.sample()
    ez = fno(x)
    
    im1.set_data(x)
    im2.set_data(ez.real)
    im3.set_data(ez.imag)
    im4.set_data(ez.abs())

button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio, button])

## FDFD vs FNO

Let's inspect the differences between the two in terms of accuracy and speed!

### Error

For the errors, we will focus on the absolute field for simplicity (it contains both the real and imaginary parts, obviously, so should be a decent benchmark).
We normalize the absolute fields of both solvers. This is not strictly necessary but simplifies the comparison a bit.

In [None]:
def normalize(x):
    x -= x.min()
    x /= x.max()
    return x

In [None]:
geometry = blobs.sample()
ez_fdfd = fdfd(geometry)
ez_fno = fno(geometry)

ez_fdfd_norm = normalize(ez_fdfd.abs())
ez_fno_norm = normalize(ez_fno.abs())

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 4, figsize=(10, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("|Ez| FDFD")
ax[2].set_title("|Ez| FNO")
ax[3].set_title("abs. err.")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

extent = (-2.56, 2.56, -2.56, 2.56)
im1 = ax[0].imshow(geometry, cmap="gray_r", extent=extent)
im2 = ax[1].imshow(ez_fdfd_norm, cmap="magma", extent=extent)
im3 = ax[2].imshow(ez_fno_norm, cmap="magma", extent=extent)
im4 = ax[3].imshow(torch.abs(ez_fdfd_norm - ez_fno_norm), cmap="viridis", extent=extent)
divider = make_axes_locatable(ax[3])
cax = divider.append_axes("right", size="5%", pad=0.25)
fig.colorbar(im4, cax)
plt.show()

radio = widgets.RadioButtons(
    options=["Blobs", "Triangles"],
    description="Geometry",
    disabled=False
)

def update(_):
    geom_val = radio.get_interact_value()
    sampler = blobs if geom_val == "Blobs" else triangles
    
    x = sampler.sample()
    ez_fdfd = normalize(fdfd(x).abs())
    ez_fno = normalize(fno(x).abs())
    
    im1.set_data(x)
    im2.set_data(ez_fdfd)
    im3.set_data(ez_fno)
    ax[3].cla()
    im4 = ax[3].imshow(torch.abs(ez_fdfd - ez_fno), cmap="viridis", extent=extent)
    cax.cla()
    fig.colorbar(im4, cax)
    ax[3].set_xlabel("x (μm)")

button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio, button])

### Speed

Obviously, we expect the surrogate solver to be faster. Is it though?

In [None]:
%%time
for _ in range(10):
    fdfd(blobs.sample())

In [None]:
%%time
for _ in range(10):
    fno(blobs.sample())

## "Breaking" the FNO

The surrogate solver has limits. One of them is the materials in the simulation. As it was trained on only two permittivities, the error will increase as we include different materials in the simulation.

In [None]:
geometry = blobs.sample()
ez_fdfd = fdfd(geometry)
ez_fno = fno(geometry)

ez_fdfd_norm = normalize(ez_fdfd.abs())
ez_fno_norm = normalize(ez_fno.abs())

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 4, figsize=(10, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("|Ez| FDFD")
ax[2].set_title("|Ez| FNO")
ax[3].set_title("abs. err.")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

extent = (-2.56, 2.56, -2.56, 2.56)
im1 = ax[0].imshow(geometry, cmap="gray_r", extent=extent)
im2 = ax[1].imshow(ez_fdfd_norm, cmap="magma", extent=extent)
im3 = ax[2].imshow(ez_fno_norm, cmap="magma", extent=extent)
im4 = ax[3].imshow(torch.abs(ez_fdfd_norm - ez_fno_norm), cmap="viridis", extent=extent)
divider = make_axes_locatable(ax[3])
cax = divider.append_axes("right", size="5%", pad=0.25)
fig.colorbar(im4, cax)
plt.show()

eps_input = widgets.FloatText(value=2.25, description="eps_max", disabled=False)

def update(_):
    eps = eps_input.get_interact_value() / 2.25
    
    x = eps * blobs.sample()
    ez_fdfd = normalize(fdfd(x).abs())
    ez_fno = normalize(fno(x).abs())
    
    im1.set_data(x)
    im2.set_data(ez_fdfd)
    im3.set_data(ez_fno)
    ax[3].cla()
    im4 = ax[3].imshow(torch.abs(ez_fdfd - ez_fno), cmap="viridis", extent=extent)
    cax.cla()
    fig.colorbar(im4, cax)
    ax[3].set_xlabel("x (μm)")

button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([eps_input, button])