# Importing Code:

In [None]:
# Mount to my Google Drive
from google.colab import drive
import os
import pickle
drive.mount('/content/drive', force_remount=True)
os.chdir("/content/drive/MyDrive/###")

In [None]:
# Import necessary packages
# For graphing
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# And to create animations
import matplotlib.animation as animation
from IPython import display

In [None]:
# Import my code
from model import Linear_Regression
from linearized_model import Linearized_Model

# Visualization with Images:

In [None]:
# Fixed point at which to evaluate the networks
x = tf.constant([1], dtype=tf.float32, shape=[1,1])
tf.print(x)

In [None]:
# Now, initialize our networks
# Initialization scale
alpha = 1
# Initialization shape
w0 = tf.ones([2, 1])

# The nonlinear network
model = Linear_Regression(w0, alpha)

# And the linear network
model_const = Linear_Regression(w0, alpha)
model_linearized = Linearized_Model(model_const)

# And evaluate them at x = 1
tf.print(f"Model: {model(x)}")
tf.print(f"Linearized model: {model_linearized((x, model_const))}")

In [None]:
# Generate a grid of weights at which we will evaluate the network
Nsample = 50
w = tf.linspace(alpha - 2, alpha + 2, Nsample)
w_x, w_y = np.meshgrid(w, w)
weights = np.array([w_x.ravel(),w_y.ravel()])
weights = tf.convert_to_tensor(weights, dtype=tf.float32)
weights = tf.transpose(weights)
tf.print(weights)

In [None]:
# Evaluate the networks at each of the weight vectors w
# NOTE: this code may take a while to run
N = int(tf.shape(weights)[0])
output_model = np.zeros([N,1])
output_linearized = np.zeros([N,1])

for i in range(N):

  if not i % 10**2:
    print(i)

  model.linear_layer_1.w = tf.reshape(weights[i,:], [-1,1])
  output_model[i,0] = model(x)

  model_linearized.linearized_layer_1.w = tf.reshape(weights[i,:], [-1,1])
  output_linearized[i,0] = model_linearized((x, model_const))

In [None]:
# Reshape our one-dimensional arrays for plotting
x = np.reshape(weights[:,0].numpy(), (Nsample,Nsample))
y = np.reshape(weights[:,1].numpy(), (Nsample,Nsample))
z_model = np.reshape(output_model, (Nsample,Nsample))
z_linearized = np.reshape(output_linearized, (Nsample,Nsample))

In [None]:
fig = plt.figure(0)
ax = fig.add_subplot(projection='3d')

# Plot the model evaluated at each w
ax.plot_surface(x, y, z_model, alpha=0.5)
# And the linearized model at each w
ax.plot_surface(x, y, z_linearized, alpha=0.5)
# Finally, the model and linearized model intersect at w = w0
ax.scatter(alpha, alpha, 0, color='black', s=30)

ax.set_xlabel(r'$w_1$')
ax.set_ylabel(r'$w_2$')
ax.set_zlabel(r'$f(w_1, w_2)(1)$')

ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

ax.view_init(10, 60)

plt.show()
# Save the image as a png file
fig.savefig(f'visualize_linearized_{alpha}.png', dpi=300)

# Creating GIFs:

In [None]:
# Function to get the grid of weights for a given initialization scale alpha and number of sample points (for each axis)
# For a given initialization scale alpha, we evaluate the network on [alpha - 2, alpha + 2]^2
def get_weights(alpha, Nsample):
  w = tf.linspace(alpha - 2, alpha + 2, Nsample)
  w_x, w_y = np.meshgrid(w, w)
  weights = np.array([w_x.ravel(),w_y.ravel()])
  weights = tf.convert_to_tensor(weights, dtype=tf.float32)
  return tf.transpose(weights)

In [None]:
# GIF parameters
Nsample = 10
num_frames = 100
fps = 5

# Initialization
x = tf.constant([1], dtype=tf.float32, shape=[1,1])
w0 = tf.ones([2, 1])

