# performance of kronecker product linear algebra

Here, the performance of linear algebra (matvec and LU-solve) of 3D kronecker product matrices is tested:

* Matvec: $y_{ijk} = A_{im} B_{jn} C_{ko} x_{mno}$
* Solve for $x$:  $A_{im} B_{jn} C_{ko} x_{mno} =  r_{ijk}$

In [None]:
import numpy             as np
import time
import scipy.sparse as sparse
from scipy.sparse.linalg import splu

import hylife.utilitis_FEEC.projectors_global as proj_glob
import hylife.utilitis_FEEC.evaluationV2 as evaV2

import hylife.utilitis_FEEC.bsplines as bsp

In [None]:
def kron_matvec_3d(kmat,vec3d):
    ''' matrix-vector product with 3d kronecker matrix with 3d vectors
    
        res_ijk = (A_im * B_jn * C_ko) * vec3d_mno
        
        implemented as three matrix-matrix multiplications with intermediate reshape and transpose.
        step1(v1*v2,k0) <= ( kmat[0](k0,v0) * reshaped_vec3d(v0,v1*v2) )^T
        step2(v2*k0,k1) <= ( kmat[1](k1,v1) * reshaped_step1(v1,v2*k0) )^T
        step3(k0*k1*k2) <= ( kmat[2](k2,v2) * reshaped_step2(v2,k0*k1) )^T
        res <= reshaped_step3(k0,k1,k2)
        
        no overhead of numpy reshape command, as they do NOT copy the data.
    Parameters
        ----------
        kmat  : 3 sparse matrices for each direction, 
                  of size (k0,v0),(k1,v1),(k2,v2)
            
        vec3d : 3d array of size (v0,v1,v2)   
            

        Returns
        -------
        res : 3d array of size (k0,k1,k2)

    '''
    v0,v1,v2=vec3d.shape
    k0=kmat[0].shape[0]
    k1=kmat[1].shape[0]
    k2=kmat[2].shape[0]
    res=((kmat[2].dot(((kmat[1].dot(((kmat[0].dot(vec3d.reshape(v0,v1*v2))).T).reshape(v1,v2*k0))).T).reshape(v2,k0*k1))).T).reshape(k0,k1,k2)
    return res

def kron_lusolve_3d(kmatlu,rhs):
    ''' Solve for 3d vector, matrix would be a 3d kronecker matrix, 
        but LU is only solved in each direction.
        
        solve for x: (A_im * B_jn * C_ko) * x_mno =  rhs_ijk
        
        implemented as three matrix-matrix solve with intermediate reshape and transpose.
        step1(r1*r2,r0) <= ( A(r0,r0)^-1 *   reshaped_rhs(r0,r1*r2) )^T
        step2(r2*r0,r1) <= ( B(r1,r1)^-1 * reshaped_step1(r1,r2*r0) )^T
        step3(r0*r1*r2) <= ( C(r2,r2)^-1 * reshaped_step2(r2,r0*r1) )^T
        res <= reshaped_step3(r0,r1,r2)
        
        no overhead of numpy reshape command, as they do NOT copy the data.
    Parameters
        ----------
        kmatlu : 3 already LU decompositions of sparse matrices for each direction, 
                  of size (r0,r0),(r1,r1),(r2,r2)
            
        rhs   : 3d array of size (r0,r1,r2), right-hand size
            

        Returns
        -------
        res : 3d array of size (r0,r1,r2), solution 

    '''    
    r0,r1,r2 = rhs.shape
    res=((kmatlu[2].solve(((kmatlu[1].solve(((kmatlu[0].solve(rhs.reshape(r0,r1*r2))).T).reshape(r1,r2*r0))).T).reshape(r2,r0*r1))).T).reshape(r0,r1,r2)
    return res

