In [1]:
import matplotlib.pyplot as plt

from tpelm import *
from tpelm.bspline import BSpline
from tpelm.tensor_grid import TensorGrid
from tpelm.base import fit, factors_pinv
from tpelm.tucker_tensor import TuckerTensor

jax.config.update("jax_enable_x64", True)

# B-Spline fitting of a flower state

First we want to fit a flower state in the unit cube $[-0.5, 0.5]^3$. The flower state is given by the normalized version of
\begin{align*}
    m_1(x_1, x_2, x_3) &= x_1x_3 \\
    m_2(x_1, x_2, x_3) &= x_2x_3 + \frac{1}{8} x_2^3x_3^3 \\
    m_3(x_1, x_2, x_3) &= 1.
\end{align*}

In [2]:
def flower_state(x):
    a, b, c = 1, 2, 1
    mx = 1 / a * x[..., 0] * x[..., 2]
    my = 1 / c * x[..., 1] * x[..., 2] + 1 / b**3 * x[..., 1]**3 * x[..., 2]**3
    mz = jnp.ones_like(mx)
    m = jnp.stack([mx, my, mz], axis=-1)
    m = m / jnp.linalg.norm(m, axis=-1, keepdims=True)
    return m

Next, we need to define the B-Spline model:

In [3]:
r, k = 10, 5
model = BSpline(
    TensorGrid(*[jnp.linspace(-0.5, 0.5, r)] * 3),
    degree=k - 1
)

Fitting can be done by first defining a quadrature tensor grid and then call the `fit` method:

In [4]:
quad_tg = TensorGrid(*[jnp.array([-0.5, 0.5])] * 3).to_gauss(140)
flower_state_ft = model.fit(quad_tg, flower_state)

The error can be computed as follows

In [5]:
tg_val = TensorGrid(*([jnp.linspace(-0.5, 0.5, 200)] * 3))
mag_true = jnp.apply_along_axis(flower_state, -1, tg_val.grid)
mag_pred = flower_state_ft(tg_val)
err = jnp.max(jnp.abs(mag_true - mag_pred))
err

Array(2.15792957e-08, dtype=float64)

Combinig this, we can compute some statistics for different B-Spline models:

In [6]:
@partial(jax.jit, static_argnames=("f"))
def fit_function(f, model):
    quad_tg = TensorGrid(*[jnp.array([-0.5, 0.5])] * 3).to_gauss(140)
    return model.fit(quad_tg, f)


def error(f, ft_model):
    tg_val = TensorGrid(*([jnp.linspace(-0.5, 0.5, 200)] * 3))
    mag_true = jnp.apply_along_axis(f, -1, tg_val.grid)
    mag_pred = ft_model(tg_val)
    return jnp.max(jnp.abs(mag_true - mag_pred))

    
for k in [2, 3, 4, 5, 6, 7]:
    for r in [10, 20, 40, 80]:
        model = BSpline(TensorGrid(*[jnp.linspace(-0.5, 0.5, r)] * 3), degree=k - 1)
        flower_state_ft = fit_function(flower_state, model)
        flower_state_ft.core.block_until_ready()  # compile
        err = error(flower_state, flower_state_ft)
        t = %timeit -q -o -n 10 -r 2 fit_function(flower_state, model).core.block_until_ready()
        print(f"k={k}, r={r}, error={err:.4e}, fitting_time={t.average * 1000:.3f} ms")

k=2, r=10, error=7.0427e-04, fitting_time=8.703 ms
k=2, r=20, error=1.5670e-04, fitting_time=13.157 ms
k=2, r=40, error=3.6477e-05, fitting_time=31.677 ms
k=2, r=80, error=9.8369e-06, fitting_time=75.159 ms
k=3, r=10, error=1.2594e-05, fitting_time=7.586 ms
k=3, r=20, error=1.3198e-06, fitting_time=13.632 ms
k=3, r=40, error=1.5213e-07, fitting_time=33.524 ms
k=3, r=80, error=2.2988e-08, fitting_time=72.892 ms
k=4, r=10, error=4.4098e-07, fitting_time=11.056 ms
k=4, r=20, error=2.1672e-08, fitting_time=15.192 ms
k=4, r=40, error=1.2836e-09, fitting_time=36.254 ms
k=4, r=80, error=1.2483e-10, fitting_time=74.003 ms
k=5, r=10, error=2.1579e-08, fitting_time=8.738 ms
k=5, r=20, error=4.9088e-10, fitting_time=16.566 ms
k=5, r=40, error=1.3663e-11, fitting_time=40.674 ms
k=5, r=80, error=6.6765e-13, fitting_time=78.948 ms
k=6, r=10, error=1.7179e-09, fitting_time=11.325 ms
k=6, r=20, error=1.7144e-11, fitting_time=16.750 ms
k=6, r=40, error=2.4014e-13, fitting_time=36.828 ms
k=6, r=80, erro