# Arrays to hold output values
output_model_array = np.zeros([Nsample**2, num_frames])
output_linearized_array = np.zeros([Nsample**2, num_frames])

# Array containing initialization scales
alphas = np.linspace(0.1, 10, num_frames)

for i in range(num_frames):

  if not i % 10:
    print(i)
  
  # Initialize the nonlinear network
  model = Linear_Regression(w0, alphas[i])

  # And the linear network
  model_const = Linear_Regression(w0, alphas[i])
  model_linearized = Linearized_Model(model_const)

  # Get the corresponding grid of weights
  weights = get_weights(alphas[i], Nsample)

  for j in range(Nsample**2):

    # And evaluate the models on the grid
    model.linear_layer_1.w = tf.reshape(weights[j,:], [-1,1])
    output_model_array[j,i] = model(x)

    model_linearized.linearized_layer_1.w = tf.reshape(weights[j,:], [-1,1])
    output_linearized_array[j,i] = model_linearized((x, model_const))

In [None]:
# Pickle the output np.ndarray objects for posterity
with open('model_output.pkl', 'rb') as f:
    output_model_array = pickle.load(f)

with open('model_linearized_output.pkl', 'rb') as f1:
    output_linearized_array = pickle.load(f1)

In [None]:
# And generate the GIF

# Updates the plot at each frame in the GIF
def update_plot(frame_number, zarray, plot):
    plot[0].remove()
    plot[1].remove()
    plot[2].remove()

    textvar.set_text(r"$\alpha =$" + f"{round(alphas[frame_number], 1)}")

    weights = get_weights(alphas[frame_number], Nsample)
    x = np.reshape(weights[:,0].numpy(), (Nsample,Nsample))
    y = np.reshape(weights[:,1].numpy(), (Nsample,Nsample))

    plot[0] = ax.plot_surface(x, y, np.reshape(output_model_array[:,frame_number], (Nsample, Nsample)), color="#1f77b4", alpha=0.5)
    plot[1] = ax.plot_surface(x, y, np.reshape(output_linearized_array[:,frame_number], (Nsample, Nsample)), color="#ff7f0e", alpha=0.5)
    plot[2] = ax.scatter(alphas[frame_number], alphas[frame_number], 0, color='black', s=30)


# Initialize the plot
fig = plt.figure()
fig.tight_layout()

ax = fig.add_subplot(111, projection='3d')
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax.zaxis.set_rotate_label(False) 
ax.set_xlabel(r'$w_1$', labelpad=0)
ax.set_ylabel(r'$w_2$', labelpad=0)
ax.set_zlabel(r'$f(w_1, w_2)(1)$', rotation=90, labelpad=0)

ax.grid(False)
textvar = ax.text2D(0.10, 0.90, r"$\alpha =$" + f"{round(alphas[0],1)}", transform=ax.transAxes)

zarray = 1
weights = get_weights(alphas[0], Nsample)

plot = [ax.plot_surface(np.reshape(weights[:,0].numpy(), (Nsample,Nsample)), np.reshape(weights[:,1].numpy(), (Nsample,Nsample)), np.reshape(output_model_array[:,0], (Nsample, Nsample)), color="#1f77b4", alpha=0.5), ax.plot_surface(np.reshape(weights[:,0].numpy(), (Nsample,Nsample)), np.reshape(weights[:,1].numpy(), (Nsample,Nsample)), np.reshape(output_linearized_array[:,0], (Nsample, Nsample)), color="#ff7f0e", alpha=0.5), ax.scatter(alphas[0], alphas[0], 0, color='black', s=30)]

ax.view_init(10, 60)

# Instantiate the animation object
ani = animation.FuncAnimation(fig, update_plot, num_frames, fargs=(zarray, plot), interval=1000/fps)

In [None]:
# Visualize the GIF in the notebook file
plt.rcParams['animation.html'] = 'html5'
ani

In [None]:
# Install writer to save GIF
!apt-get update
!apt install imagemagick

In [None]:
# And save the GIF
ani.save('linearized_model.gif', writer='imagemagick', dpi=200)