In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 684.3551635742188
1 667.095458984375
2 650.3870239257812
3 634.1436767578125
4 618.408203125
5 603.1204833984375
6 588.3328857421875
7 574.05615234375
8 560.1829833984375
9 546.7208251953125
10 533.6954345703125
11 521.0040893554688
12 508.6482238769531
13 496.6545104980469
14 485.01629638671875
15 473.7292175292969
16 462.75445556640625
17 452.0337219238281
18 441.5627136230469
19 431.3260498046875
20 421.3327941894531
21 411.58489990234375
22 402.0726318359375
23 392.80609130859375
24 383.7644348144531
25 374.9657287597656
26 366.4366455078125
27 358.1143493652344
28 349.98919677734375
29 341.99493408203125
30 334.18548583984375
31 326.5727233886719
32 319.1500549316406
33 311.90655517578125
34 304.819580078125
35 297.8678894042969
36 291.08636474609375
37 284.437744140625
38 277.9030456542969
39 271.525146484375
40 265.302490234375
41 259.1946716308594
42 253.19956970214844
43 247.33604431152344
44 241.61395263671875
45 236.01527404785156
46 230.53538513183594
47 225.1658630371093

375 0.002005272079259157
376 0.0019399202428758144
377 0.0018767089350149035
378 0.0018155439756810665
379 0.0017563734436407685
380 0.0016991207376122475
381 0.001643732888624072
382 0.0015901431906968355
383 0.0015383128775283694
384 0.001488149631768465
385 0.0014395912876352668
386 0.0013926371466368437
387 0.0013472017599269748
388 0.0013032093411311507
389 0.0012606660602614284
390 0.001219500438310206
391 0.0011796746402978897
392 0.001141120446845889
393 0.001103815040551126
394 0.0010677268728613853
395 0.0010327977361157537
396 0.0009990156395360827
397 0.0009663134114816785
398 0.000934666080866009
399 0.0009040598524734378
400 0.0008744239457882941
401 0.0008457497460767627
402 0.0008179885917343199
403 0.0007911555003374815
404 0.0007651803316548467
405 0.0007400286267511547
406 0.0007156960782594979
407 0.0006921631866134703
408 0.0006693812320008874
409 0.0006473288522101939
410 0.0006260102964006364
411 0.0006053695688024163
412 0.0005853961920365691
413 0.0005660749156