Load the benchmark case:

In [None]:
import jetto_tools
import plotly.graph_objects as go

template = jetto_tools.template.from_directory("../jetto/templats/spr45-v9")
exfile = jetto_tools.binary.read_binary_file(template.extra_files["jetto.ex"])

# Normalise ECRH to [0, 1]
benchmark_ecrh = exfile["QECE"][0] / np.max(exfile["QECE"][0])
xrho = exfile["XRHO"][0]

figure = go.Figure()
figure.add_trace(go.Scatter(x=xrho, y=benchmark_ecrh, name="SPR45-v9"))

Load our ECRH functions:

In [None]:
from jetto_mobo.ecrh import (
    piecewise_linear,
    piecewise_linear_2,
    sum_of_gaussians,
    cubic_spline,
)


def sum_of_gaussians_2(x, *p):
    return sum_of_gaussians(x, p[: len(p) // 2], [0.85] * len(p), p[len(p) // 2 :])

Fit to benchmark:

In [None]:
from scipy.optimize import curve_fit

parameters = {}
N_parameters = 12
for f in [piecewise_linear, piecewise_linear_2, sum_of_gaussians, cubic_spline]:
    p, _ = curve_fit(
        f=f,
        xdata=xrho,
        ydata=benchmark_ecrh,
        p0=[
            *np.linspace(0, 0.5, N_parameters // 2),
            *np.linspace(1, 0.1, N_parameters // 2),
        ],
        bounds=(0, 1),
        max_nfev=1e4,
    )
    parameters[f] = p

    figure.add_trace(
        go.Scatter(
            x=xrho,
            y=f(xrho, p),
            name=f.__name__,
        )
    )