# Multivariate Function Interpolation

In [None]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline  

import plotly.graph_objects as go

tfk = tf.keras
tfkl = tfk.layers

## Prepare datasets using a fractal-like function to interpolate

In [None]:
def fractal_function(x, y):
    x = 2*x 
    y = 2*y
    z = np.sin(10 * np.pi * x) * np.cos(10 * np.pi * y) + np.sin(np.pi * (x**2 + y**2))
    z += np.abs(x - y) + (np.sin(5 * x * y) / (0.1 + np.abs(x + y)))
    z *= np.exp(-0.1 * (x**2 + y**2))
    
    # Add noise to z
    noise = np.random.normal(0, 0.1, z.shape)
    z += noise
    
    return z

Visualize our function on regular sub-mesh of $[-1, 1]^{2}$.

In [None]:
X, Y = np.meshgrid(
    np.linspace(-1, 1, 100), 
    np.linspace(-1, 1, 100)
)
Z = fractal_function(X, Y)

fig = go.Figure(data=[
    go.Surface(
        z=Z, 
        x=X, 
        y=Y, 
    )
])

fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project_z=True))

fig.update_layout(
    title='Original function', autosize=True,
    width=512, 
    height=512,
)

fig.show()

Generate a low-discrepancy training dataset using fast Halton sampling containing 22.500 samples.

In [None]:
from skopt.space import Space
from skopt.sampler import Halton


n_samples = 150*150
space = Space([(-1.0, 1.0), (-1.0, 1.0)])

sampler = Halton()

x_train = np.array(sampler.generate(space.dimensions, n_samples))
y_train = fractal_function(x_train[:,0], x_train[:,1]).reshape((-1, 1))

## Define models

We define a simple MLP as our baseline and a number of KANs using polynomial bases.

In [None]:
from arnold.layers.polynomial.orthogonal import (
    AskeyWilson,
    Chebyshev1st,
    Gegenbauer, 
)

from arnold.layers.wavelet import (
    Bump, 
    Ricker, 
    Poisson
)

from arnold.layers.radial import (
    GaussianRBF,
    InverseMultiQuadricRBF
)

In [None]:
all_models = {
    'mlp': tfk.Sequential([
        tfkl.Dense(1024, activation="relu"),
        tfkl.Dense(512, activation="relu"),
        tfkl.Dense(1, activation="linear")
        ],
        name='mlp'
    ),
    'askey_wilson': tfk.Sequential([
            AskeyWilson(input_dim=2, output_dim=8, degree=4),
            tfkl.LayerNormalization(),
            AskeyWilson(input_dim=8, output_dim=16, degree=6),
            tfkl.LayerNormalization(),
            AskeyWilson(input_dim=16, output_dim=32, degree=4),
            tfkl.LayerNormalization(),
            AskeyWilson(input_dim=32, output_dim=1, degree=3),
        ],
        name="askey_wilson_kan" 
    ),
    'chebyshev_1st': tfk.Sequential([
            Chebyshev1st(input_dim=2, output_dim=8, degree=4),
            tfkl.LayerNormalization(),
            Chebyshev1st(input_dim=8, output_dim=16, degree=6),
            tfkl.LayerNormalization(),
            Chebyshev1st(input_dim=16, output_dim=32, degree=4),
            tfkl.LayerNormalization(),
            Chebyshev1st(input_dim=32, output_dim=1, degree=3),
        ],
        name="chebyshev_1st_kan" 
    ),
    'gegenbauer': tfk.Sequential([
            Gegenbauer(input_dim=2, output_dim=8, degree=4),
            tfkl.LayerNormalization(),
            Gegenbauer(input_dim=8, output_dim=16, degree=6),
            tfkl.LayerNormalization(),
            Gegenbauer(input_dim=16, output_dim=32, degree=4),
            tfkl.LayerNormalization(),
            Gegenbauer(input_dim=32, output_dim=1, degree=3),
        ],
        name="gegenbauer_kan" 
    ),
    'bump': tfk.Sequential([
            Bump(input_dim=2, output_dim=8),
            tfkl.LayerNormalization(),
            Bump(input_dim=8, output_dim=16),
            tfkl.LayerNormalization(),
            Bump(input_dim=16, output_dim=32),
            tfkl.LayerNormalization(),
            Bump(input_dim=32, output_dim=1),
        ],
        name='bump_kan'
    ),
    'ricker': tfk.Sequential([
            Ricker(input_dim=2, output_dim=8),
            tfkl.LayerNormalization(),
            Ricker(input_dim=8, output_dim=16),
            tfkl.LayerNormalization(),
            Ricker(input_dim=16, output_dim=32),
            tfkl.LayerNormalization(),
            Ricker(input_dim=32, output_dim=1),
        ],
        name='ricker_kan'
    ),
    'poisson': tfk.Sequential([
            Poisson(input_dim=2, output_dim=8),
            tfkl.LayerNormalization(),
            Poisson(input_dim=8, output_dim=16),
            tfkl.LayerNormalization(),
            Poisson(input_dim=16, output_dim=32),
            tfkl.LayerNormalization(),
            Poisson(input_dim=32, output_dim=1),
        ],
        name='poisson_kan'
    ),
    'gaussian_rbf': tfk.Sequential([
            GaussianRBF(input_dim=2, output_dim=8),
            tfkl.LayerNormalization(),
            GaussianRBF(input_dim=8, output_dim=16),
            tfkl.LayerNormalization(),
            GaussianRBF(input_dim=16, output_dim=32),
            tfkl.LayerNormalization(),
            GaussianRBF(input_dim=32, output_dim=1),
        ],
        name='gaussian_rbf_kan'
    ),
    'inverse_multiquadric_rbf': tfk.Sequential([
            InverseMultiQuadricRBF(input_dim=2, output_dim=8),
            tfkl.LayerNormalization(),
            InverseMultiQuadricRBF(input_dim=8, output_dim=16),
            tfkl.LayerNormalization(),
            InverseMultiQuadricRBF(input_dim=16, output_dim=32),
            tfkl.LayerNormalization(),
            InverseMultiQuadricRBF(input_dim=32, output_dim=1),
        ],
        name='inverse_multiquadric_rbf_kan'
    ),
}

