In [1]:
import numpy as np

%matplotlib inline


Warm-up: numpy
--------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x using Euclidean error.

This implementation uses numpy to manually compute the forward pass, loss, and
backward pass.

A numpy array is a generic n-dimensional array; it does not know anything about
deep learning or gradients or computational graphs, and is just a way to perform
generic numeric computations.



In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N = 64
D_in = 1000
H = 100
D_out = 10

x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    h = np.dot(x, w1)
    h_relu = np.where(h > 0, h, 0)
    y_pred = np.dot(h_relu, w2)

    loss = np.sum(np.square(y_pred - y))
    print(t, loss)
    
    grad_loss_y = 2.0 * (y_pred - y)
    grad_loss_w2 = np.dot(np.transpose(h_relu), grad_loss_y)
    grad_loss_h = np.where(h > 0, np.dot(grad_loss_y, np.transpose(w2)), 0)
    grad_loss_w1 = np.dot(np.transpose(x), grad_loss_h)

    w1 -= learning_rate * grad_loss_w1
    w2 -= learning_rate * grad_loss_w2

0 25986978.236384504
1 19881125.675122853
2 16767597.530874074
3 14247396.701699208
4 11736421.343079653
5 9107428.283333126
6 6746153.545564149
7 4782602.011727516
8 3335868.413700737
9 2318628.702504961
10 1635533.5405566741
11 1178922.5305528657
12 874772.2469936967
13 668090.1505974202
14 524625.3120818299
15 422143.56370441953
16 346780.2955678095
17 289731.7959497693
18 245354.68823523683
19 210040.15010709787
20 181382.17866963515
21 157748.73158435617
22 138025.16897972894
23 121363.95229003113
24 107164.71523086357
25 94984.96239524276
26 84495.2278547584
27 75383.59771416945
28 67430.1614512668
29 60467.83921376068
30 54345.066693944165
31 48940.251447393995
32 44155.837295539706
33 39909.41989584151
34 36132.5897828081
35 32764.501493465763
36 29759.548481595237
37 27074.795507656658
38 24666.267204822747
39 22501.28766729821
40 20552.30632173452
41 18795.228120327403
42 17207.672198682078
43 15770.692619004498
44 14468.551521712934
45 13286.01003355156
46 12212.372414415393