In [None]:
from collections import defaultdict

import pandas as pd

from ho_stray_field import *
from ho_stray_field.bspline import BSpline
from ho_stray_field.tensor_grid import TensorGrid
from ho_stray_field.base import fit, factors_pinv
from ho_stray_field.tucker_tensor import TuckerTensor

from ho_stray_field.sources import flower_state, vortex_state

jax.config.update("jax_enable_x64", True)

# B-Spline fitting of a flower state

We want to fit a flower state in the unit cube $[-0.5, 0.5]^3$. The flower state is given by the normalized version of
\begin{align*}
    m_1(x_1, x_2, x_3) &= x_1x_3 \\
    m_2(x_1, x_2, x_3) &= x_2x_3 + \frac{1}{8} x_2^3x_3^3 \\
    m_3(x_1, x_2, x_3) &= 1.
\end{align*}

First, we need to define the B-Spline model:

In [2]:
r, k = 10, 5
model = BSpline(
    TensorGrid(*[jnp.linspace(-0.5, 0.5, r)] * 3),
    degree=k - 1
)

Fitting can be done by first defining a quadrature tensor grid and then call the `fit` method:

In [3]:
quad_tg = TensorGrid(*[jnp.array([-0.5, 0.5])] * 3).to_gauss(140)
flower_state_ft = model.fit(quad_tg, flower_state)

The error can be computed as follows

In [4]:
tg_val = TensorGrid(*([jnp.linspace(-0.5, 0.5, 200)] * 3))
mag_true = jnp.apply_along_axis(flower_state, -1, tg_val.grid)
mag_pred = flower_state_ft(tg_val)
err = jnp.max(jnp.abs(mag_true - mag_pred))
err

Array(2.15792975e-08, dtype=float64)

Combinig this, we can compute some statistics for different B-Spline models:

In [5]:
@partial(jax.jit, static_argnames=("f"))
def fit_function(f, model):
    quad_tg = TensorGrid(*[jnp.array([-0.5, 0.5])] * 3).to_gauss(140)
    return model.fit(quad_tg, f)


def error(f, ft_model):
    tg_val = TensorGrid(*([jnp.linspace(-0.5, 0.5, 200)] * 3))
    mag_true = jnp.apply_along_axis(f, -1, tg_val.grid)
    mag_pred = ft_model(tg_val)
    return jnp.max(jnp.abs(mag_true - mag_pred))


_table_flower = []

for k in [3, 4, 5, 6, 7]:
    for r in [10, 20, 40, 80]:
        model = BSpline(TensorGrid(*[jnp.linspace(-0.5, 0.5, r)] * 3), degree=k - 1)
        flower_state_ft = fit_function(flower_state, model)
        flower_state_ft.core.block_until_ready()  # compile
        err = error(flower_state, flower_state_ft)
        t = %timeit -q -o -n 10 -r 2 fit_function(flower_state, model).core.block_until_ready()
        _table_flower.append({"k": k, "r": r, "error": err, "fitting_time": t.average})
        print(f"k={k}, r={r}, error={err:.4e}, fitting_time={t.average * 1000:.3f} ms")


k=3, r=10, error=1.2594e-05, fitting_time=3.546 ms
k=3, r=20, error=1.3198e-06, fitting_time=4.500 ms
k=3, r=40, error=1.5213e-07, fitting_time=7.788 ms
k=3, r=80, error=2.2988e-08, fitting_time=13.375 ms
k=4, r=10, error=4.4098e-07, fitting_time=3.220 ms
k=4, r=20, error=2.1672e-08, fitting_time=4.032 ms
k=4, r=40, error=1.2835e-09, fitting_time=7.781 ms
k=4, r=80, error=1.2482e-10, fitting_time=13.508 ms
k=5, r=10, error=2.1579e-08, fitting_time=3.344 ms
k=5, r=20, error=4.9088e-10, fitting_time=4.079 ms
k=5, r=40, error=1.3661e-11, fitting_time=8.896 ms
k=5, r=80, error=6.6766e-13, fitting_time=14.345 ms
k=6, r=10, error=1.7179e-09, fitting_time=3.403 ms
k=6, r=20, error=1.7146e-11, fitting_time=4.329 ms
k=6, r=40, error=2.3614e-13, fitting_time=7.764 ms
k=6, r=80, error=1.4655e-14, fitting_time=14.131 ms
k=7, r=10, error=1.4070e-10, fitting_time=4.343 ms
k=7, r=20, error=5.8442e-13, fitting_time=5.218 ms
k=7, r=40, error=1.9873e-14, fitting_time=8.385 ms
k=7, r=80, error=1.5876e-14

# B-Spline fitting of a vortex state