## Train all models

Build and compile all models. 

In [None]:
for name, model in tqdm(all_models.items()):
    model.build((None, 2))
    model.compile(
        optimizer=tf.keras.optimizers.legacy.Adam(),
        loss='huber',
        metrics=['mse']
    )

print('Trainable parameter', {name: np.sum([np.prod(p.shape) for p in model.trainable_weights]) for (name, model) in all_models.items()})
print('Non-trainable parameter', {name: np.sum([np.prod(p.shape) for p in model.non_trainable_weights]) for (name, model) in all_models.items()})

In [None]:
EPOCHS = 100
BATCH_SIZE = 512

Train all models. On a Apple M1 Max this will take ~3 minutes.

In [None]:
model_train_histories = {
    name: model.fit(
        x_train,
        y_train,
        epochs=EPOCHS, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        verbose=0
    ) for (name, model) in all_models.items()
}

Plot all loss & mse curves.

In [None]:
import pandas as pd

for name, hist in model_train_histories.items():
    pd.DataFrame(hist.history).plot(figsize=(8,5), title=name)
    plt.show()

## Visualize all interpolants on a test dataset.

In [None]:
X_test, Y_test = np.meshgrid(
    np.linspace(-1, 1, 400), 
    np.linspace(-1, 1, 400)
)

x_test = np.stack([X_test.ravel(), Y_test.ravel()], axis=-1)
y_true = fractal_function(X_test, Y_test)

In [None]:
all_predictions = { 
    name: model.predict(x_test).reshape((400,400)) for (name, model) in tqdm(all_models.items())
}

Visualize all interpolations.

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2, 
    cols=int((len(all_predictions) + 1)/2), 
    start_cell="top-left", 
    subplot_titles=['original', ] + list(all_predictions.keys()),
    specs=[[{"type": "surface"}, {"type": "surface"}, {"type": "surface"}, {"type": "surface"}, {"type": "surface"}], [{"type": "surface"}, {"type": "surface"}, {"type": "surface"}, {"type": "surface"}, {"type": "surface"},]]
)

fig.add_trace(
    go.Surface(
        z=Z, 
        y=X, 
        x=Y, 
    ),
    row=1, col=1
)

for i, (name, y) in enumerate(all_predictions.items()):
    fig.add_trace(
        go.Surface(
            z=y, 
            x=X_test, 
            y=Y_test, 
        ),
        row=1 + int(i / 5) if i < 4 else max(i/5, 2), 
        col=i+2 if i<4 else 1+(1 + i%5)%5
    )

fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project_z=True))
fig.update_layout(scene=dict(zaxis=dict(dtick=1, type='linear')))

fig.update_layout(
    autosize=True,
    width=2048, 
    height=1024,
)

In [None]:
all_evaluations = { 
    name: model.evaluate(x_test) for (name, model) in tqdm(all_models.items())
}

In [None]:
{ name: np.min(model_train_histories[name].history['mse']) for name in all_models.keys() }