Load the benchmark case:

In [1]:
import jetto_tools
import plotly.graph_objects as go
import numpy as np

template = jetto_tools.template.from_directory("../jetto/templates/spr45-v9")
exfile = jetto_tools.binary.read_binary_file(template.extra_files["jetto.ex"])

# Normalise ECRH to [0, 1]
benchmark_ecrh = exfile["QECE"][0] / np.max(exfile["QECE"][0])
xrho = exfile["XRHO"][0]

figure = go.Figure()
figure.add_trace(go.Scatter(x=xrho, y=benchmark_ecrh, name="SPR45-v9"))



Load our ECRH functions:

In [2]:
from jetto_mobo.ecrh import (
    piecewise_linear,
    piecewise_linear_2,
    sum_of_gaussians,
    cubic_spline,
)

def sum_of_gaussians_2(x, p):
    return sum_of_gaussians(x, p[: len(p) // 2], [0.0025] * (len(p) // 2), p[len(p) // 2 :])

f_dict = {
    piecewise_linear: 12,
    piecewise_linear_2: 12,
    sum_of_gaussians_2: 8,
    cubic_spline: 10
}

Fit to benchmark:

In [3]:
from scipy.optimize import curve_fit

parameters = {}
for f, n_params in f_dict.items():
    optimised_parameters, _ = curve_fit(
        f=lambda x, *p: f(x, p),
        xdata=xrho,
        ydata=benchmark_ecrh,
        p0=[
            *np.linspace(0.2, 0.5, n_params // 2),
            *np.linspace(1, 0.1, n_params // 2),
        ],
        bounds=(0, 1),
        max_nfev=1e4,
    )
    parameters[f.__name__] = optimised_parameters

    figure.add_trace(
        go.Scatter(
            x=xrho,
            y=f(xrho, optimised_parameters),
            name=f.__name__,
        )
    )

In [4]:
figure.show()

In [9]:
figure2 = go.Figure()

for f, n_params in f_dict.items():
    figure2.add_trace(
        go.Scatter(
            x=xrho,
            y=f(xrho, np.random.rand(n_params)),
            name=f.__name__,
        )
    )
    
figure2.show()

In [6]:
figure3 = go.Figure()

f = sum_of_gaussians_2

for i in range(5):
    figure3.add_trace(
        go.Scatter(
            x=xrho,
            y=f(xrho, np.random.rand(n_params)),
            name=str(i),
        )
    )
    
figure3.show()