In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 32191600.0
1 30232470.0
2 32035356.0
3 32276274.0
4 27730908.0
5 19226408.0
6 11083586.0
7 5797139.0
8 3110949.0
9 1853443.25
10 1253224.375
11 936506.6875
12 746138.4375
13 616381.0625
14 519625.96875
15 443535.21875
16 381860.40625
17 330828.84375
18 287979.3125
19 251671.828125
20 220748.78125
21 194219.390625
22 171368.0625
23 151586.015625
24 134422.25
25 119478.453125
26 106431.3828125
27 94992.5859375
28 84936.9921875
29 76080.171875
30 68269.75
31 61371.59765625
32 55258.23828125
33 49823.29296875
34 44982.41015625
35 40665.76171875
36 36814.1171875
37 33371.6640625
38 30286.734375
39 27515.091796875
40 25026.23046875
41 22788.890625
42 20770.666015625
43 18948.271484375
44 17301.5078125
45 15811.5986328125
46 14462.28125
47 13238.865234375
48 12129.0341796875
49 11121.1162109375
50 10204.5400390625
51 9369.9296875
52 8609.4794921875
53 7915.71826171875
54 7282.58056640625
55 6704.294921875
56 6176.32861328125
57 5693.2236328125
58 5251.64697265625
59 4846.912109375
60 4475.8

494 2.9781964258290827e-05
495 2.9460790756274946e-05
496 2.9189997803769074e-05
497 2.8827062124037184e-05
498 2.8519962143036537e-05
499 2.825907176884357e-05
