In [1]:
import torch


dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 35090523.0278067
1 29045796.3823992
2 24696500.58335717
3 19532206.76665774
4 13865405.691609848
5 9010140.522053719
6 5602638.522759974
7 3508884.164679032
8 2296756.097012648
9 1599564.2936631443
10 1184586.0809310973
11 923193.1286269674
12 747199.3278339786
13 620919.9669531495
14 525604.7280968336
15 450679.5152780383
16 390128.00666437706
17 340274.04284718144
18 298519.4457290344
19 263227.3224592699
20 233084.9766429955
21 207147.66269326932
22 184689.81596939627
23 165166.52417192582
24 148104.93311139056
25 133141.2933900836
26 119981.37532983854
27 108355.64432906895
28 98037.45137331185
29 88860.76190719994
30 80668.83764731011
31 73355.82484234268
32 66797.2489586891
33 60903.0263983456
34 55600.63271949979
35 50819.869893130264
36 46499.473022608785
37 42596.865354055306
38 39075.93243163888
39 35883.38306151076
40 32978.34812050893
41 30334.72332485122
42 27925.154629069715
43 25726.81654360704
44 23719.933755257633
45 21884.86315289595
46 20205.09613119097
47 18667.34