### 2-layer NN learning + basis functions visualization

Author: Teun Mathijssen

In [None]:
%pylab inline

In [None]:
def f(x):
    """Target function."""
#     return x*np.cos(1.5 * x)
#     return np.heaviside(x, 0.5)
    return np.abs(x)
#     return np.sin(x)


def h(x):
    """Activation function."""
    # ReLU
#     return np.select([x < 0, x >= 0], [0, x])
    # Tanh
    return np.tanh(x)
    # Sigmoid
#     return 1 / (1 + np.exp(-x))


def h_p(x):
    """Derivative of activation function."""
    # ReLU
#     return np.select([x < 0, x >= 0], [0, 1])
    # Tanh
    return 1 - np.tanh(x)**2
    # Sigmoid
#     return (1/(1 + np.exp(-x)))*(1 - 1/(1 + np.exp(-x)))

In [None]:
def NN_propagate(t, x, W_1, w_2, lamb):
    """Propagate batch x through our 2-layer NN using h(x) as activation function.
    Return output, batch derivatives and hidden layer output."""
    # Propagate forward.
    # Augment x.
    x_ = np.vstack((np.ones((1, x.shape[-1])), x))
    
    a = W_1 @ x_

    z = h(a)
    # Augment z.
    z_ = np.vstack((np.ones((1, z.shape[-1])), z))
    
    y = w_2 @ z_
    
    # Propagate backward.
    dy = y - t
    
    # Infer batch size for quadratic regularization.
    bs = t.size
    
    # Use numpy's superior broadcasting capabilities to calculate all
    # derivatives at once and sum them.
    dE_dw_2_batch = dy[:, np.newaxis] * z_[np.newaxis]
    dE_dw_2 = np.sum(dE_dw_2_batch, axis=-1) + bs*lamb*w_2
    
    # We don't need dz_0.
    dz = h_p(a) * (w_2.T @ dy)[1:]
    
    dE_dW_1_batch = dz[:, np.newaxis] * x_[np.newaxis]
    dE_dW_1 = np.sum(dE_dW_1_batch, axis=-1) + bs*lamb*W_1
        
    return np.squeeze(y), dE_dw_2, dE_dW_1, z_


def SGD(ts, xs, W_1, w_2, lamb=1e-4, nEpochs=1000, lr=0.003, bs=2, plotBase=10):
    """Perform SGD and plot the results."""
    if xs.shape[-1] % bs != 0:
        print("Error: batch size must divide input size.")
        return
    
    # List of all epoch numbers.
    epochs = np.arange(0, nEpochs + 1)
    
    # 0 epoch prediction and error.
    ys, _, _, z_ = NN_propagate(ts, xs, W_1, w_2, lamb)
    errors_plot = [np.sum((ys - ts)**2) / 2]
    epochs_plot = [epochs[0]]
    plot_data(epochs[0], xs, ys, ts, z_, w_2)
    
    # SGD.
    for curEpoch in epochs[1:]:
        # Shuffle inputs and targets.
        indices_random = np.arange(N)
        np.random.shuffle(indices_random)
        xs_random = xs[indices_random]
        ts_random = ts[indices_random]

        # Process batches.
        for b in np.arange(0, N, bs):
            _, dE_dw_2, dE_dW_1, z_ = NN_propagate(ts_random[b:b+bs], xs_random[b:b+bs], W_1, w_2, lamb)
            W_1 = W_1 - lr * dE_dW_1
            w_2 = w_2 - lr * dE_dw_2
                            
        # Use logarithmic epoch intervals for plotting.
        if curEpoch % plotBase ** np.floor(np.log(curEpoch) / np.log(plotBase)) == 0 or curEpoch == nEpochs:
            # Current epoch prediction and error.
            ys, _, _, z_ = NN_propagate(ts, xs, W_1, w_2, lamb)
            errors_plot.append(np.sum((ys - ts)**2) / 2)
            epochs_plot.append(curEpoch)
            plot_data(curEpoch, xs, ys, ts, z_, w_2)
    
    plot_errors(epochs_plot, errors_plot)

In [None]:
def plot_data(curEpoch, xs, ys, ts, z_, w_2):
    """Plot current epoch output layer and hidden layer."""
    plt.figure(figsize=(14, 3))
    
    plt.subplot(1, 2, 1)
    plt.plot(xs, ts, color="lightgray", marker="o", markersize=4, linestyle="")
    plt.plot(xs, ys, color="black", marker="o", markersize=4, linewidth=1)
    plt.ylim((-1.2, 1.2))
    plt.title("Output epoch " + str(curEpoch))
    
    plt.subplot(1, 2, 2)
    for i in range(len(z_)):
        plt.plot(xs, w_2[:, i] * z_[i], linestyle=(0, (2, np.random.randint(5, 8))))
    plt.title("Hidden layer output")
    
    plt.subplots_adjust(wspace=0.12, hspace=0, top=1, bottom=0)
    plt.show()
    

def plot_errors(epochs_plot, errors_plot):
    """Plot error vs. epoch on log-log scale."""
    plt.figure(figsize=(14, 3))
    
    plt.plot(epochs_plot, errors_plot, color="red", marker="x", markersize=6, linewidth=1)
    plt.xlabel("epoch")
    plt.ylabel("E")
    plt.yscale("log")
    plt.xscale("log")
    plt.grid()
    plt.title("Error vs. epoch")
    
    plt.show()

In [None]:
# Number of data points.
N = 50

# Input layer.
D = 1
# Hidden layer.
M = 3
# Output layer.
K = 1

# (Input)-(hidden layer) weight matrix.
W_1 = np.random.normal(0, 0.3, (M, D+1))

# (Hidden layer)-(output) weight matrix.
w_2 = np.random.normal(0, 0.3, (K, M+1))

# Input and target.
xs = np.linspace(-math.pi, math.pi, N)
ts = f(xs)

### Running instructions

Keep the output of the cell below expanded and press `shift + enter` to see the plots as soon as they are generated.

In [None]:
SGD(ts, xs, W_1, w_2, lamb=0, nEpochs=10000, lr=0.003, bs=5, plotBase=10)