In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 673.6611328125
1 620.7244873046875
2 575.39306640625
3 535.7691650390625
4 500.5320129394531
5 468.818359375
6 440.01556396484375
7 413.6776123046875
8 389.4150390625
9 366.7279357910156
10 345.65960693359375
11 325.8575439453125
12 307.1839904785156
13 289.67047119140625
14 273.0721130371094
15 257.2470703125
16 242.19493103027344
17 227.973388671875
18 214.5721893310547
19 201.8910675048828
20 189.86151123046875
21 178.47463989257812
22 167.697509765625
23 157.4932098388672
24 147.8601531982422
25 138.77113342285156
26 130.20680236816406
27 122.16253662109375
28 114.59638977050781
29 107.49974060058594
30 100.84286499023438
31 94.61349487304688
32 88.7770767211914
33 83.3063735961914
34 78.17693328857422
35 73.3779525756836
36 68.88845825195312
37 64.68101501464844
38 60.749969482421875
39 57.07925033569336
40 53.64588165283203
41 50.437843322753906
42 47.441776275634766
43 44.637664794921875
44 42.02363586425781
45 39.57916259765625
46 37.29598617553711
47 35.16404724121094
48 33.

422 0.0006081120809540153
423 0.0005958164110779762
424 0.0005837808130308986
425 0.0005719895707443357
426 0.0005604309844784439
427 0.0005491272895596921
428 0.0005380465881898999
429 0.000527192372828722
430 0.000516562897246331
431 0.000506150652654469
432 0.0004959469079039991
433 0.0004859517503064126
434 0.0004761554882861674
435 0.00046656333142891526
436 0.000457166024716571
437 0.0004479571362026036
438 0.00043893742258660495
439 0.00043011686648242176
440 0.0004214527434669435
441 0.00041298221913166344
442 0.0004046796530019492
443 0.00039654469583183527
444 0.00038857170147821307
445 0.00038076157215982676
446 0.00037311026244424284
447 0.0003656221379060298
448 0.00035828116233460605
449 0.00035108078736811876
450 0.00034403454628773034
451 0.0003371344646438956
452 0.0003303627308923751
453 0.0003237365745007992
454 0.0003172425786033273
455 0.0003108739329036325
456 0.0003046443744096905
457 0.0002985339378938079
458 0.00029255577828735113
459 0.000286688533378765
460 0