## Lecture 28: Stochastic Gradient Descent + Neural Networks for Dynamical Systems

In today's lecture, we will explore
1. Implementing our own Staochastic Gradient Descent Algorithm
2. Neural Networks ability to predict future states for a dynamical system (a system of ODEs).  In particular, we will look at how to train an NN to predict Lorenz System Dynamics

In [1]:
import numpy as np
import os

import matplotlib.pyplot as plt
from matplotlib import rc

plt.rcParams['xtick.labelsize']=16      # change the tick label size for x axis
plt.rcParams['ytick.labelsize']=16      # change the tick label size for x axis
plt.rcParams['axes.linewidth']=1        # change the line width of the axis
plt.rcParams['xtick.major.width'] = 3   # change the tick line width of x axis
plt.rcParams['ytick.major.width'] = 3   # change the tick line width of y axis
rc('text', usetex=False)                # disable LaTeX rendering in plots
rc('font',**{'family':'DejaVu Sans'})   # set the font of the plot to be DejaVu Sans

In [2]:
from scipy import integrate

from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, MaxPool2D
from keras import optimizers
from keras.layers import Activation
from keras import backend as K

### 1. Stochastic Gradient Descent

Here we present the implementation based on the textbook, but there are many different ways to implement SGD, see below for a few examples:

1. https://www.geeksforgeeks.org/ml-stochastic-gradient-descent-sgd/
2. https://realpython.com/gradient-descent-algorithm-python/
3. https://scikit-learn.org/stable/modules/sgd.html

In [None]:
from matplotlib import rcParams
from scipy import interpolate
from mpl_toolkits.mplot3d import Axes3D

rcParams.update({'font.size': 18})
plt.rcParams['figure.figsize'] = [12, 12]

iterMax = 100 # maximum iteration number
h = 0.1 # resolution for function evaluation
x_grid = np.arange(-6, 6+h, h)
y_grid = np.copy(x_grid)
n = len(x_grid)
X, Y = np.meshgrid(x_grid, y_grid)
F1 = 1.5 - 1.6 * np.exp(-0.05 * (3 * (X+3)**2 + (Y+3)**2))
F = F1 + (0.5 - np.exp(-0.1*(3 * (X-3)**2 + (Y-3)**2)))
dFy, dFx = np.gradient(F, h, h)

x0 = np.array([4, 0, -5])
y0 = np.array([0, -5, 2])

x = np.zeros(iterMax+1)
y = np.copy(x)
f = np.copy(x)

x_out = np.zeros((iterMax+1, 3));
y_out = np.copy(x_out)
f_out = np.copy(x_out)

interp_type = 'linear'

for jj in range(3):
  q = np.random.permutation(n)
  i1 = np.sort(q[:10])
  q2 = np.random.permutation(n)
  i2 = np.sort(q2[:10])
  x[0] = x0[jj]
  y[0] = y0[jj]

  F_i12 = F[i1[:, np.newaxis],i2]
  dFx_i12 = dFx[i1[:, np.newaxis],i2]
  dFy_i12 = dFy[i1[:, np.newaxis],i2]

  F_interp = interpolate.interp2d(x_grid[i1], y_grid[i2], F_i12, kind=interp_type)
  dfx_interp = interpolate.interp2d(x_grid[i1], y_grid[i2], dFx_i12, kind=interp_type)
  dfy_interp = interpolate.interp2d(x_grid[i1], y_grid[i2], dFy_i12, kind=interp_type)

  f[0] = F_interp(x[0],y[0])
  dfx = dfx_interp(x[0],y[0])
  dfy = dfy_interp(x[0],y[0])

  tau = 1.5
  for j in range(iterMax):
    x[j+1] = x[j]-tau*dfx # update x, y, and f
    y[j+1] = y[j]-tau*dfy
    q = np.random.permutation(n)
    i1 = np.sort(q[:10])
    q2 = np.random.permutation(n)
    i2 = np.sort(q2[:10])

    F_interp = interpolate.interp2d(x_grid[i1], y_grid[i2], F_i12, kind=interp_type)
    dfx_interp = interpolate.interp2d(x_grid[i1], y_grid[i2], dFx_i12, kind=interp_type)
    dfy_interp = interpolate.interp2d(x_grid[i1], y_grid[i2], dFy_i12, kind=interp_type)

    f[j+1] = F_interp(x[j+1], y[j+1])
    dfx = dfx_interp(x[j+1], y[j+1])
    dfy = dfy_interp(x[j+1], y[j+1])

    if np.abs(f[j+1]-f[j]) < 10**(-6): # check convergence
      print('Converged after {} iterations'.format(j+1))
      break
    if j == iterMax-1:
      print('Failed to converge after {} iterations'.format(j+1))
  x_out[:,jj] = x
  y_out[:,jj] = y
  f_out[:,jj] = f

  # If converged before iterMax, replace 0s with NaNs
  x_out[(j+2):,jj] = np.nan
  y_out[(j+2):,jj] = np.nan
  f_out[(j+2):,jj] = np.nan

