# Inverse design using surrogate solvers and VAEs

## Setup

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import output

    output.enable_custom_widget_manager()
    if "mlatint" not in sys.modules:
        !sudo apt install libcairo2-dev pkg-config python3-dev
        !pip install git+https://github.com/yaugenst/mlatint

In [None]:
%matplotlib ipympl
import torch
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from mlatint import VAE, Sampler, FNO, FDFD

First, we load everything we need. Here, that will be the FNO and VAE models as well as the FDFD simulation to check our results later.

In [None]:
fno = FNO()
fdfd = FDFD()
blobs = Sampler("blobs")
vae_blobs = VAE("blobs")

extent = (-2.56, 2.56, -2.56, 2.56)  # for plotting

For inverse design, we need to define an objective function. We do this here by defining a boolean mask, and we want to maximize the absolute electric field wherever this mask is non-zero.

In [None]:
mask = torch.zeros(128, 128, dtype=bool)
mask[40:50, 90:100] = 1
mask[110:120, 20:30] = 1


def objective(latent):
    design = vae_blobs.decode(latent)
    ez = fno(design)
    return ez[mask].abs().mean()

Next, we initialize a latent vector which we will use as the parameters in our optimization.
Note that we initialize this vector by encoding a randomly sampled structure using the VAE encoder.
This is not strictly necessary, we could also randomly sample from some uniform or normal distribution instead (with the correct shape).
Using an encoded structure however guarantees that we have a starting point in our optimization that the VAE can represent well.

In [None]:
x0 = vae_blobs.encode(blobs.sample())

And now we visualize our setup by plotting the initial geometry, the target optimization areas, as well a reference FNO & FDFD simulation just to make sure that the initial FNO fields are ok.

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 3, figsize=(9, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("|Ez| FDFD")
ax[2].set_title("|Ez| FNO")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

_init = vae_blobs.decode(x0)

im1 = ax[0].imshow(_init.numpy(force=True), cmap="gray_r", extent=extent)
im2 = ax[1].imshow(fdfd(_init).abs().numpy(force=True), cmap="magma", extent=extent)
im3 = ax[2].imshow(fno(_init).abs().numpy(force=True), cmap="magma", extent=extent)
for axi in ax:
    axi.contour(torch.flipud(mask), colors="r", extent=extent)
plt.show()

## Optimization

Now we are almost ready to optimize! We will use a PyTorch optimizer out of convenience here, but in principle any local optimizer will do.

In [None]:
x0.requires_grad_(True)
opt = torch.optim.Adam([x0], lr=1e-1)

And now we optimize for a certain number of iterations (modify this value if needed!) and show the evolution of the structure and the field.
Note that you can re-run this cell to continue optimizing for some additional number of iterations.

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 2, figsize=(6, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("|Ez| FNO")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

_init = vae_blobs.decode(x0)

extent = (-2.56, 2.56, -2.56, 2.56)
im1 = ax[0].imshow(_init.numpy(force=True), cmap="gray_r", extent=extent)
im2 = ax[1].imshow(fno(_init).abs().numpy(force=True), cmap="magma", extent=extent)
for axi in ax:
    axi.contour(torch.flipud(mask), colors="r", extent=extent)
hfig = display(fig, display_id=True)

for _ in range(10):
    opt.zero_grad()
    loss = -objective(x0)
    loss.backward()
    opt.step()
    with torch.no_grad():
        torch.clamp_(x0, -0.5, 2.0)  # not strictly necessary but helps the VAE
        design = vae_blobs.decode(x0)
        ez = fno(design).abs()
        im1.set_data(design.numpy(force=True))
        im2.set_data(ez.numpy(force=True))
        fig.canvas.draw()
        hfig.update(fig)

plt.close(fig)

## Evaluation

And lastly, we visualize the results of the optimization and compare to FDFD fields.

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 3, figsize=(9, 3), sharey=True)

ax[0].set_title("Geometry")
ax[1].set_title("|Ez| FDFD")
ax[2].set_title("|Ez| FNO")
ax[0].set_ylabel("y (μm)")
plt.setp(ax, xlabel="x (μm)")

_init = vae_blobs.decode(x0)

im1 = ax[0].imshow(_init.numpy(force=True), cmap="gray_r", extent=extent)
im2 = ax[1].imshow(fdfd(_init).abs().numpy(force=True), cmap="magma", extent=extent)
im3 = ax[2].imshow(fno(_init).abs().numpy(force=True), cmap="magma", extent=extent)
for axi in ax:
    axi.contour(torch.flipud(mask), colors="r", extent=extent)
plt.show()