# PYTORCH: NN

全连接ReLU网络：1层隐藏层，根据𝑥预测𝑦，通过最小化欧氏距离训练网络。

使用torch.nn包构建网络。autograd为底层函数，定义复杂网络时，不够方便，可以使用nn包。nn包定义了一组模块，可视为一个神经网络层。

In [2]:
import torch
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction="sum")

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)
    
    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    # Zero the gradients before running the backward pass.
    model.zero_grad()
    
    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()
    
    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 703.64111328125
1 652.8242797851562
2 608.5263061523438
3 569.1094360351562
4 533.6871948242188
5 501.4747009277344
6 471.9813537597656
7 444.6668701171875
8 419.27532958984375
9 395.6815185546875
10 373.56927490234375
11 352.96929931640625
12 333.67041015625
13 315.3907470703125
14 298.0212097167969
15 281.6432189941406
16 266.0712585449219
17 251.33004760742188
18 237.2996826171875
19 223.9430389404297
20 211.2364959716797
21 199.1659698486328
22 187.72239685058594
23 176.8980712890625
24 166.62005615234375
25 156.8442840576172
26 147.60935974121094
27 138.8663330078125
28 130.57003784179688
29 122.76390075683594
30 115.40692901611328
31 108.47644805908203
32 101.94702911376953
33 95.80756378173828
34 90.04463958740234
35 84.63249969482422
36 79.52958679199219
37 74.7374267578125
38 70.23615264892578
39 66.01703643798828
40 62.072113037109375
41 58.36521530151367
42 54.88894271850586
43 51.63022232055664
44 48.57445526123047
45 45.714508056640625
46 43.036659240722656
47 40.5281829

378 0.0004979411023668945
379 0.0004861084744334221
380 0.0004745707556139678
381 0.0004633040225598961
382 0.0004523315583355725
383 0.00044161872938275337
384 0.000431168737122789
385 0.00042097672121599317
386 0.0004110296431463212
387 0.00040133422589860857
388 0.00039186514914035797
389 0.0003826353349722922
390 0.0003736171929631382
391 0.0003648261190392077
392 0.0003562507627066225
393 0.00034789086203090847
394 0.0003397151886019856
395 0.0003317527298349887
396 0.00032397458562627435
397 0.0003163834335282445
398 0.0003089886449743062
399 0.0003017618437297642
400 0.000294710072921589
401 0.00028783048037439585
402 0.0002811135200317949
403 0.0002745606761891395
404 0.0002681627811398357
405 0.00026191677898168564
406 0.00025582939269952476
407 0.00024988409131765366
408 0.00024408216995652765
409 0.00023842080554459244
410 0.00023288781812880188
411 0.00022748597257304937
412 0.00022221784456633031
413 0.0002170747466152534
414 0.0002120579592883587
415 0.0002071506896754726