In [None]:
plt.figure()
plt.contour(X, Y, F, colors = 'k')
for jj in range(3):
  plt.plot(x_out[:,jj], y_out[:,jj], 'o')
plt.show()

fig,ax = plt.subplots(1, 1, subplot_kw={'projection': '3d'})
ax.plot_surface(X, Y, F, linewidth=0, cmap='binary', alpha=0.3)
for jj in range(3):
  ax.scatter(x_out[:,jj], y_out[:,jj], f_out[:,jj]+0.1, 'o', s=100)
ax.view_init(elev = 40, azim = -100)
plt.show()

### 2. NN for Dynamical System
#### 2.1 Generate Trajetory for Lorenz System Using `scipy.integrate.odeint`.

Equations for Lorzen System:

\begin{align}
x′ &= \sigma(y - x)\\
y′ &= x(\rho - z) - y\\
z′ &= xy − \beta z\\
\end{align}

In [None]:
## Simulate the Lorenz System

# total time and timestep for ODE solver
dt = 0.01
T = 8
t = np.arange(0,T+dt,dt)

# parameters for Lorenz System
beta = 8/3.0
sigma = 10
rho = 28

nn_input = np.zeros((100*(len(t)-1), 3))
nn_output = np.zeros_like(nn_input)

fig,ax = plt.subplots(1,1,subplot_kw={'projection': '3d'})

def lorenz_deriv(x_y_z, t0, sigma=sigma, beta=beta, rho=rho):
  x, y, z = x_y_z
  return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

np.random.seed(123)
x0 = -15 + 30 * np.random.random((100, 3))

x_t = np.asarray([integrate.odeint(lorenz_deriv, x0_j, t)
                  for x0_j in x0])

for j in range(100):
  nn_input[j*(len(t)-1):(j+1)*(len(t)-1),:] = x_t[j,:-1,:]
  nn_output[j*(len(t)-1):(j+1)*(len(t)-1),:] = x_t[j,1:,:]
  x, y, z = x_t[j,:,:].T
  ax.plot(x, y, z,linewidth=1)
  ax.scatter(x0[j,0],x0[j,1],x0[j,2],color='r')

ax.view_init(18, -113)
plt.show()

In [None]:
net = Sequential()
net.add(Dense(10, input_dim=3, activation='sigmoid'))
net.add(Dense(10, activation='relu'))
net.add(Dense(3, activation='linear'))
net.compile(loss='mse', optimizer='adam')
History = net.fit(nn_input, nn_output, epochs=30)

In [None]:
num_traj = 100
ynn = np.zeros((num_traj, len(t), 3))
ynn[:, 0, :] = -15 + 30 * np.random.random((num_traj, 3))
for jj, tval in enumerate(t[:-1]):
  ynn[:, jj+1, :] = net.predict(ynn[:, jj, :])

In [None]:
fig,ax = plt.subplots(1,1,subplot_kw={'projection': '3d'})
for i in range(num_traj):
  ax.plot(ynn[i, :, 0], ynn[i, :, 1], ynn[i, :, 2], linewidth=1)

ax.view_init(18, -113)
plt.show()