
PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

<strong style="color:red">Rather than manually updating the weights of the model as we have been doing.</strong>,
we use the<strong style="color:red"> optim package.</strong> to define an Optimizer that will <strong style="color:green">update the weights for us</strong>. 

The optim package defines many optimization algorithms that are commonly used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [1]:
%matplotlib inline

<h1 style="background-image: linear-gradient( 135deg, #ABDCFF 10%, #0396FF 100%);"> Orinal Tutorial code

In [2]:
import torch
from torch.autograd import Variable

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# loss function
loss_fn = torch.nn.MSELoss(size_average=False)

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Variables it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its parameters
    optimizer.step()

0 682.8252563476562
1 666.3059692382812
2 650.2415161132812
3 634.7083740234375
4 619.6635131835938
5 605.11474609375
6 590.9830932617188
7 577.2006225585938
8 563.7789916992188
9 550.6875
10 537.9588623046875
11 525.55322265625
12 513.4520874023438
13 501.6575012207031
14 490.21295166015625
15 479.06243896484375
16 468.2439880371094
17 457.754150390625
18 447.5335388183594
19 437.612548828125
20 427.9945373535156
21 418.6335144042969
22 409.53851318359375
23 400.7070617675781
24 392.14263916015625
25 383.79864501953125
26 375.6753234863281
27 367.7168273925781
28 359.9151916503906
29 352.2444763183594
30 344.73577880859375
31 337.41595458984375
32 330.25653076171875
33 323.2464904785156
34 316.3948669433594
35 309.6763000488281
36 303.0809326171875
37 296.64385986328125
38 290.3456726074219
39 284.16912841796875
40 278.12176513671875
41 272.1761169433594
42 266.3402099609375
43 260.6186828613281
44 255.04672241210938
45 249.58338928222656
46 244.2247314453125
47 238.97825622558594
48 

400 1.2921124834974762e-05
401 1.2110701391065959e-05
402 1.1349991837050766e-05
403 1.0638404091878328e-05
404 9.96866583591327e-06
405 9.339854841527995e-06
406 8.751857421884779e-06
407 8.19921842776239e-06
408 7.680937414988875e-06
409 7.193659712356748e-06
410 6.736984687449876e-06
411 6.30889644526178e-06
412 5.907951617700746e-06
413 5.5324198910966516e-06
414 5.179457730264403e-06
415 4.8479714678251185e-06
416 4.537897439149674e-06
417 4.2478386603761464e-06
418 3.973785169364419e-06
419 3.719066171470331e-06
420 3.4795232295437017e-06
421 3.2558061775489477e-06
422 3.044882305403007e-06
423 2.8484348604251863e-06
424 2.6637344490154646e-06
425 2.491130317139323e-06
426 2.329381459276192e-06
427 2.177650230805739e-06
428 2.0360484995762818e-06
429 1.9024266748601804e-06
430 1.7789180901672808e-06
431 1.6623555438854964e-06
432 1.5532614270341583e-06
433 1.4511902008962352e-06
434 1.3558548062064801e-06
435 1.266283334189211e-06
436 1.1829897630377673e-06
437 1.1043914582842262