# Low-rank functional approximation  by hand
This is just an illustrative example of ALS applied to calculate successive  rank-1 approximations 

In [None]:
%load_ext autoreload
%autoreload 2
# %matplotlib widget

In [None]:
import numpy as np
from matplotlib import cm
from tensorflow import keras
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
 
# our stuff
from nnu import gss_kernels as ssk
from nnu import points_generator as pgen
from nnu import gss_report_generator as ssrg


## Create the function to fit and sample it randomly

In [None]:
ndim = 2 
sim_range = 4

stretch = 1.1 
nx = 10000
input_seed = 1917
xs = pgen.generate_points(
    sim_range*stretch, nx, ndim, 'random', seed=input_seed)[0]

In [None]:
input_f_spec = 'laplace_1'
genres =  ssrg.generate_inputs_and_nodes(
    ndim = ndim,
    nsamples = nx,
    nnodes = nx,
    input_f_spec = input_f_spec,
    input_seed=input_seed,
    nsr_stretch=1.0,
)
func = genres[-1]
ys = func(xs)

## Calculate nodes and basis function matrices

In [None]:
nnodes = 2*int(pow(nx, 1.0/ndim))
kernel = 'invquad'
scale_mult = 4.0
global_scale = 2*sim_range*stretch / nnodes
knl_f = ssk.global_kernel_dict(global_scale * scale_mult)[kernel]
nodes = np.linspace(-sim_range*stretch, sim_range*stretch, nnodes, endpoint=True)


knl_v1s = knl_f(xs[:,0:1] - np.expand_dims(nodes,0)).numpy()
knl_v2s = knl_f(xs[:,1:2] - np.expand_dims(nodes,0)).numpy()


## Function for doing one iteration of rank-1 functional approximation

In [None]:
def rank_1_knl_approx_iter(mleft, mright, ys, wl0, wr0, reltol = 1e-4):

    nnodes = mleft.shape[1] # == mright.shape[1]

    reg_coef = 1e-4
    wl_prev = wl0
    wr_prev = wr0
    old_rel_err = 0.0
    while True:

        # Ar = np.diag(mleft @ wl_prev) @ mright
        Ar = np.repeat(mleft @ wl_prev.reshape(-1,1), nnodes, axis=1) * mright
        
        # wr = np.linalg.lstsq(Ar, ys, rcond = None)[0]
        regrr = Ridge(alpha=reg_coef, fit_intercept=False)
        wr = regrr.fit(Ar, ys).coef_
        
        # Al = np.diag(mright @ wr) @ mleft
        Al = np.repeat(mright @ wr.reshape(-1,1), nnodes, axis=1) * mleft
        
        # wl = np.linalg.lstsq(Al, ys, rcond = None)[0]
        regrl = Ridge(alpha=reg_coef, fit_intercept=False)
        wl = regrl.fit(Al, ys).coef_

        rel_err = np.linalg.norm(wl - wl_prev)/np.linalg.norm(wl_prev) + np.linalg.norm(wr - wr_prev)/np.linalg.norm(wr_prev)
        print_rel_err = False
        if print_rel_err: print(rel_err)
        wl_prev = wl.copy()
        wr_prev = wr.copy()
        if abs(rel_err - old_rel_err) < reltol:
            break
        old_rel_err = rel_err
    fit_f = lambda w1,w2 : (mleft @ w1)*(mright @ w2)
    return wl, wr, fit_f(wl,wr), fit_f

## Rank-n approximation
Call rank-1 approximation successively on residuals of the previous approximation

In [None]:
def rank_n_knl_approx(mleft, mright, ys, rank, reltol = 1e-4):

    ys_resid = ys
    nsamples = len(ys)

    wls = []
    wrs = []
    fls = []
    frs = []
    mses = []

    for r in range(rank):
        avw = np.sqrt(np.abs(ys).mean()/np.mean(np.abs(knl_v1s.sum(axis=1)*knl_v2s.sum(axis=1))))
        wl_fit, wr_fit, ys_fit, fit_f = rank_1_knl_approx_iter(
            mleft = knl_v1s, mright = knl_v2s, ys = ys_resid, wl0 = avw*np.ones(nnodes), wr0 = avw*np.ones(nnodes), reltol=reltol )
        
        funcl = lambda x, wl_fit = wl_fit: knl_f(np.expand_dims(x,-1) - np.expand_dims(nodes,0)).numpy() @ wl_fit
        funcr = lambda x,wr_fit = wr_fit : knl_f(np.expand_dims(x,-1) - np.expand_dims(nodes,0)).numpy() @ wr_fit

        ys_resid = ys_resid - ys_fit
        mse = np.linalg.norm(ys_resid - ys_fit)/np.sqrt(nsamples)
        print(f'rank-{r} mse = {mse:.4f} on step {r}')

        wls.append(wl_fit)
        wrs.append(wr_fit)
        fls.append(funcl)
        frs.append(funcr)
        mses.append(mse)

    fit_f = lambda w1,w2 : (mleft @ w1)*(mright @ w2)

    ys_fit = np.array([fit_f(w1,w2) for w1,w2 in zip(wls,wrs)]).sum(axis=0)
    return wls, wrs, ys_fit,fls,frs


## Run the fit  and see how it looks like

In [None]:
rank = 3
wls,wrs,ys_fit,fls,frs = rank_n_knl_approx(
            mleft = knl_v1s, mright = knl_v2s, ys = ys, rank=rank, reltol = 1e-4)

In [None]:
%matplotlib inline
print(f'Overall mse = {1 - np.linalg.norm(ys - ys_fit)/np.linalg.norm(ys):.4f}')
plt.plot(ys_fit,ys,'.', label = 'actual vs fit')
plt.legend(loc = 'best')
plt.show()

for r in range(rank):
    plt.plot(xs[:,0], fls[r](xs[:,0]),'.', markersize = 1, label = f'left  f for r={r}')
    plt.plot(xs[:,1], frs[r](xs[:,1]),'.', markersize = 1, label = f'right f for r={r}')
plt.legend(loc = 'best')
plt.show()


In [None]:
%matplotlib auto
# %matplotlib inline

f1 = plt.figure()
ax = f1.add_subplot(projection='3d')

ax.scatter(xs[:,0],xs[:,1],ys,  # c=ys,
           cmap=cm.coolwarm, marker='.', s=1, alpha = 0.75, label = 'actual')
ax.scatter(xs[:,0],xs[:,1],ys_fit,  # c=ys,
           cmap=cm.coolwarm, marker='.', s=1, alpha = 0.75, label = 'fit')
plt.title('Actual vs fit')
plt.legend(loc = 'best')
plt.show()