In [None]:
# Mount to my Google Drive
from google.colab import drive
import os
import pickle
drive.mount('/content/drive', force_remount=True)
os.chdir("/content/drive/MyDrive/###")

In [None]:
# Import necessary packages
import tensorflow as tf
import numpy as np

# Import code for NTK callback function
from NTK_callback import NTKCallback

# As well as the code for the linear regression model from Woodworth 2020
from model import Linear_Regression

In [None]:
# Write TensorFlow callback function for early stopping when a certain loss threshold is reached
class LossThreshold(tf.keras.callbacks.Callback):
    
    def __init__(self, threshold, **kwargs):
        super(LossThreshold, self).__init__(**kwargs)
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None): 
        
        loss = logs["loss"]
        if loss <= self.threshold:
            self.model.stop_training = True

# Initializing the Linear Model and Neural Tangent Kernel (NTK) Callback

In [None]:
# Create our training data set

# For reproducibility of results, set the random seed
tf.random.set_seed(500)
np.random.seed(500)

# Dimension of input points
d = 20
# Number of training points
N = 10

# Create the \beta which parameterizes our linear regression model
# Here, we generate \beta by taking each entry to be an i.i.d. Unif(0,1) random variable
beta = tf.random.uniform([d, 1], dtype=tf.float32)

# As in Woodworth et al., suppose our training points are drawn from a d-dimensional
# standard multivariare normal distribution
train_x = np.random.multivariate_normal(np.zeros((d)), np.identity(d), size=N)
# NOTE: train_x here has dimension N x d
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32)

# Compute the corresponding y-values
train_y = tf.reshape(tf.matmul(train_x, beta), (-1, 1))

In [None]:
# Initialize the linear regression model

# Initialization scale
alpha = 0.1
# Initialization shape
w0 = tf.ones([2*d, 1])

# Create our model
model = Linear_Regression(w0, alpha=alpha)
# As well as the NTK callback object
ntk_callback = NTKCallback(train_x, step=10)

In [None]:
# Optimize the model using gradient descent 
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)

# We will use the mean squared error as our loss function L
# NOTE: this is the same loss that is used in Woodworth et al. (the squared loss), except scaled by a factor of 1/N (hence 'mean')
MSE = tf.keras.losses.MeanSquaredError()
model.compile(optimizer, loss=MSE)

In [None]:
# Finally, fit out model
# We stop when our training loss reaches 10^{-4}
model.fit(train_x, train_y, epochs=10**4, verbose=1, callbacks=[ntk_callback, LossThreshold(1e-4)])

In [None]:
# We have indeed have all evaluations of the NTK during training
ntk_callback.NTK_evals[0:3]
# Notice that each list item is a 10 x 10 tensor representing the NTK evaluated on the training grid of N = 10 points

In [None]:
# Number of NTK evaluations
len(ntk_callback.NTK_evals)

 # Visualizing the NTK During Training

In [None]:
# Import plotting tools
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# And packages to create animation
import matplotlib.animation as animation
from IPython import display

In [None]:
# Visualize the Neural Tangent Kernel of the model upon initialization
# This is equal to the NTK of the corresponding linearized model from Chizat et al. 2018
fig, ax = plt.subplots()

ax.title.set_text(r"$\langle \nabla_w f(w)(x), \nabla_w f(w)(x) \rangle$, epoch = 0")

# Labels for points
pt_labels = [r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$', r'$x_5$', r'$x_6$', r'$x_7$', r'$x_8$', r'$x_9$', r'$x_{10}$']

# Plot the NTK evaluated at the training points
im = ax.imshow(ntk_callback.NTK_evals[0], cmap="Greens", origin="lower")

# Add a colorbar to the right of the plot
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.1)
fig.colorbar(im, cax=cax, orientation='vertical')


ax.set_xticks(range(0, 10))
ax.set_xticklabels(pt_labels)
ax.set_yticks(range(0, 10))
ax.set_yticklabels(pt_labels)

fig.show()
fig.savefig(f'linearized_NTK_{alpha}.png', dpi=300)

In [None]:
# Parameters for our animation
num_frames = len(ntk_callback.NTK_evals)
fps = 10

In [None]:
# Minimum, maximum for colorbar
# We want the scale of the the plot to be constant; otherwise, we cannot visualize change throughout training
col_min = -1
col_max = 2.5

In [None]:
def update_plot(frame_number, zarray, plot):
  # Remove the previous plot
  plot[0].remove()

  # And add the new one
  plot[0] = ax.imshow(ntk_callback.NTK_evals[frame_number], cmap="Greens", origin="lower", vmin=col_min, vmax=col_max)
  
  fig.show()
  return

# Initialize the plot (using our previous code)
fig, ax = plt.subplots()

plot = [ax.imshow(ntk_callback.NTK_evals[0], cmap="Greens", origin="lower", vmin=col_min, vmax=col_max)]

ax.set_xticks(range(0, 10))
ax.set_xticklabels(pt_labels)
ax.set_yticks(range(0, 10))
ax.set_yticklabels(pt_labels)

fig.show()

zarray=1

# Instantiate the animation object
ani = animation.FuncAnimation(fig, update_plot, num_frames, fargs=(zarray, plot), interval=1000/fps)

In [None]:
# Visualize the GIF in the notebook file
plt.rcParams['animation.html'] = 'html5'
ani

In [None]:
# Install writer to save GIF
!apt-get update
!apt install imagemagick

In [None]:
# And save the GIF
ani.save(f'NTK_{alpha}.gif', writer='imagemagick', dpi=200)