In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 685.8623657226562
1 634.03369140625
2 589.7142944335938
3 551.1286010742188
4 517.04150390625
5 486.3191223144531
6 458.5585632324219
7 433.0906677246094
8 409.5806884765625
9 387.8254699707031
10 367.6499328613281
11 348.76806640625
12 331.0064697265625
13 314.1520690917969
14 298.1526184082031
15 283.01666259765625
16 268.58380126953125
17 254.8072509765625
18 241.6618194580078
19 229.14276123046875
20 217.12139892578125
21 205.60546875
22 194.60401916503906
23 184.09361267089844
24 174.05401611328125
25 164.4669647216797
26 155.3260955810547
27 146.62376403808594
28 138.3309326171875
29 130.44900512695312
30 122.9742431640625
31 115.88301849365234
32 109.15935516357422
33 102.78791809082031
34 96.75810241699219
35 91.06149291992188
36 85.68048858642578
37 80.62894439697266
38 75.85957336425781
39 71.36870574951172
40 67.14002990722656
41 63.160987854003906
42 59.418792724609375
43 55.9035758972168
44 52.60743713378906
45 49.52303695678711
46 46.633785247802734
47 43.92216491699219

429 0.0002673197886906564
430 0.0002614241966512054
431 0.0002556674007792026
432 0.0002500360133126378
433 0.00024452892830595374
434 0.0002391437883488834
435 0.00023387608234770596
436 0.00022873243142385036
437 0.00022369532962329686
438 0.0002187759819207713
439 0.00021395696967374533
440 0.0002092450304189697
441 0.00020464372937567532
442 0.00020014327310491353
443 0.00019573845202103257
444 0.00019143553799949586
445 0.0001872278080554679
446 0.00018311179883312434
447 0.00017908288282342255
448 0.00017515482613816857
449 0.00017130363266915083
450 0.000167535908985883
451 0.00016385794151574373
452 0.00016025349032133818
453 0.00015673530288040638
454 0.0001532882743049413
455 0.0001499232603237033
456 0.00014663349429611117
457 0.0001434114237781614
458 0.0001402614580001682
459 0.0001371874095639214
460 0.00013417701120488346
461 0.00013122592645231634
462 0.00012835148663725704
463 0.00012553589476738125
464 0.0001227759348694235
465 0.00012008831981802359
466 0.00011745431