Here we do the same for a vortex state given by
\begin{align*}
    m_1(x_1, x_2, x_3) &= -\frac{x_2}{r} \sqrt{1 - \exp(-4 \frac{r^2}{r_c^2})} \\
    m_2(x_1, x_2, x_3) &= \frac{x_1}{r} \sqrt{1 - \exp(-4 \frac{r^2}{r_c^2})} \\
    m_3(x_1, x_2, x_3) &= \exp(-2 \frac{r^2}{r_c^2}),
\end{align*}
where $r = \sqrt{x_1^2 + x_2^2 }$ and $r_c = 0.14$.

In [6]:
_table_vortex = []

for k in [3, 4, 5, 6, 7]:
    for r in [10, 20, 40, 80]:
        model = BSpline(TensorGrid(*[jnp.linspace(-0.5, 0.5, r)] * 3), degree=k - 1)
        vortex_state_ft = fit_function(vortex_state, model)
        vortex_state_ft.core.block_until_ready()  # compile
        err = error(vortex_state, vortex_state_ft)
        t = %timeit -q -o -n 10 -r 2 fit_function(vortex_state, model).core.block_until_ready()
        _table_vortex.append({"k": k, "r": r, "error": err, "fitting_time": t.average})
        print(f"k={k}, r={r}, error={err:.4e}, fitting_time={t.average * 1000:.3f} ms")


k=3, r=10, error=1.0676e-01, fitting_time=2.935 ms
k=3, r=20, error=1.1426e-02, fitting_time=4.025 ms
k=3, r=40, error=9.7996e-04, fitting_time=7.451 ms
k=3, r=80, error=1.5599e-04, fitting_time=13.544 ms
k=4, r=10, error=1.9049e-01, fitting_time=3.409 ms
k=4, r=20, error=5.0380e-03, fitting_time=4.065 ms
k=4, r=40, error=1.6091e-04, fitting_time=7.985 ms
k=4, r=80, error=1.2096e-05, fitting_time=13.589 ms
k=5, r=10, error=7.2920e-02, fitting_time=3.471 ms
k=5, r=20, error=1.5826e-03, fitting_time=4.103 ms
k=5, r=40, error=1.7630e-05, fitting_time=8.425 ms
k=5, r=80, error=6.4886e-07, fitting_time=14.398 ms
k=6, r=10, error=1.2999e-01, fitting_time=3.536 ms
k=6, r=20, error=1.1279e-03, fitting_time=4.353 ms
k=6, r=40, error=3.2265e-06, fitting_time=7.901 ms
k=6, r=80, error=5.2026e-08, fitting_time=14.056 ms
k=7, r=10, error=5.1472e-02, fitting_time=3.517 ms
k=7, r=20, error=4.4722e-04, fitting_time=4.537 ms
k=7, r=40, error=5.6762e-07, fitting_time=8.163 ms
k=7, r=80, error=4.0412e-09

In [25]:
t1 = {k: [d[k] for d in _table_flower] for k in _table_flower[0]}
t2 = {k: [d[k] for d in _table_vortex] for k in _table_vortex[0]}
table = pd.DataFrame({"k": t1["k"], "r": t1["r"], "error flower": t1["error"], "error vortex": t2["error"], "time": t1["fitting_time"]})
print(table.to_latex(index=False))

\begin{tabular}{rrllr}
\toprule
k & r & error flower & error vortex & time \\
\midrule
2 & 10 & 0.0007042714464582378 & 0.33748728865236477 & 0.003416 \\
2 & 20 & 0.00015670338838780218 & 0.06348384597516266 & 0.004835 \\
2 & 40 & 3.647712010312887e-05 & 0.02312379452548885 & 0.007969 \\
2 & 80 & 9.83694126499568e-06 & 0.005068870315851459 & 0.013717 \\
3 & 10 & 1.259368880257572e-05 & 0.10676193729702843 & 0.002861 \\
3 & 20 & 1.3198077537524e-06 & 0.011425764217281809 & 0.003994 \\
3 & 40 & 1.5213209114683934e-07 & 0.0009799573758897395 & 0.007459 \\
3 & 80 & 2.2987904291921213e-08 & 0.00015598964385232783 & 0.013469 \\
4 & 10 & 4.409771863223e-07 & 0.19049143361180731 & 0.003400 \\
4 & 20 & 2.1671801020417547e-08 & 0.00503801530264314 & 0.004049 \\
4 & 40 & 1.2835442708691858e-09 & 0.00016090689460712682 & 0.007865 \\
4 & 80 & 1.2482059830176695e-10 & 1.209617161257448e-05 & 0.013501 \\
5 & 10 & 2.1579297460050384e-08 & 0.07291995110239857 & 0.003462 \\
5 & 20 & 4.908814466020317e-1