In [53]:
import re
import jax
import numpy as np
from qsc import Qsc
import jax.numpy as jnp
from flax import serialization
from scipy.optimize import root
from train_inn import DeepNN as i_DeepNN, number_of_x_parameters as i_number_of_x_parameters, model_save_path as i_model_save_path
from train_nn import DeepNN as f_DeepNN, number_of_x_parameters as f_number_of_x_parameters, model_save_path as f_model_save_path
i_nfp = int(re.search('nfp(\d+)', i_model_save_path).group(1))
f_nfp = int(re.search('nfp(\d+)', f_model_save_path).group(1))
print(f'nfp_inverse_solver = {i_nfp} and nfp_forward_solver = {f_nfp}')
# Load i_NN
i_model = i_DeepNN()
i_dummy_input = jnp.ones((1, i_number_of_x_parameters))
i_init_params = i_model.init(jax.random.PRNGKey(0), i_dummy_input)
with open(i_model_save_path, 'rb') as f:
    i_bytes_params = f.read()
i_params = serialization.from_bytes(i_init_params, i_bytes_params)
# Load f_NN
f_model = f_DeepNN()
f_dummy_input = jnp.ones((1, f_number_of_x_parameters))
f_init_params = f_model.init(jax.random.PRNGKey(0), f_dummy_input)
with open(f_model_save_path, 'rb') as f:
    f_bytes_params = f.read()
f_params = serialization.from_bytes(f_init_params, f_bytes_params)

nfp_inverse_solver = 2 and nfp_forward_solver = 2


In [63]:
# Test iNN
iota=0.1
elongation = 3.1
maxiLgradB = 1.2
wanted_stel_geometry = [iota, elongation, maxiLgradB]
%timeit predicted = i_model.apply(i_params, wanted_stel_geometry)
def objective(x):
    stel = Qsc(rc=[1,-x[0]], zs=[0, x[1]], nfp=i_nfp, etabar=x[2], nphi=51)
    result = np.array([stel.iota, float(np.max(stel.elongation)), np.max(stel.inv_L_grad_B)])
    return (result - wanted_stel_geometry)**2
%timeit true = root(objective, x0=predicted).x
print(f'predicted = {predicted}')
print(f'true = {true}')


5.61 ms ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
503 ms ± 9.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
predicted = [0.07039151 0.07435976 1.3126286 ]
true = [0.06373385 0.08085611 1.32867536]


In [49]:
# Test NN
rc = 0.06
zs = 0.1
etabar = 0.5
%timeit predicted_stel_geometry = f_model.apply(f_params, [rc, zs, etabar])
%timeit stel = Qsc(rc=[1,-rc], zs=[0, zs], nfp=f_nfp, etabar=etabar, nphi=51)
true_stel_geometry = [stel.iota, float(np.max(stel.elongation)), np.max(stel.inv_L_grad_B)]
print(f'predicted = {predicted_stel_geometry}')
print(f'true = {true_stel_geometry}')

5.61 ms ± 20 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.33 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
predicted = [0.0770195 5.010864  1.3075962]
true = [0.07573669243475663, 5.101711383433303, 1.3051728263887241]
