In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 37597084.0
1 36396412.0
2 38321832.0
3 35813600.0
4 27044666.0
5 15965046.0
6 8112039.5
7 4073650.5
8 2305108.0
9 1519853.625
10 1130337.0
11 903473.625
12 750821.0
13 637223.1875
14 547651.9375
15 474633.78125
16 414022.625
17 363067.75
18 319932.125
19 283089.0625
20 251429.25
21 224051.421875
22 200284.140625
23 179616.46875
24 161550.125
25 145687.21875
26 131698.4375
27 119329.6875
28 108340.0625
29 98568.6640625
30 89846.6953125
31 82038.1484375
32 75030.8125
33 68727.25
34 63046.0078125
35 57918.828125
36 53271.0546875
37 49058.03515625
38 45230.734375
39 41748.77734375
40 38578.5078125
41 35685.1015625
42 33040.01171875
43 30620.6796875
44 28404.685546875
45 26370.7265625
46 24502.478515625
47 22784.626953125
48 21203.703125
49 19746.087890625
50 18402.25
51 17161.53125
52 16014.94921875
53 14954.2666015625
54 13973.1083984375
55 13064.4462890625
56 12221.1123046875
57 11438.6298828125
58 10712.328125
59 10037.0205078125
60 9408.7763671875
61 8824.2646484375
62 8279.5859375
6

494 0.00016739772399887443
495 0.00016407747170887887
496 0.00016079966735560447
497 0.00015878243721090257
498 0.0001553332549519837
499 0.00015263931709341705
