In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [4]:
import torch


dtype = torch.float
#device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initializex weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6

for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 37922340.0
1 37764680.0
2 40290936.0
3 37524912.0
4 27529336.0
5 15551603.0
6 7563557.0
7 3748841.0
8 2162891.75
9 1474155.5
10 1126222.125
11 914547.25
12 765817.25
13 652067.375
14 560799.875
15 485737.4375
16 423121.84375
17 370342.53125
18 325540.03125
19 287313.65625
20 254495.09375
21 226171.5625
22 201624.96875
23 180232.65625
24 161531.6875
25 145097.5
26 130624.3984375
27 117835.765625
28 106504.578125
29 96427.6328125
30 87446.25
31 79422.84375
32 72244.6875
33 65812.1796875
34 60027.03125
35 54815.4296875
36 50119.85546875
37 45878.2734375
38 42041.3203125
39 38564.57421875
40 35427.8359375
41 32590.775390625
42 30011.955078125
43 27660.66796875
44 25513.615234375
45 23553.919921875
46 21761.427734375
47 20119.728515625
48 18614.140625
49 17233.0390625
50 15964.802734375
51 14800.240234375
52 13727.68359375
53 12739.7392578125
54 11828.814453125
55 10988.759765625
56 10213.078125
57 9497.04296875
58 8835.4912109375
59 8224.15625
60 7658.0234375
61 7133.5205078125
62 6647.7