In [None]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 647.7970581054688
1 631.4168090820312
2 615.5759887695312
3 600.192626953125
4 585.282958984375
5 570.8240966796875
6 556.8074951171875
7 543.2206420898438
8 530.08837890625
9 517.2682495117188
10 504.80755615234375
11 492.6728820800781
12 480.89227294921875
13 469.4423828125
14 458.26824951171875
15 447.3427429199219
16 436.7566223144531
17 426.5788879394531
18 416.7215881347656
19 407.1139221191406
20 397.75689697265625
21 388.6580505371094
22 379.7793273925781
23 371.06695556640625
24 362.56378173828125
25 354.2771911621094
26 346.1903381347656
27 338.27459716796875
28 330.576416015625
29 323.1023254394531
30 315.7698974609375
31 308.62615966796875
32 301.6797180175781
33 294.9222412109375
34 288.3114929199219
35 281.8810119628906
36 275.59588623046875
37 269.4170837402344
38 263.3517150878906
39 257.4182434082031
40 251.58865356445312
41 245.87501525878906
42 240.28076171875
43 234.789306640625
44 229.40037536621094
45 224.10150146484375
46 218.88644409179688
47 213.7545623779297

458 1.725835971910783e-07
459 1.6128875302001688e-07
460 1.5062214231420512e-07
461 1.4068280052015325e-07
462 1.3132776643942634e-07
463 1.225586316877525e-07
464 1.1442126179872503e-07
465 1.0669930361473234e-07
466 9.95835094386166e-08
467 9.291969860214522e-08
468 8.656142824747803e-08
469 8.081198643594689e-08
470 7.533245849344894e-08
471 7.018574166295366e-08
472 6.546056141587542e-08
473 6.108065520038508e-08
474 5.6924839952898765e-08
475 5.308511319412901e-08
476 4.956076438134005e-08
477 4.6111967577644464e-08
478 4.301745448742622e-08
479 4.007460319144229e-08
480 3.726602670894863e-08
481 3.475969734267892e-08
482 3.2316105347263147e-08
483 3.009715499047161e-08
484 2.803743193169339e-08
485 2.616223859774891e-08
486 2.4328963732500597e-08
487 2.2695312296150405e-08
488 2.115428543447706e-08
489 1.96442897504312e-08
490 1.829220508398066e-08
491 1.7044799349719142e-08
492 1.584763609230322e-08
493 1.4740689557868336e-08
494 1.3708579160720546e-08
495 1.2774322044606379e-08