def kron_solve_3d(kmat,rhs):
    ''' Solve for 3d vector, matrix would be a 3d kronecker matrix, 
        but system is only solved in each direction.
        
        solve for x: (A_im * B_jn * C_ko) * x_mno =  rhs_ijk
        
        implemented as three matrix-matrix solve with intermediate reshape and transpose.
        step1(r1*r2,r0) <= ( A(r0,r0)^-1 *   reshaped_rhs(r0,r1*r2) )^T
        step2(r2*r0,r1) <= ( B(r1,r1)^-1 * reshaped_step1(r1,r2*r0) )^T
        step3(r0*r1*r2) <= ( C(r2,r2)^-1 * reshaped_step2(r2,r0*r1) )^T
        res <= reshaped_step3(r0,r1,r2)
        
        no overhead of numpy reshape command, as they do NOT copy the data.
    Parameters
        ----------
        kmat  : 3 sparse matrices for each direction, 
                  of size (r0,r0),(r1,r1),(r2,r2)
            
        rhs   : 3d array of size (r0,r1,r2), right-hand size
            

        Returns
        -------
        res : 3d array of size (r0,r1,r2), solution 

    '''    
    r0,r1,r2 = rhs.shape
    res=((splu(kmat[2]).solve(((splu(kmat[1]).solve(((splu(kmat[0]).solve(rhs.reshape(r0,r1*r2))).T).reshape(r1,r2*r0))).T).reshape(r2,r0*r1))).T).reshape(r0,r1,r2)
    return res

In [None]:
# ------------------------
# which projector to test (0, 11, 12, 13, 21, 22, 23, 3):
# ------------------------
comp = 21

# ------------------------
# compare to full 3d implementation (switch because of memory issue!):
# ------------------------
do_full=True

#-----------------
# Create the grid:
#-----------------
# side lengths of logical cube
L = [3., 2., 1.] 

# spline degrees
p = [3, 3, 3]   

# periodic boundary conditions (use 'False' if clamped)
bc = [True, True, True] 

# loop over different number of elements (convergence test)
Nel_cases = [2**n for n in range(2,5)]

# loop over different number of quadrature points per element
Nq_cases = [6]

for Nel_k in Nel_cases:
    
    
    # number of elements
    Nel = [3*Nel_k, 2*Nel_k, Nel_k]   
    print('Nel=', Nel)

    # element boundaries
    el_b = [np.linspace(0., L_i, Nel_i + 1) for L_i, Nel_i in zip(L, Nel)] 

    # knot sequences
    T = [bsp.make_knots(el_b_i, p_i, bc_i) for el_b_i, p_i, bc_i in zip(el_b, p, bc)] 
    
    for Nq_k in Nq_cases:
        
        # number of quadrature points
        Nq = [Nq_k, Nq_k, Nq_k]

        # create an instance of the projector class
        obj      = proj_glob.projectors_3d(T, p, bc, Nq)

        # grid points for error computation
        pts_loc, wts_loc = np.polynomial.legendre.leggauss(5)   # quadrature points per element

        pts1, wts1 = bsp.quadrature_grid(el_b[0], pts_loc, wts_loc)   
        pts2, wts2 = bsp.quadrature_grid(el_b[1], pts_loc, wts_loc) 
        pts3, wts3 = bsp.quadrature_grid(el_b[2], pts_loc, wts_loc) 

        xgrid = [ pts1.flatten(), pts2.flatten(), pts3.flatten() ]   # error grid

        # 1d basis of each direction evaluated at the grid points

        basemat= evaV2.FEM_evalbase_3d(comp, xgrid, T, p, bc)
        
 
        x   = np.zeros((Nel))
        rhs = np.ones((Nel))
        y0,y1,y2 = [len(xgrid[0]),len(xgrid[1]),len(xgrid[2])]
        
        nrep = np.maximum(int(1000*(Nel_cases[0]/Nel_k)**2),1)
        print('number of repetitions    =     %d' % (nrep))
        print("---------------------------------------")
        if(do_full):
            t0 = time.time()
        
            # 3D assemble of the basemat
            bm=sparse.kron(sparse.kron(basemat[0], basemat[1]), basemat[2])
        
            t1 = time.time()
            print('time build 3D sparse     =     %10.7e' % ((t1-t0)))
            print("---------------------------------------")
            
            t0 = time.time()
        
            for i in range(nrep):
                y = (bm.dot(x.flatten())).reshape(y0,y1,y2)
            
            t1 = time.time()
            dt1 = (t1-t0)/nrep
            print('time 3D sparse matvec    =     %10.7e' % (dt1))
        
        t0 = time.time()
        
        for i in range(nrep):
            y = kron_matvec_3d(basemat,x)
            
        t1 = time.time()
        dt2 = (t1-t0)/nrep
        if(do_full):    
            print('time kron sparse matvec  =     %10.7e, %6.2f x faster' % (dt2,dt1/dt2))
        else:
            print('time kron sparse matvec  =     %10.7e' % (dt2))
        print("---------------------------------------")        
        if(do_full):
            t0 = time.time()
        
            full_lu= splu(sparse.kron(sparse.kron(obj.N[0], obj.N[1]), obj.N[2],format='csc'))

            t1 = time.time()
            dt1 = (t1-t0)
            print('time build 3D LU(NNN)    =     %10.7e' % (dt1))     
            
        t0 = time.time()
        
        matlu = [splu(obj.N[0]),splu(obj.N[1]),splu(obj.N[2])]
        
        t1 = time.time()
        dt2 = (t1-t0)
        if(do_full):
            print('time build 3x 1D LU(N)   =     %10.7e, %6.2f x faster' % (dt2,dt1/dt2))
        else:
            print('time build 3x 1D LU(N)   =     %10.7e' % (dt2))
        print("---------------------------------------")
        if(do_full):
            t0 = time.time()
        
            for i in range(nrep):
                x = full_lu.solve(rhs.flatten()).reshape((Nel))
            
            t1 = time.time()
            dt1 = (t1-t0)/nrep
            print('time solve 3D LU         =     %10.7e' % (dt1))
            
        t0 = time.time()
        
        for i in range(nrep):
            x = kron_lusolve_3d(matlu,rhs)
        
        
        t1 = time.time()
        dt2 = (t1-t0)/nrep
        if(do_full):
            print('time solve kron LU       =     %10.7e, %6.2f x faster' % (dt2,dt1/dt2))
        else:
            print('time solve kron LU       =     %10.7e' % (dt2))
        print("---------------------------------------")      
        t0 = time.time()
        
        for i in range(nrep):
            x = kron_solve_3d(obj.N,rhs)
            
        t1 = time.time()
        dt2 = (t1-t0)/nrep
        if(do_full):
            print('time solve kron          =     %10.7e, %6.2f x faster' % (dt2,dt1/dt2))
        else:
            print('time solve kron          =     %10.7e' % (dt2))
        print("---------------------------------------")     
        print('done.')
        
    print('')

