Add non-linear activation functions.

In [None]:
import numpy as np
from IPython import display
from common import *

In [None]:
expected = expected1

# Num neurons per layer
layers = [2, 6, 1]

# Weights
w = [np.random.rand(m, n) for m, n in zip(layers, layers[1:])]

# Biases
b = [np.random.rand(n) for n in layers[1:]]

In [None]:
lr = 8e-4

def reluF(x): return max(x, 0)
relu = np.vectorize(reluF)

def reluFPrime(x): return 0 if x < 0 else 1
reluPrime = np.vectorize(reluFPrime)

# Calculate the model's output, including values of all intermediate layers
def forward(x):
  xi = [x]

  for i in range(len(layers) - 1):
    xiCur = np.matmul(xi[i], w[i]) + b[i]
    # Don't apply relu activation on output layer
    activ = xiCur if (i == len(layers) - 2) else relu(xiCur)
    xi.append(activ)

  return xi


def loss(out):
  return np.square(out - expected).sum()


# Calculate how changes to the weights and biases affect the loss
def backward(xi):
  dLdX = 2 * (xi[-1] - expected)
  dLdW = []
  dLdB = []

  for i in range(len(layers) - 1):
    prevX = xi[-(i+2)]
    curW = w[-(i+1)]
    curB = b[-(i+1)]

    nonActivPrevX = np.matmul(prevX, curW) + curB
    # No activation function on output layer
    activPrime = np.ones_like(nonActivPrevX) if i == 0 else reluPrime(nonActivPrevX)

    dXdW = elemOuter(prevX, activPrime)
    dLdWCur = (elemDot(dXdW, dLdX)).sum(0)
    dLdW.append(dLdWCur)

    dXdB = activPrime
    dLdBCur = (dLdX * dXdB).sum(0)
    dLdB.append(dLdBCur)

    dXdX = curW * activPrime.sum(0)
    dLdX = np.matmul(dLdX, dXdX.T)
  
  dLdW.reverse()
  dLdB.reverse()
  return dLdW, dLdB


def updateWeights(dLdW, dLdB):
  for i in range(len(w)):
    w[i] -= lr * dLdW[i]

  for i in range(len(b)):
    b[i] -= lr * dLdB[i]

In [None]:
epochs = 100

# Training loop
for i in range(epochs):
  xi = forward(input)

  # Visualize output and loss
  display.clear_output(wait=True)
  plot(forward(test)[-1], expected)
  print("Epoch: ", i + 1)
  print("Loss: ", loss(xi[-1]))
  
  dLdW, dLdB = backward(xi)
  updateWeights(dLdW, dLdB)