In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 679.9183959960938
1 662.9070434570312
2 646.473388671875
3 630.4984741210938
4 615.0211181640625
5 600.1182250976562
6 585.6351318359375
7 571.5257568359375
8 557.7781982421875
9 544.4041137695312
10 531.518310546875
11 519.05224609375
12 506.8771057128906
13 495.04608154296875
14 483.6139221191406
15 472.4757995605469
16 461.6009826660156
17 451.005615234375
18 440.6210632324219
19 430.47998046875
20 420.5804443359375
21 410.8900451660156
22 401.4541015625
23 392.2784729003906
24 383.32794189453125
25 374.6073913574219
26 366.1451416015625
27 357.86370849609375
28 349.7484130859375
29 341.8279724121094
30 334.140869140625
31 326.6395568847656
32 319.3004455566406
33 312.0951232910156
34 305.0321350097656
35 298.0843505859375
36 291.26312255859375
37 284.5833435058594
38 278.0383605957031
39 271.6129455566406
40 265.3201904296875
41 259.1630859375
42 253.11764526367188
43 247.20481872558594
44 241.4191436767578
45 235.76377868652344
46 230.22100830078125
47 224.77601623535156
48 219.

367 0.00014986732276156545
368 0.00014143975568003953
369 0.00013346923515200615
370 0.00012593058636412024
371 0.00011880750389536843
372 0.0001120663364417851
373 0.0001056940745911561
374 9.966887591872364e-05
375 9.397224494023249e-05
376 8.859475201461464e-05
377 8.350666757905856e-05
378 7.870786066632718e-05
379 7.416665903292596e-05
380 6.988333188928664e-05
381 6.583752838196233e-05
382 6.201442010933533e-05
383 5.840765516040847e-05
384 5.500317638507113e-05
385 5.17895859957207e-05
386 4.875446029473096e-05
387 4.589304080582224e-05
388 4.319365689298138e-05
389 4.06493891205173e-05
390 3.824679515673779e-05
391 3.598061084630899e-05
392 3.384935553185642e-05
393 3.183273656759411e-05
394 2.9935612474218942e-05
395 2.8149368517915718e-05
396 2.646451503096614e-05
397 2.4875349481590092e-05
398 2.3378959667752497e-05
399 2.1970863599563017e-05
400 2.0644551113946363e-05
401 1.9393857655813918e-05
402 1.8218435798189603e-05
403 1.711109871394001e-05
404 1.6068479453679174e-05