In [None]:
r0=obj.N[0].shape[0]
r1=obj.N[1].shape[0]

rhs2d = np.random.rand(r0,r1)
matlu2d=splu(sparse.kron(obj.N[0], obj.N[1],format='csc'))

matlu = [splu(obj.N[0]),splu(obj.N[1]),splu(obj.N[2])]

x2d = matlu2d.solve(rhs2d.flatten()).reshape(r0,r1)

y2d = (matlu[0].solve(rhs2d)).T
y2d = (matlu[1].solve(y2d)).T
#1step
y2d = (matlu[1].solve((matlu[0].solve(rhs2d)).T)).T

np.amax(np.abs(x2d-y2d))

In [None]:
r0=obj.N[0].shape[0]
r1=obj.N[1].shape[0]
r2=obj.N[2].shape[0]
E = [np.eye(r0),np.eye(r1),np.eye(r2)]
rhs3d = np.random.rand(r0,r1,r2)
matlu3d=splu(sparse.kron(sparse.kron(obj.N[0], obj.N[1]),obj.N[2],format='csc'))

matlu = [splu(obj.N[0]),splu(obj.N[1]),splu(obj.N[2])]

x3d = matlu3d.solve(rhs3d.flatten()).reshape(r0,r1,r2)
ytmp1 = ((matlu[0].solve(rhs3d.reshape(r0,r1*r2))).T).reshape(r1,r2*r0)
ytmp2 = ((matlu[1].solve(ytmp1)).T).reshape(r2,r0*r1)
y3d   = ((matlu[2].solve(ytmp2)).T).reshape(r0,r1,r2)
#1step
y3d  = kron_lusolve_3d(matlu,rhs3d)

np.amax(np.abs(x3d-y3d))

In [None]:
%timeit y3d  = kron_lusolve_3d(matlu,rhs3d)