# B-Spline fitting of a vortex state

Here we do the same for a vortex state given by
\begin{align*}
    m_1(x_1, x_2, x_3) &= -\frac{x_2}{r} \sqrt{1 - \exp(-4 \frac{r^2}{r_c^2})} \\
    m_2(x_1, x_2, x_3) &= \frac{x_1}{r} \sqrt{1 - \exp(-4 \frac{r^2}{r_c^2})} \\
    m_3(x_1, x_2, x_3) &= \exp(-2 \frac{r^2}{r_c^2}),
\end{align*}
where $r = \sqrt{x_1^2 + x_2^2 }$ and $r_c = 0.14$.

In [7]:
def vortex_state(x):
    rc = 0.14
    r = jnp.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2)
    k = r**2 / rc**2
    c = jnp.sqrt(1 - jnp.exp(-4 * k))
    mx = -jnp.asarray(jnp.where(jnp.abs(r) < 1e-9, 0.0, x[..., 1] / r) * c)
    my = jnp.asarray(jnp.where(jnp.abs(r) < 1e-9, 0.0, x[..., 0] / r) * c)
    mz = jnp.exp(-2 * k)
    mag = jnp.stack([mx, my, mz], axis=-1)
    return mag  # no normalization


In [8]:
for k in [2, 3, 4, 5, 6, 7]:
    for r in [10, 20, 40, 80]:
        model = BSpline(TensorGrid(*[jnp.linspace(-0.5, 0.5, r)] * 3), degree=k - 1)
        vortex_state_ft = fit_function(vortex_state, model)
        vortex_state_ft.core.block_until_ready()  # compile
        err = error(vortex_state, vortex_state_ft)
        t = %timeit -q -o -n 10 -r 2 fit_function(vortex_state, model).core.block_until_ready()
        print(f"k={k}, r={r}, error={err:.4e}, fitting_time={t.average * 1000:.3f} ms")

k=2, r=10, error=3.3749e-01, fitting_time=8.498 ms
k=2, r=20, error=6.3484e-02, fitting_time=15.036 ms
k=2, r=40, error=2.3124e-02, fitting_time=33.686 ms
k=2, r=80, error=5.0689e-03, fitting_time=76.487 ms
k=3, r=10, error=1.0676e-01, fitting_time=9.086 ms
k=3, r=20, error=1.1426e-02, fitting_time=15.419 ms
k=3, r=40, error=9.7996e-04, fitting_time=34.358 ms
k=3, r=80, error=1.5599e-04, fitting_time=73.659 ms
k=4, r=10, error=1.9049e-01, fitting_time=9.873 ms
k=4, r=20, error=5.0380e-03, fitting_time=15.894 ms
k=4, r=40, error=1.6091e-04, fitting_time=38.011 ms
k=4, r=80, error=1.2096e-05, fitting_time=75.564 ms
k=5, r=10, error=7.2920e-02, fitting_time=10.506 ms
k=5, r=20, error=1.5826e-03, fitting_time=16.440 ms
k=5, r=40, error=1.7630e-05, fitting_time=44.598 ms
k=5, r=80, error=6.4886e-07, fitting_time=82.587 ms
k=6, r=10, error=1.2999e-01, fitting_time=14.059 ms
k=6, r=20, error=1.1279e-03, fitting_time=21.743 ms
k=6, r=40, error=3.2265e-06, fitting_time=43.125 ms
k=6, r=80, erro