# Create and test the generalized Stochastic Sampling (gSS) model
## This notebook uses ``gss_model_factory`` and other higher-level interfaces 

We create a ``onedim`` version of the model with a one-dim optimizer for the frequency bounds (aka scales) and linear regression for the outer (linear) weights 

In [None]:
%load_ext autoreload
%autoreload 2
#%matplotlib widget

In [None]:
import scipy
import numpy as np
from scipy import stats
import tensorflow as tf
from matplotlib import cm
from tensorflow import keras
from functools import reduce
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from timeit import default_timer as timer

# our stuff
from nnu import points_generator as pgen
from nnu import laplace_kernel, fit_function_factory
from nnu import gss_layer, gss_model_factory, gss_report_config

# globals
np.set_printoptions(precision =3, suppress=True)


## Create the function we want to fit

In [None]:
ndim=2
laplace_mixture = fit_function_factory.KernelType.LpM
stds = [1.5, 1.0, 0.5][-ndim:]
off_diag_correl = 0.0
laplace_shift = 1
means = np.array([[1, 1, 0], [-1, -1, 0], [0.5, -0.5, 0]])
means = means[:, :ndim]
cov_multipliers = [0.5, 0.3, 0.1]  
mix_weights = [0.4, 0.35, 0.35]  
covar_matr = laplace_kernel.simple_covar_matrix(stds, off_diag_correl)

function_to_fit = fit_function_factory.generate_nd(
    laplace_mixture, covar_matr, 
    shift=laplace_shift,
    means=means,
    cov_multipliers=cov_multipliers,
    mix_weights=mix_weights)


## Generate the learning set (inputX, inputY)


In [None]:

nsamples = 10000
input_seed = 1917
sim_range = 4

points_type = "random"
inputX = pgen.generate_points(sim_range, nsamples, ndim, points_type, seed = input_seed)[0]
inputY = function_to_fit(inputX)


## Generate nodes for our model


In [None]:

nodes_type = "random"
# nodes_type="regular"
nodes_seed = 2022

nnodes = 200
nsr_stretch = 1.2
nodes_sim_range = sim_range * nsr_stretch 

nodes = pgen.generate_points(
    nodes_sim_range, nnodes, ndim, nodes_type, seed = nodes_seed, plot = 0)[0]
nnodes = len(nodes)

nnodes_per_dim = round(pow(nnodes, 1./ndim))
global_scale = pgen.average_distance(nodes)

## Set up the gSS ``onedim`` model

In [None]:
tf_dtype=tf.float32 # tf.float64
np_dtype = np.float32 if tf_dtype == tf.float32 else np.float64

# common specs we will be re-using a few times
model_specs = dict(
    model_type=gss_report_config.ModelType.SSL,
    use_outer_regression=True,
    optimize_knots=False,
    optimize_scales=True,
    scales_dim=gss_layer.ScalesDim.OnlyOne,
    apply_final_aggr=False,
    kernel='invquad',
)

seed_for_keras = 2021
model = gss_model_factory.generate_model(
    ndim=ndim,
    global_scale=global_scale,
    nodes=nodes,
    inputX=inputX,
    inputY=inputY,
    scales=None,
    seed_for_keras=seed_for_keras,

    **model_specs,
)

# Because of how the model is constructed, fit(...) uses some fake xs. 
# The real inputX, inputY is 'side-loaded' into the model through 'xpts' and 'ypts' layers

fake_dim = 1
output_size =  nsamples
x_to_use = np.zeros((output_size, fake_dim), dtype=np_dtype)
y_to_use = np.zeros(output_size, dtype=np_dtype)

model.summary()
model.get_layer('predict_y_model').summary()


## Calculatre average gradients
We create a 'test' model to calculate gradients. The test model
will have its outer weights (linear coefficients) calculated by regression
so this is the "base" approximation from the paper

In [None]:

grad_model = gss_model_factory.generate_model_for_testing_2(
    model,
    ndim=ndim,
    global_scale=global_scale,
    nodes=nodes,

    **model_specs,

    average_slopes=None,
    l2_regularizer=1e-8,
    tf_dtype=tf.float32,
    nsr_stretch=nsr_stretch,
    generate_more_nodes=False, 
    sim_range=sim_range)

# calc the gradients using tf
x_tensor = tf.convert_to_tensor(inputX, dtype=tf_dtype)
with tf.GradientTape() as t:
    t.watch(x_tensor)
    output = grad_model(x_tensor)
    gradients = t.gradient(output, x_tensor).numpy()

# the average directyional risk magnitudes
average_slopes = np.linalg.norm(gradients, axis=0)
average_slopes /= np.amax(average_slopes)
print("Normalized directional risk magnitudes: ", average_slopes)

## Re-init the model with calculated average slopes

In [None]:

print_params = True
# Redefine the model now that we know average_slopes
model = gss_model_factory.generate_model(
    ndim=ndim,
    global_scale=global_scale,
    nodes=nodes,
    inputX=inputX,
    inputY=inputY,
    scales=None,
    seed_for_keras=seed_for_keras,

    **model_specs,

    average_slopes=average_slopes,
    sim_range=sim_range,
)

## Fit the gSS model
We run a one-dim optimizer for the global scale of the kernels with linear regression for outer (linear) coefficients on each step

