# Automated model extraction

You can use gdsfactory simulation plugins to build SDict models for circuit simulations. 

The parent `Model` class contains common logic for model building such as input-output vector definition from a set of input parameters, as well as fitting of the input-output vector relationships (for instance, through ND-ND interpolation and feedforward neural nets).

The children subclasses inherit all of this machinery, but further define solver- or component-specific information such as:

- `outputs_from_inputs` method: how the input vectors (typically, `Component` or `LayerStack` arguments) are mapped to output vectors (this could directly be the S-parameters, or some solver results used to generate S-parameters like effective index)
- `sdict` method: how the output vectors are mapped to S-parameter dictionaries for circuit simulation (this could directly be the result of `output_from_input`, or some downstream calculation using the output vectors with some extra Component parameters whose effect on the S-parameters is known and does not require training)

For instance, consider a `straight` component in the generic LayerStack

In [1]:
import jax.numpy as jnp
from sax.utils import reciprocal

from gdsfactory.pdk import get_layer_stack
from gdsfactory.simulation.fem.mode_solver import compute_cross_section_modes
from gdsfactory.simulation.sax.build_model import Model

import gdsfactory as gf
from gdsfactory.cross_section import rib
from gdsfactory.simulation.sax.parameter import LayerStackThickness, NamedParameter
from gdsfactory.technology import LayerStack
from gdsfactory.generic_tech import get_generic_pdk

gf.config.rich_output()
PDK = get_generic_pdk()
PDK.activate()

c = gf.components.straight(
    cross_section=rib(width=2),
    length=10,
)
c

2023-02-20 17:57:46.219 | INFO     | gdsfactory.config:<module>:50 - Load '/home/runner/work/gdsfactory/gdsfactory/gdsfactory' 6.43.1


2023-02-20 17:57:47.843 | INFO     | gdsfactory.technology.layer_views:__init__:785 - Importing LayerViews from YAML file: /home/runner/work/gdsfactory/gdsfactory/gdsfactory/generic_tech/layer_views.yaml.


2023-02-20 17:57:47.850 | INFO     | gdsfactory.pdk:activate:206 - 'generic' PDK is now active


