# 10 - Spline spaces

In this tutorial we will learn some details about the tensor-product spline spaces used in Struphy.

## Uni-variate spline spaces

The theory is roughly explained [here](https://struphy.pages.mpcdf.de/struphy/sections/discretization.html#uni-variate-spline-spaces). Let us have a look at the Struphy API. We shall create 12 uni-variate spline spaces, namely 6 periodic and 6 clamped spaces (3 N- and and 3 D-spaces, respectively) of different degree and dimension.

In [1]:
from mpi4py import MPI
from struphy.feec.psydac_derham import Derham

comm = MPI.COMM_WORLD

Nel = [16, 32, 64]  # Number of grid cells
p = [3, 3, 3]  # spline degrees

spl_kind = [True]*3  # periodic
derham_periodic = Derham(Nel, p, spl_kind, comm=comm)

spl_kind = [False]*3  # clamped
derham_clamped = Derham(Nel, p, spl_kind, comm=comm)

The 12 uni-variate spaces can be accessed through the FE spaces $V_0$ and $V_3$, repectively:

In [2]:
periodic_0s = derham_periodic.Vh_fem['0'].spaces
periodic_3s = derham_periodic.Vh_fem['3'].spaces
clamped_0s = derham_clamped.Vh_fem['0'].spaces
clamped_3s = derham_clamped.Vh_fem['3'].spaces

Let us get a sense of the spline space attributes:

In [None]:
from struphy.tutorials.utilities import print_all_attr

print_all_attr(periodic_0s[0])

Here is a comparison pf the space attributes:

In [None]:
import numpy as np

np.set_printoptions(precision=2)

for d, (periodic_0, periodic_3, clamped_0, clamped_3) in enumerate(zip(periodic_0s, periodic_3s, clamped_0s, clamped_3s)):
    print('-'*30)
    print(f'Direction {d + 1}:\n')
    print('The degrees of the D-splines is always one lower then the degree of N-splines:')
    print(f'{periodic_0.degree = }')
    print(f'{periodic_3.degree = }')
    print(f'{clamped_0.degree  = }')
    print(f'{clamped_3.degree  = }\n')
    
    print('The break points (j.e. grid points) are always the same:')
    print(f'{periodic_0.breaks = }')
    print(f'{periodic_3.breaks = }')
    print(f'{clamped_0.breaks  = }')
    print(f'{clamped_3.breaks  = }\n')
    
    print('The number of grid cells defined by the break points is always the same:')
    print(f'{periodic_0.ncells = }')
    print(f'{periodic_3.ncells = }')
    print(f'{clamped_0.ncells  = }')
    print(f'{clamped_3.ncells  = }\n')
    
    print('The basis type is indicated: "B" stands for N-spline and "M" stands for D-spline (historic reasons):')
    print(f'{periodic_0.basis = }')
    print(f'{periodic_3.basis = }')
    print(f'{clamped_0.basis  = }')
    print(f'{clamped_3.basis  = }\n')
    
    print('The dimension of the spline space is `ncells` for periodic and `ncells + p` for clamped:')
    print(f'{periodic_0.nbasis = }')
    print(f'{periodic_3.nbasis = }')
    print(f'{clamped_0.nbasis  = }')
    print(f'{clamped_3.nbasis  = }\n')
    
    print('The knot sequences are the break points with `p` points added to the left and to the right:')
    print(f'{periodic_0.knots = }')
    print(f'{periodic_3.knots = }')
    print(f'{clamped_0.knots  = }')
    print(f'{clamped_3.knots  = }\n')
    
    print('The Greville points are the center of mass points of the splines:')
    print(f'{periodic_0.greville = }')
    print(f'{periodic_3.greville = }')
    print(f'{clamped_0.greville  = }')
    print(f'{clamped_3.greville  = }\n')

Next we want to find out about the indexing of the splines. For this we can use the function `find_span`, which returns the **knot span index** at a given point $\eta \in [0, 1]$.
For a degree $p$, the knot span index $i$ identifies the indices $i-p, i-p+1,\ldots, i$ of all $p + 1$ non-zero basis functions at a given location $\eta \in [0, 1]$. 

It will be useful to print the indices of the non-zero basis functions at each break point:

In [5]:
from struphy.bsplines.bsplines import find_span

if False:
    for d, (periodic_0, periodic_3, clamped_0, clamped_3) in enumerate(zip(periodic_0s, periodic_3s, clamped_0s, clamped_3s)):
        print('-'*30)
        print(f'Direction {d + 1}:\n')
        for eta in periodic_0.breaks:
            print(f'{eta = }:')
            i_p_0 = find_span(periodic_0.knots, periodic_0.degree, eta)
            i_p_3 = find_span(periodic_3.knots, periodic_3.degree, eta)
            i_c_0 = find_span(clamped_0.knots, clamped_0.degree, eta)
            i_c_3 = find_span(clamped_3.knots, clamped_3.degree, eta)
            periodic_0_inds = np.arange(i_p_0 - periodic_0.degree, i_p_0 + 1)
            periodic_3_inds = np.arange(i_p_3 - periodic_3.degree, i_p_3 + 1)
            clamped_0_inds = np.arange(i_c_0 - clamped_0.degree, i_c_0 + 1)
            clamped_3_inds = np.arange(i_c_3 - clamped_3.degree, i_c_3 + 1)
            print(f'{periodic_0.degree = }, {periodic_0_inds = }')
            print(f'{periodic_3.degree = }, {periodic_3_inds = }')
            print(f'{clamped_0.degree  = }, {clamped_0_inds  = }')
            print(f'{clamped_3.degree  = }, {clamped_3_inds  = }')

We can deduce the following: let us denote the spline degree by $p$ and the number of gird cells by $n$. Moreover, let us denote the knot sequence by 

$$
T = (\eta_0, \ldots, \eta_{p-1}, \eta_p = 0.0, \eta_{p+1}, \ldots, \eta_{p + n} = 1.0, \eta_{p + n + 1}, \ldots, \eta_{2p + n}) \,.
$$

Hence, the break points are indexed by $(\eta_p = 0.0, \ldots, \eta_{p + n} = 1.0)$. The **knot span index** of any break point $\eta_i$ with $p \leq i \leq p + n$ is $i$, and thus the nonzero basis splines are $N_j^p$ with $i - p \leq j \leq i$. In the periodic case the index $j$ is actually $mod(j, n)$ to account for the circular nature of the indices. Likewise, the support of any basis spline $N_j^p$ is 

$$
supp\, N_j^p = [\eta_j, \eta_{j + p + 1})\,.
$$

## Quasi interpolation

Assume $f \in C([0, 1])$ to be a continuous function on the unit interval. Moreover, let $\mathbb S^p_n$ denote the spline space of degree $p$ obtained from a uniform partition of $[0, 1]$ into $n$ cells, $\mathbb S^p_n = span(N^p_i)_{i=0}^{N-1}$, where $N=n$ in the periodic and $N = n + p$ in the clamped case. Quasi interpolation is a projection $C([0, 1]) \to \mathbb S^p_n$ defined by

$$
 Q f(\eta) = \sum_{i=0}^{N-1} \lambda_i f \, N^p_i(\eta)\,,
$$

where $\lambda_i$ are functionals $C([0, 1]) \to \mathbb R$ to be determined. In quasi interpolation the functionals $\lambda_i$ need to be "local", in the same sense as the basis splines $N_i^p$ have a local support $supp\, N_i^p = [\eta_i, \eta_{i + p + 1})$. "Local" means that $\lambda_i$ should need information of $f$ only in the close vicinity of $\eta_i$. This is in contrast to the global interpolation problem, where the computation of each coefficient $\lambda_i f$ cannot be done seperately, but rather is done by inverting a global collocation matrix; this process needs information of $f$ on the whole interval $[0, 1]$.

In order for quasi interpolation to be useful it should reproduce the spline basis,

$$
 Q N^p_i = N^p_i \qquad \forall \, i\,,
$$

and it should have the optimal order,

$$
 ||Qf - f|| \sim (1/n)^{p+1}\,.
$$

Let us try to play around with the definition of the $\lambda_i$:

In [6]:
import numpy as np
from struphy.bsplines.bsplines import basis_funs

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# local projector for given spline space
def make_local_proj(space):
    N = space.nbasis
    ncells = space.ncells
    p = space.degree
    T = space.knots
    periodic = space.periodic
    basis = space.basis
    print(f"{N = }, {p = }, {periodic = }, {basis=='M' = }")
    
    def local_proj(fun):
        lambdas = np.zeros(N)
        for i in range(N):
            pts, nloc = _get_pts(i, p, T)
            start_span = find_span(T, p, pts[0])
            end_span = find_span(T, p, pts[-1])
            inds = np.arange(start_span - p, end_span + 1) # indices of nonzero basis functions on pts
            pos = np.argwhere(inds == i)[0, 0] # this is very important, as it tells us which of the computed lambdas will be taken as lambda_i
            Cmat = np.zeros((nloc, nloc))
            for j, pt in enumerate(pts):
                span = find_span(T, p, pt)
                sl = slice(span - start_span, span - start_span + p + 1)
                #print(f'{pt = }, {span = }, {sl = }')
                Cmat[j, sl] = basis_funs(T, p, pt, span, basis=='M')
            #print(f'{Cmat = }')
            rhs = fun(pts)
            lambda_jk = np.linalg.solve(Cmat, rhs)
            lambdas[i] = lambda_jk[pos]
        return lambdas
    
    # local evaluation points
    def _get_pts(i, p, T):
        if periodic:
            mu = i + p
            nu = i + p + 1
        else:
            if i < p:
                mu = p
                nu = p + 1 + i
            elif i > ncells - 1:
                diff = ncells + p - 1 - i
                mu = ncells + p - 1 - diff
                nu = ncells + p
            else:
                mu = i 
                nu = i + 1 + p
        nloc = p + nu - mu # number of local interpolation points (equal to number of nonzero basis funs in the interval)
        pts = np.linspace(T[mu], T[nu], nloc + 1)[:-1]
        #print(f'{pts = }')
        return pts, nloc
    
    return local_proj, N, ncells, p, T, basis

Let's test the projection of a sine function:

In [None]:
from struphy.bsplines.evaluation_kernels_1d import evaluation_kernel_1d
from matplotlib import pyplot as plt

# function to project
fun = lambda eta: np.sin(2*np.pi*eta) 

# figure
plt.figure(figsize=(16, 4))

# convergence test
max_errors = []
for d, space in enumerate(clamped_0s):
    local_proj, N, ncells, p, T, basis = make_local_proj(space)
    normalize = basis=='M'

    # project
    lambdas = local_proj(fun)

    etas = np.linspace(0., 1., 100)
    fun_h = np.zeros(100)
    for k, eta in enumerate(etas):
        span = find_span(T, p, eta)
        ind1 = np.arange(span - p, span + 1) % N
        basis = basis_funs(T, p, eta, span, normalize=normalize)
        fun_h[k] = evaluation_kernel_1d(p, basis, ind1, lambdas)
    
    max_errors += [np.max(np.abs(fun(etas) - fun_h))]
    print(f'max error: {max_errors[-1]}')
    
    plt.subplot(1, 3, d + 1)
    plt.plot(etas, fun_h, 'o', label='fun_h')
    plt.plot(etas, fun(etas), label='fun')
    plt.title(f'{ncells = }')
    plt.legend()
    
print(f'\nConvergence rate: {np.log2(max_errors[-2]/max_errors[-1])}, expected: {float(p + 1)}')

Let's check that the basis functions are exactly reproduced by `local_proj`:

In [None]:
local_proj, N, ncells, p, T, basis = make_local_proj(clamped_0s[0])
normalize = basis=='M'

def make_basis_fun(i):
    def fun(etas):        
        if isinstance(etas, float) or isinstance(etas, int):
            etas = np.array([etas])
        out = np.zeros_like(etas)
        for j, eta in enumerate(etas):
            span = find_span(T, p, eta)
            inds = np.arange(span - p, span + 1) % N
            pos = np.argwhere(inds == i)
            #print(f'{pos = }')
            if pos.size > 0:    
                pos = pos[0, 0]
                out[j] = basis_funs(T, p, eta, span, normalize=normalize)[pos]
            else:
                out[j] = 0.
        return out
    return fun

In [None]:
plt.figure(figsize=(16, (N//4 + 1)*4))

for j in range(N):
    fun = make_basis_fun(j)
    lambdas = local_proj(fun)

    etas = np.linspace(0., 1., 100)
    fun_h = np.zeros(100)
    for k, eta in enumerate(etas):
        span = find_span(T, p, eta)
        ind1 = np.arange(span - p, span + 1) % N
        basis = basis_funs(T, p, eta, span, normalize=normalize)
        fun_h[k] = evaluation_kernel_1d(p, basis, ind1, lambdas)
        
    print(f'{j = }, max error: {np.max(np.abs(fun(etas) - fun_h))}')
    
    plt.subplot(N//4 + 1, 4, j + 1)    
    plt.plot(etas, fun_h, 'o', label='fun_h')
    plt.plot(etas, fun(etas), label='fun')
    plt.title(f'Basis function {j = }')
    plt.legend()

Finally, we check the sparsity pattern of the following basis projection operator:

$$
 A_{ij} = \hat \Pi^{loc}_i\left( \sin(2\pi\eta) \, N_j^p \right)
$$

In [None]:
plt.figure(figsize=(16, (N//4 + 1)*4))

for j in range(N):
    fun = lambda eta: make_basis_fun(j)(eta) * np.sin(2*np.pi*eta)
    lambdas = local_proj(fun)

    etas = np.linspace(0., 1., 100)
    fun_h = np.zeros(100)
    for k, eta in enumerate(etas):
        span = find_span(T, p, eta)
        ind1 = np.arange(span - p, span + 1) % N
        basis = basis_funs(T, p, eta, span, normalize=normalize)
        fun_h[k] = evaluation_kernel_1d(p, basis, ind1, lambdas)
        
    print(f'projected basis function {j = }, max error: {np.max(np.abs(fun(etas) - fun_h))}')
    i_inds = np.where(np.abs(lambdas) > 1e-15)[0] % N
    print(f'Nonzero i indices: {i_inds = }')
    print(f'{np.all(i_inds - j <= p) = }')
    
    plt.subplot(N//4 + 1, 4, j + 1)    
    plt.plot(etas, fun_h, 'o', label='fun_h')
    plt.plot(etas, fun(etas), label='fun')
    plt.title(f'Basis function {j = }')
    plt.legend()