# PYTORCH: OPTIM

全连接ReLU网络：1层隐藏层，根据𝑥预测𝑦，通过最小化欧氏距离训练网络。

使用optim包定义优化器，自动更新权值。

In [3]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=D_in, out_features=H),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=H, out_features=D_out),
)
loss_fn = torch.nn.MSELoss(reduction="sum")

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)
    
    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()
    
    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()
    
    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()
    

0 679.3148193359375
1 662.0254516601562
2 645.2996215820312
3 629.042236328125
4 613.2485961914062
5 597.9348754882812
6 583.0579833984375
7 568.6455078125
8 554.6974487304688
9 541.0750122070312
10 527.81591796875
11 514.846435546875
12 502.2281494140625
13 489.99932861328125
14 478.1239318847656
15 466.57342529296875
16 455.3585510253906
17 444.4521179199219
18 433.8395690917969
19 423.54534912109375
20 413.5301513671875
21 403.74346923828125
22 394.2352294921875
23 384.95343017578125
24 375.8745422363281
25 367.0218505859375
26 358.34820556640625
27 349.8387451171875
28 341.48297119140625
29 333.332275390625
30 325.3876953125
31 317.633544921875
32 310.0392761230469
33 302.6335144042969
34 295.3958740234375
35 288.3063049316406
36 281.4027099609375
37 274.6609802246094
38 268.0578308105469
39 261.5999450683594
40 255.25628662109375
41 249.03692626953125
42 242.9569854736328
43 237.01266479492188
44 231.17845153808594
45 225.4643096923828
46 219.85177612304688
47 214.35244750976562
4

359 8.31662691780366e-05
360 7.851669943192974e-05
361 7.412872946588323e-05
362 6.99858574080281e-05
363 6.607697287108749e-05
364 6.23852465650998e-05
365 5.889864405617118e-05
366 5.560982026509009e-05
367 5.250363756204024e-05
368 4.957564306096174e-05
369 4.680691927205771e-05
370 4.419451579451561e-05
371 4.172870103502646e-05
372 3.940042370231822e-05
373 3.7203222746029496e-05
374 3.512845796649344e-05
375 3.317007940495387e-05
376 3.132053097942844e-05
377 2.9574674044852145e-05
378 2.7924481400987133e-05
379 2.6368737962911837e-05
380 2.4901144570321776e-05
381 2.3512726329499856e-05
382 2.220225906057749e-05
383 2.0965231669833884e-05
384 1.9797291315626353e-05
385 1.869399238785263e-05
386 1.7652260794420727e-05
387 1.6668012904119678e-05
388 1.5738893125671893e-05
389 1.486267865402624e-05
390 1.403445639880374e-05
391 1.3252568351163063e-05
392 1.2512199646153022e-05
393 1.1815413017757237e-05
394 1.115834584197728e-05
395 1.0535924047871958e-05
396 9.947159014700446e-06