In [None]:
# Get access to the inner weight corresponding to the common scale (known to be a scalar in this particular case)
#  that drives all scales so we can set it directly and optimize in a 1d optimizer
predict_model = model.get_layer('predict_y_model')
prod_kernel = predict_model.get_layer('prodkernel')
orig_weights = prod_kernel.get_weights()

def obj_f_1d(w):
    '''
    1D objective function for scipy.optimize.minimize_scalar
    '''
    # set the right weight to w, reuse the others (nodes -- not trainable here)
    prod_kernel.set_weights([np.array([[w]]), orig_weights[1]])

    fit = predict_model.predict(x_to_use, batch_size=nsamples)
    mse = np.linalg.norm(fit[:, 0] - inputY)/np.linalg.norm(inputY)
    return mse

res = scipy.optimize.minimize_scalar(
    obj_f_1d, bounds=(0.25, 1.25), method='bounded', options={'xatol': 1e-1})

# This sets the scale to the final solution. Prints it for good measure
achieved_mse = obj_f_1d(res.x)
print(
    f'1d optimizer found solution scale = {res.x:.4f} and achieved mse = {achieved_mse:.4f}')


## Get the results of the fit

In [None]:
predict_model = model.get_layer('predict_y_model')

fit = predict_model.predict(x_to_use, batch_size=nsamples)
learn_mse = np.linalg.norm(fit[:, 0] - inputY)/np.linalg.norm(inputY)
learn_mae = np.linalg.norm(
    fit[:, 0] - inputY, ord=1)/np.linalg.norm(inputY)/np.sqrt(nsamples)  # note we divide by L2 norm on purpose

print(f'learn_mse = {learn_mse:.4f}, learn_mae = {learn_mae:.4f}')

# %matplotlib auto
%matplotlib inline
plt.plot(fit[:,0], inputY, '.')
plt.title('learn: actual vs fit')
plt.show()

# Plot the fit in 3D

In [None]:
%matplotlib auto
# %matplotlib inline
plot_step = 1
ax = plt.axes(projection='3d')

ax.scatter(inputX[::plot_step,0], inputX[::plot_step,1],  inputY[::plot_step],
            cmap=cm.coolwarm, marker='.', alpha = 0.75, s=1, label = 'actual')
ax.scatter(inputX[::plot_step,0], inputX[::plot_step,1],  fit[::plot_step,0],
            cmap=cm.coolwarm, marker='.', alpha = 0.75, s = 1, label = 'fit')

plt.xlabel('x1')
plt.ylabel('x2')
plt.legend(loc = 'best')
plt.show()

## Generate the final model with an increased number of nodes using calibrated scales

In [None]:
test_model = gss_model_factory.generate_model_for_testing_2(
    model,
    ndim=ndim,
    global_scale=global_scale,
    nodes=nodes,

    **model_specs,

    average_slopes=average_slopes,
    l2_regularizer=1e-8,
    tf_dtype=tf.float32,
    nsr_stretch=nsr_stretch,
    generate_more_nodes=True, 
    sim_range=sim_range)


## See the results of the fit for an independently generates test set (testX,testY)

In [None]:
test_res_seed=314

testX = pgen.generate_points(sim_range, nsamples, ndim, 
    points_type, seed = test_res_seed)[0]
testY = function_to_fit(testX)

test_fit = test_model.predict(testX)

test_mse = np.linalg.norm(test_fit[:, 0] - testY)/np.linalg.norm(testY)
test_mae = np.linalg.norm(
    test_fit[:, 0] - testY, ord=1)/np.linalg.norm(testY)/np.sqrt(nsamples)  # note we divide by L2 norm on purpose


print(f'test_mse = {test_mse:.4f}, test_mae = {test_mae:.4f}')

# %matplotlib auto
%matplotlib inline
plt.plot(test_fit[:,0], testY, '.')
plt.title('Test: actual vs fit')
plt.show()

In [None]:
%matplotlib auto
# %matplotlib inline
plot_step = 1
ax = plt.axes(projection='3d')

ax.scatter(testX[::plot_step,0], testX[::plot_step,1],  testY[::plot_step],
            cmap=cm.coolwarm, marker='.', alpha = 0.75, s=1, label = 'actual')
ax.scatter(testX[::plot_step,0], testX[::plot_step,1],  test_fit[::plot_step,0],
            cmap=cm.coolwarm, marker='.', alpha = 0.75, s = 1, label = 'fit')

plt.xlabel('x1')
plt.ylabel('x2')
plt.legend(loc = 'best')
plt.show()

## Plot outer weights as a function of nodes

In [None]:
ws = test_model.get_layer('final').get_weights()[0][:, 0]
test_nodes=test_model.get_layer('prodkernel').get_weights()[2]

In [None]:
%matplotlib auto
# %matplotlib inline

plot_step = 1
ax = plt.axes(projection='3d')

ax.scatter(test_nodes[::plot_step,0], test_nodes[::plot_step,1], ws[::plot_step],
            c = ws[::plot_step],
            cmap=cm.coolwarm, marker='.',  alpha = 0.75, label = 'outer weights')

plt.xlabel('nodes_x')
plt.ylabel('nodex_y')
plt.legend(loc = 'best')
plt.show()