In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:1") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 29266960.0
1 25477014.0
2 27260208.0
3 30241428.0
4 30207164.0
5 25019268.0
6 16550626.0
7 9194384.0
8 4705195.0
9 2503613.25
10 1491022.125
11 1013848.625
12 764834.1875
13 615944.9375
14 514465.5625
15 438525.125
16 378407.875
17 329093.71875
18 287852.5
19 252939.140625
20 223106.5625
21 197447.125
22 175266.234375
23 156014.3125
24 139235.984375
25 124555.3046875
26 111688.09375
27 100362.8203125
28 90364.3828125
29 81518.0703125
30 73668.6640625
31 66682.203125
32 60449.52734375
33 54883.21875
34 49919.86328125
35 45471.68359375
36 41476.265625
37 37879.41796875
38 34634.4921875
39 31704.412109375
40 29053.955078125
41 26653.27734375
42 24475.0546875
43 22497.203125
44 20697.529296875
45 19058.849609375
46 17564.28515625
47 16199.8564453125
48 14952.7568359375
49 13811.7900390625
50 12768.75
51 11813.8359375
52 10937.40625
53 10133.0517578125
54 9393.5966796875
55 8713.83984375
56 8088.37890625
57 7512.13037109375
58 6980.56298828125
59 6490.02099609375
60 6037.01953125
61 5618.

383 0.00029685787740163505
384 0.0002885052526835352
385 0.0002808722492773086
386 0.0002728764957282692
387 0.00026524218264967203
388 0.00025755914975889027
389 0.00025115051539614797
390 0.0002441358083160594
391 0.00023690806119702756
392 0.0002304964727954939
393 0.00022458033345174044
394 0.00021919768187217414
395 0.00021322167594917119
396 0.0002077500830637291
397 0.00020248429791536182
398 0.00019730074563995004
399 0.00019228708697482944
400 0.00018758948135655373
401 0.00018313754117116332
402 0.00017861778906080872
403 0.00017444466357119381
404 0.00017021280655171722
405 0.0001659562112763524
406 0.00016227713786065578
407 0.00015833263751119375
408 0.00015444692689925432
409 0.00015071018424350768
410 0.00014737022866029292
411 0.00014419946819543839
412 0.00014061227557249367
413 0.00013734021922573447
414 0.0001345641940133646
415 0.00013141818635631353
416 0.00012852488725911826
417 0.0001258644333574921
418 0.00012298462388571352
419 0.00012033754319418222
420 0.0001