In [1]:
import numpy as np

%matplotlib inline


Warm-up: numpy
--------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x using Euclidean error.

This implementation uses numpy to manually compute the forward pass, loss, and
backward pass.

A numpy array is a generic n-dimensional array; it does not know anything about
deep learning or gradients or computational graphs, and is just a way to perform
generic numeric computations.



In [2]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N = 64
D_in = 1000
H = 100
D_out = 10

x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 29343510.969464988
1 26453087.159270234
2 26865641.795774058
3 26587583.07522152
4 23369553.836728744
5 17417229.154374555
6 11136892.622852199
7 6449322.370463408
8 3662366.3285348127
9 2183778.1742578046
10 1419681.3930133197
11 1010568.8953483881
12 773624.4216397287
13 622530.5959467802
14 516744.9847994833
15 437179.1469930584
16 374308.9150606445
17 323038.0436613461
18 280413.35989143944
19 244505.23690073355
20 213998.40318308977
21 187919.3598352012
22 165570.99855205472
23 146291.96216686617
24 129624.34209235801
25 115142.52093763268
26 102521.21627941245
27 91509.24334223256
28 81850.31242417428
29 73363.04494378745
30 65875.77053747339
31 59250.95987514621
32 53374.9188203175
33 48156.40677321945
34 43516.22889761285
35 39379.70513766825
36 35685.2101951626
37 32375.865870080317
38 29408.71910181395
39 26743.568581102052
40 24346.180860756453
41 22189.6828284217
42 20244.706183758823
43 18487.183033978414
44 16898.573063838045
45 15460.957604520017
46 14157.274000446334
