In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4

for t in range(500):
    
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    
    loss = loss_fn(y_pred, y)
    
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 702.1624755859375
1 644.8729858398438
2 596.2704467773438
3 554.158203125
4 517.5407104492188
5 485.1929931640625
6 456.0840759277344
7 429.58087158203125
8 405.23162841796875
9 382.74615478515625
10 361.8940124511719
11 342.3479919433594
12 324.0294189453125
13 306.8989562988281
14 290.67108154296875
15 275.32373046875
16 260.77618408203125
17 246.9761199951172
18 233.8465576171875
19 221.32080078125
20 209.43507385253906
21 198.10658264160156
22 187.3758087158203
23 177.1374053955078
24 167.39028930664062
25 158.1318817138672
26 149.34852600097656
27 141.0047607421875
28 133.0923614501953
29 125.61160278320312
30 118.52947235107422
31 111.81786346435547
32 105.4681167602539
33 99.46053314208984
34 93.78005981445312
35 88.42120361328125
36 83.35578918457031
37 78.56588745117188
38 74.04135131835938
39 69.77720642089844
40 65.75477600097656
41 61.957584381103516
42 58.38185119628906
43 55.013404846191406
44 51.84122085571289
45 48.850852966308594
46 46.03667449951172
47 43.3819999694

362 9.086611680686474e-05
363 8.789556159172207e-05
364 8.50231881486252e-05
365 8.224906923715025e-05
366 7.955894398037344e-05
367 7.69707839936018e-05
368 7.446299423463643e-05
369 7.203932909760624e-05
370 6.968862726353109e-05
371 6.741877587046474e-05
372 6.522342300741002e-05
373 6.310307799139991e-05
374 6.105418287916109e-05
375 5.9069934650324285e-05
376 5.7153454690705985e-05
377 5.529707777895965e-05
378 5.3503765229834244e-05
379 5.1766975957434624e-05
380 5.00904097862076e-05
381 4.8469250032212585e-05
382 4.689871275331825e-05
383 4.537718996289186e-05
384 4.3909818487009034e-05
385 4.24871759605594e-05
386 4.1115028579952195e-05
387 3.978548920713365e-05
388 3.850062421406619e-05
389 3.725615533767268e-05
390 3.605216261348687e-05
391 3.4887292713392526e-05
392 3.376423046574928e-05
393 3.267428473918699e-05
394 3.161907443427481e-05
395 3.060369999730028e-05
396 2.9615610401378945e-05
397 2.866145587177016e-05
398 2.774038875941187e-05
399 2.684600076463539e-05
400 2.5