In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 31738264.0
1 30105388.0
2 35463872.0
3 41497176.0
4 40875664.0
5 29998896.0
6 16078730.0
7 7013804.5
8 3112514.0
9 1682273.375
10 1133358.125
11 876578.125
12 723564.875
13 614592.1875
14 529287.625
15 459575.25
16 401391.40625
17 352274.8125
18 310460.875
19 274652.625
20 243859.578125
21 217208.625
22 194055.46875
23 173863.703125
24 156179.71875
25 140648.875
26 126961.3203125
27 114862.890625
28 104135.7265625
29 94607.5078125
30 86096.6640625
31 78484.921875
32 71660.984375
33 65545.734375
34 60055.83203125
35 55100.7109375
36 50625.30078125
37 46575.828125
38 42899.453125
39 39557.7109375
40 36516.3984375
41 33745.265625
42 31217.908203125
43 28915.62890625
44 26806.78515625
45 24871.3359375
46 23094.0234375
47 21459.619140625
48 19956.6015625
49 18574.935546875
50 17299.205078125
51 16121.994140625
52 15033.3671875
53 14026.140625
54 13092.8037109375
55 12227.880859375
56 11425.6328125
57 10681.171875
58 9989.642578125
59 9346.8232421875
60 8748.783203125
61 8192.0986328125
62

430 0.00014780036872252822
431 0.0001450083072995767
432 0.00014158798148855567
433 0.00013839598977938294
434 0.00013615578063763678
435 0.0001333008403889835
436 0.00013010657858103514
437 0.0001278881391044706
438 0.00012472557136788964
439 0.0001223755971295759
440 0.00011991038627456874
441 0.00011754576553357765
442 0.00011532929784152657
443 0.00011320968042127788
444 0.00011128491314593703
445 0.00010905153612839058
446 0.00010706766624934971
447 0.00010456096788402647
448 0.00010271654173266143
449 0.00010082474182127044
450 9.896120900521055e-05
451 9.743105329107493e-05
452 9.535052231512964e-05
453 9.362763375975192e-05
454 9.207546099787578e-05
455 9.067733481060714e-05
456 8.903218258637935e-05
457 8.724925282876939e-05
458 8.598413114668801e-05
459 8.47732080728747e-05
460 8.335051825270057e-05
461 8.177434210665524e-05
462 8.005530980881304e-05
463 7.890181586844847e-05
464 7.763069879729301e-05
465 7.635540532646701e-05
466 7.491218275390565e-05
467 7.394140993710607e-