AppLayout(children=(Tab(children=(VBox(children=(HBox(children=(Button(layout=Layout(border_bottom='solid 2px …

straight_ad10dbf8: uid c2296b1e, ports ['o1', 'o2'], references [], 2 polygons


In [2]:
layerstack = get_layer_stack()

filtered_layerstack = LayerStack(
    layers={
        k: layerstack.layers[k]
        for k in (
            "slab90",
            "core",
            "box",
            "clad",
        )
    }
)

We first wrap this component into a function taking for argument only a dictionary, the keys of which are used to parametrize the Component arguments we are interested in varying. Below, for instance, we force the component straight to have a `rib` cross-section, whose width can be varied.

In [3]:
def trainable_straight_rib(parameters):
    return gf.components.straight(cross_section=rib(width=parameters["width"]))

## Instantiating Models

Next we can instantiate the `Model` proper. Here, we use the children class `FemwellWaveguideModel`. Its `outputs_from_inputs` method returns the effective index from the input geometry, and its `sdict` function uses the input geometry, length, and loss to return the S-parameters for the corresponding straight waveguide:

In [4]:
from gdsfactory.simulation.sax.femwell_waveguide_model import FemwellWaveguideModel

rib_waveguide_model = FemwellWaveguideModel(
    trainable_component=trainable_straight_rib,
    layerstack=filtered_layerstack,
    simulation_settings={
        "resolutions": {
            "core": {"resolution": 0.02, "distance": 2},
            "clad": {"resolution": 0.2, "distance": 1},
            "box": {"resolution": 0.2, "distance": 1},
            "slab90": {"resolution": 0.05, "distance": 1},
        },
        "overwrite": False,
        "order": 1,
        "radius": jnp.inf,
    },
    trainable_parameters={
        "width": NamedParameter(
            min_value=0.4, max_value=0.6, nominal_value=0.5, step=0.05
        ),
        "wavelength": NamedParameter(
            min_value=1.545, max_value=1.555, nominal_value=1.55, step=0.005
        ),
        "core_thickness": LayerStackThickness(
            layerstack=filtered_layerstack,
            min_value=0.21,
            max_value=0.23,
            nominal_value=0.22,
            layername="core",
            step=0.1,
        ),
    },
    non_trainable_parameters={
        "length": NamedParameter(nominal_value=10),
        "loss": NamedParameter(nominal_value=1),
    },
    num_modes=4,
)

Note the dictionary parameters:

(1) the entries of `simulation_settings` are used by the model builder to parametrize the simulator,

(2) the entries of `trainable_parameters` are used to define the simulation space that maps inputs to outputs and which requires interpolation, and

(3) the entries of `non_trainable_parameters` are required to calculate the S-parameters, but do not appear in the simulator (their effect can be added after intermediate results have been interpolated).

## Training models

The Model object can generate input and output vectors requiring modelling from these dicts:

In [5]:
%%capture
input_vectors, output_vectors = rib_waveguide_model.get_model_input_output()

2023-02-20 17:57:48.606 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_d69acef1_c866923b.npz')


2023-02-20 17:57:49.272 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_d69acef1_b531f320.npz')


2023-02-20 17:57:49.901 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_d69acef1_a1913a4a.npz')


2023-02-20 17:57:50.536 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_14c7c91f_169072fa.npz')


2023-02-20 17:57:51.194 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_14c7c91f_700ec387.npz')


2023-02-20 17:57:51.856 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_14c7c91f_b5da7b7c.npz')


2023-02-20 17:57:52.516 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_4309c6bd_b1a3b49a.npz')


2023-02-20 17:57:53.184 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_4309c6bd_1a9e5015.npz')


2023-02-20 17:57:53.856 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_4309c6bd_2bd41f55.npz')


2023-02-20 17:57:54.525 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_80d57308_9eb3c64a.npz')


2023-02-20 17:57:55.191 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_80d57308_7feb320a.npz')


2023-02-20 17:57:55.861 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_80d57308_81ff7582.npz')


2023-02-20 17:57:56.534 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_ec84fb00_2f50474a.npz')


2023-02-20 17:57:57.217 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_ec84fb00_ea28a376.npz')


2023-02-20 17:57:57.897 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_ec84fb00_e158a597.npz')


From above, we expect the input vector to have a number of rows equal to the set of trainable parameter points, here len(widths) x len(core_thickness) x len(wavelength) = 15, and a number of columns equal to the number of trainable parameters (3):

In [6]:
import numpy as np

print(np.shape(input_vectors))
print(input_vectors[0])

(15, 3)
[0.4   1.545 0.21 ]


The output (here, the effective indices) will have #input_vector rows, and #modes columns:

In [7]:
print(output_vectors[0])
print(np.shape(output_vectors))

[ 2.4341786e+00  2.0900087e+00  2.0828242e+00  2.0582621e+00
 -1.0564244e-08  4.4820982e-09  2.6217071e-09  2.0750175e-09]
(15, 8)


Typically we are not interested in these vectors per say, but in some interpolation model between them. One way is to perform ND-ND interpolation:

In [8]:
%%capture
rib_waveguide_model.set_nd_nd_interp()

2023-02-20 17:57:58.675 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_d69acef1_c866923b.npz')


2023-02-20 17:57:59.292 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_d69acef1_b531f320.npz')


2023-02-20 17:57:59.911 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_d69acef1_a1913a4a.npz')


2023-02-20 17:58:00.549 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_14c7c91f_169072fa.npz')


2023-02-20 17:58:01.209 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_14c7c91f_700ec387.npz')


2023-02-20 17:58:01.856 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_14c7c91f_b5da7b7c.npz')


2023-02-20 17:58:02.501 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_4309c6bd_b1a3b49a.npz')


2023-02-20 17:58:03.159 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_4309c6bd_1a9e5015.npz')


2023-02-20 17:58:03.819 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_4309c6bd_2bd41f55.npz')


2023-02-20 17:58:04.482 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_80d57308_9eb3c64a.npz')


2023-02-20 17:58:05.141 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_80d57308_7feb320a.npz')


2023-02-20 17:58:05.808 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_80d57308_81ff7582.npz')


2023-02-20 17:58:06.472 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_ec84fb00_2f50474a.npz')


2023-02-20 17:58:07.159 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_ec84fb00_ea28a376.npz')


2023-02-20 17:58:07.850 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:151 - Simulation loaded from PosixPath('/home/runner/.gdsfactory/modes/straight_ec84fb00_e158a597.npz')


The populates the model with an interpolator

## Model inference

These can then be used to construct the S-parameters within the trainable_parameter range:

In [9]:
params_dict = {
    "width": 0.5,
    "wavelength": 1.55,
    "core_thickness": 0.22,
    "length": 10,
    "loss": 1,
}

print(rib_waveguide_model.sdict(params_dict))

{('o1@0', 'o2@0'): Array(-0.13705602+0.2849836j, dtype=complex64), ('o1@1', 'o2@1'): Array(-0.28867194-0.1291066j, dtype=complex64), ('o1@2', 'o2@2'): Array(-0.28186253+0.14336498j, dtype=complex64), ('o1@3', 'o2@3'): Array(-0.18617052+0.25561795j, dtype=complex64), ('o2@0', 'o1@0'): Array(-0.13705602+0.2849836j, dtype=complex64), ('o2@1', 'o1@1'): Array(-0.28867194-0.1291066j, dtype=complex64), ('o2@2', 'o1@2'): Array(-0.28186253+0.14336498j, dtype=complex64), ('o2@3', 'o1@3'): Array(-0.18617052+0.25561795j, dtype=complex64)}


These can also be called as arrays:

In [10]:
params_dict = {
    "width": jnp.array([0.5, 0.3, 0.65]),
    "wavelength": jnp.array([1.55, 1.547, 1.55]),
    "core_thickness": jnp.array([0.22, 0.22, 0.21]),
    "length": jnp.ones(3) * 10,
    "loss": jnp.ones(3) * 1,
}

print(rib_waveguide_model.sdict(params_dict))

{('o1@0', 'o2@0'): Array([-0.13705602+0.2849836j , -0.05266529-0.31181145j,
        0.03626653-0.31414127j], dtype=complex64), ('o1@1', 'o2@1'): Array([-0.28867194-0.1291066j , -0.31618416-0.00525091j,
        0.25494823-0.18708663j], dtype=complex64), ('o1@2', 'o2@2'): Array([-0.28186253+0.14336498j, -0.30435038+0.08585367j,
       -0.28721702+0.13231172j], dtype=complex64), ('o1@3', 'o2@3'): Array([-0.18617052+0.25561795j, -0.09122577+0.30278352j,
       -0.2501116 +0.19350496j], dtype=complex64), ('o2@0', 'o1@0'): Array([-0.13705602+0.2849836j , -0.05266529-0.31181145j,
        0.03626653-0.31414127j], dtype=complex64), ('o2@1', 'o1@1'): Array([-0.28867194-0.1291066j , -0.31618416-0.00525091j,
        0.25494823-0.18708663j], dtype=complex64), ('o2@2', 'o1@2'): Array([-0.28186253+0.14336498j, -0.30435038+0.08585367j,
       -0.28721702+0.13231172j], dtype=complex64), ('o2@3', 'o1@3'): Array([-0.18617052+0.25561795j, -0.09122577+0.30278352j,
       -0.2501116 +0.19350496j], dtype=com

## Model validation

We can validate the intermediate input-output relationships by comparing the predictions to new simulations within the trainable parameter space:

In [11]:
%%capture
validation_inputs, calculated_outputs, inferred_outputs = rib_waveguide_model.validate(
    num_samples=3
)

2023-02-20 17:58:11.898 | INFO     | gdsfactory.simulation.gtidy3d:<module>:54 - Tidy3d '1.8.4' installed at ['/usr/share/miniconda/envs/anaconda-client-env/lib/python3.9/site-packages/tidy3d']


2023-02-20 17:58:19.763 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:205 - Write mode to PosixPath('/home/runner/.gdsfactory/modes/straight_dd799dbb_d21258b8.npz')


2023-02-20 17:58:29.375 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:205 - Write mode to PosixPath('/home/runner/.gdsfactory/modes/straight_c8106782_122693ac.npz')


2023-02-20 17:58:37.964 | INFO     | gdsfactory.simulation.fem.mode_solver:compute_component_slice_modes:205 - Write mode to PosixPath('/home/runner/.gdsfactory/modes/straight_742d5ba9_addb2c0c.npz')


In [12]:
validation_inputs

In [13]:
calculated_outputs

In [14]:
inferred_outputs

While the trend seems reasonable, the model above could benefit from more examples or better simulation parameter tuning.