In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [3]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
print(type(model))
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

<class 'torch.nn.modules.container.Sequential'>
0 666.5945434570312
1 649.6112670898438
2 633.0699462890625
3 617.045166015625
4 601.45556640625
5 586.2740478515625
6 571.52734375
7 557.2686157226562
8 543.4965209960938
9 530.1268310546875
10 517.22900390625
11 504.75897216796875
12 492.6857604980469
13 481.0381164550781
14 469.6769104003906
15 458.62738037109375
16 447.8845520019531
17 437.5234069824219
18 427.3841857910156
19 417.49114990234375
20 407.8642578125
21 398.55810546875
22 389.5085144042969
23 380.66082763671875
24 372.0184631347656
25 363.55621337890625
26 355.34039306640625
27 347.3174743652344
28 339.4515686035156
29 331.77410888671875
30 324.29571533203125
31 316.9756774902344
32 309.8128967285156
33 302.79852294921875
34 295.93212890625
35 289.1980895996094
36 282.59356689453125
37 276.1291809082031
38 269.82037353515625
39 263.6599426269531
40 257.6123352050781
41 251.69802856445312
42 245.90029907226562
43 240.2178192138672
44 234.63107299804688
45 229.1427764892578

355 4.7490346332779154e-05
356 4.4276730477577075e-05
357 4.127009378862567e-05
358 3.846517574856989e-05
359 3.5844608646584675e-05
360 3.339902832522057e-05
361 3.111616024398245e-05
362 2.8982331059523858e-05
363 2.6995086955139413e-05
364 2.5137875127256848e-05
365 2.3404310923069715e-05
366 2.1788309823023155e-05
367 2.027990558417514e-05
368 1.8873941371566616e-05
369 1.7564225345267914e-05
370 1.6341353330062702e-05
371 1.5203106158878654e-05
372 1.4140056919131894e-05
373 1.3150803169992287e-05
374 1.22278379421914e-05
375 1.1370153515599668e-05
376 1.0568590369075537e-05
377 9.822831088968087e-06
378 9.128914825851098e-06
379 8.481882105115801e-06
380 7.879113582021091e-06
381 7.319652468140703e-06
382 6.79758613841841e-06
383 6.313345693342853e-06
384 5.862184025318129e-06
385 5.441874236566946e-06
386 5.051737389294431e-06
387 4.68843745693448e-06
388 4.34963840234559e-06
389 4.035430720250588e-06
390 3.74409705727885e-06
391 3.4722083910310175e-06
392 3.2202092370425817e-06