<a href="https://colab.research.google.com/github/wvirany/NNGP/blob/main/NNGP2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

[0m

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras.losses import MeanSquaredError

import warnings
warnings.filterwarnings('ignore')

sns.set_style("darkgrid",
              {"axes.facecolor": ".95"})

rc = {
    "axes.facecolor": ".95",
    "figure.facecolor": "#f7f9fc",
    "axes.edgecolor": "#000000",
    "grid.color": "#EBEBE7",
    "font.family": "serif",
    "axes.labelcolor": "#000000",
    "xtick.color": "#000000",
    "ytick.color": "#000000",
    "grid.alpha": 0.4
}

default_palette = 'tab10'

sns.set(rc=rc)

np.random.seed(42)

In [None]:
class SingleLayerNetwork(keras.Model):

  def __init__(self, num_units, activation='tanh'):

    super().__init__()

    self.net = Sequential([
        Dense(num_units, input_shape=(1,), activation=activation,
              kernel_initializer=RandomNormal(mean=0.0, stddev=1.0),
              bias_initializer=RandomNormal(mean=0.0, stddev=1.0)),
        Dense(1, activation='linear',
              kernel_initializer=RandomNormal(mean=0.0, stddev=1/np.sqrt(num_units)),
              bias_initializer=RandomNormal(mean=0.0, stddev=.1))
    ])


  def call(self, X):
    return self.net(X)


  def train_step(self, x_train, y_train, optimizer, loss_fn):
    """
    Performs one step of gradient descent

    Args:
        - x_train: Input data for training (features).
        - y_train: Ground truth values (targets)
    """

    with tf.GradientTape() as tape:

      output = self(x_train)

      loss = loss_fn(y_train, output)

    # Compute the gradient
    gradients = tape.gradient(loss, self.trainable_variables)

    optimizer.apply_gradients(zip(gradients, self.trainable_variables))

    return loss

In [None]:
noise_scale = 0

f = lambda x: 5 * np.sin(x)

train_points = 5
noise_scale = 1e-1

In [None]:
x_train = np.random.uniform(-np.pi, np.pi, 5)
y_train = f(x_train)
y_train += noise_scale * np.random.normal(0, 1, 5)
train = (x_train, y_train)

x_test = np.linspace(-np.pi, np.pi, 50)

In [None]:
nn_ensemble = []

for i in range(50):

  model = SingleLayerNetwork(num_units=512)
  nn_ensemble.append(model)

In [None]:
nn_ensemble[0].summary()

In [None]:
for epoch in range(1):

  fig = plt.figure()

  plt.rcParams.update({
    "mathtext.fontset": "cm",  # Set DejaVu Serif as the font for math text
    "font.family": "serif",             # Use serif fonts by default
    "text.usetex": False,               # Use Matplotlib's mathtext instead of LaTeX
  })

  preds = []

  for nn in nn_ensemble[0:]:

    pred = nn(x_test)

    preds.append(pred)

    plt.plot(x_test, pred, c='pink', lw=.7, alpha=.5)

    nn.train_step(x_train, y_train, optimizer=Adam(), loss_fn=MeanSquaredError())

  mean = np.mean(preds, axis=0).reshape(50)
  std = np.std(preds, axis=0).reshape(50)

  plt.plot(x_test, mean, c='firebrick', lw=1)
  plt.scatter(x_train, y_train, c='k', s=25)

  plt.fill_between(x_test,
                 mean - 2 * std,
                 mean +  2 * std,
                 color='lightblue', alpha=0.2);

  plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
             ['$\pi$', '$-\frac{-\pi}{2}$', r'$0$', r'$\pi/2$', r'$\pi$'])
  plt.yticks([-5, 0, 5])

  # plt.xlabel("$x$")
  # plt.ylabel("$y$")

  plt.xlim(-np.pi, np.pi)
  plt.ylim(-7, 7)