## 使用pytorch中的nn来构建网络。

用pytorch autograd来构建计算图和计算gradients,然后pytorch会帮我们自动计算gradient。也就是说反向求导是自动的。

In [10]:
import torch
import torch.nn as nn


In [11]:
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H, bias=True),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-3
for i in range(100):
    # forward pass
    y_pred = model(x)

    # loss
    loss = loss_fn(y_pred, y)
    print(i, loss.item())
    # Backward pass
    loss.backward()

    with torch.no_grad():
        for p in model.parameters():
            p -= learning_rate * p.grad
    model.zero_grad()


0 706.3843994140625
1 374.1192321777344
2 228.05047607421875
3 129.03292846679688
4 69.18785858154297
5 36.58161544799805
6 19.765111923217773
7 11.099528312683105
8 6.788845539093018
9 5.079034328460693
10 5.8254265785217285
11 10.73802375793457
12 25.386960983276367
13 61.42582321166992
14 136.05303955078125
15 218.1109619140625
16 203.9059295654297
17 74.78096771240234
18 18.580764770507812
19 5.456915855407715
20 2.317540168762207
21 1.2080621719360352
22 0.7092007994651794
23 0.44692322611808777
24 0.29530179500579834
25 0.2022421807050705
26 0.14213846623897552
27 0.10207336395978928
28 0.07448776066303253
29 0.05518586188554764
30 0.04135351628065109
31 0.031392280012369156
32 0.024075591936707497
33 0.018648957833647728
34 0.014576628804206848
35 0.011498242616653442
36 0.009144549258053303
37 0.007333141751587391
38 0.00592443160712719
39 0.004821091890335083
40 0.003949672449380159
41 0.0032571072224527597
42 0.0027022019494324923
43 0.002255238825455308
44 0.0018920354777947

In [12]:
model[0].weight

Parameter containing:
tensor([[-0.0381, -0.0235, -0.0215,  ...,  0.0148,  0.0135,  0.0084],
        [ 0.0263,  0.0207,  0.0099,  ..., -0.0064,  0.0003,  0.0236],
        [ 0.0332, -0.0011,  0.0023,  ..., -0.0158, -0.0149, -0.0265],
        ...,
        [-0.0266,  0.0055,  0.0050,  ...,  0.0274, -0.0259, -0.0034],
        [ 0.0089, -0.0148, -0.0422,  ..., -0.0373, -0.0153, -0.0336],
        [-0.0235,  0.0014,  0.0119,  ...,  0.0402,  0.0228, -0.0072]],
       requires_grad=True)

In [13]:
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H, bias=True),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i in range(100):
    # forward pass
    y_pred = model(x)

    # loss
    loss = loss_fn(y_pred, y)
    print(i, loss.item())

    optimizer.zero_grad()
    # Backward pass
    loss.backward()

    optimizer.step()


0 694.5061645507812
1 543.2144775390625
2 427.7025146484375
3 341.1319274902344
4 273.02130126953125
5 217.68820190429688
6 172.48863220214844
7 135.4821319580078
8 105.56867218017578
9 81.93224334716797
10 63.56071090698242
11 49.58879089355469
12 39.196964263916016
13 31.796823501586914
14 26.697673797607422
15 23.431018829345703
16 21.525157928466797
17 20.461130142211914
18 19.76848602294922
19 19.157073974609375
20 18.417787551879883
21 17.432016372680664
22 16.20676040649414
23 14.83311653137207
24 13.398720741271973
25 11.946622848510742
26 10.532184600830078
27 9.204917907714844
28 8.004523277282715
29 6.958259105682373
30 6.075457572937012
31 5.354724884033203
32 4.762564659118652
33 4.276313781738281
34 3.874342918395996
35 3.534607172012329
36 3.24225115776062
37 2.9839839935302734
38 2.751889705657959
39 2.536383867263794
40 2.327824592590332
41 2.1199288368225098
42 1.9123057126998901
43 1.7117180824279785
44 1.5266356468200684
45 1.3630446195602417
46 1.2240729331970215
4

77 0.050906624644994736
78 0.04516780748963356
79 0.040317803621292114
80 0.036278627812862396
81 0.03283309563994408
82 0.02978379651904106
83 0.027047086507081985
84 0.02462133765220642
85 0.022494500502943993
86 0.02059607021510601
87 0.018833370879292488
88 0.017159709706902504
89 0.01559380441904068
90 0.014173580333590508
91 0.012894760817289352
92 0.011700402945280075
93 0.010528761893510818
94 0.009368537925183773
95 0.008268062025308609
96 0.007292214781045914
97 0.006472138222306967
98 0.0057903132401406765
99 0.005207961890846491


## build model with torch.nn.Module

In [14]:
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        y_pred = self.linear2(self.linear1(x).clamp(min=0))
        return y_pred


model = TwoLayerNet(D_in, H, D_out)

loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i in range(100):
    # forward pass
    y_pred = model(x)

    # loss
    loss = loss_fn(y_pred, y)
    print(i, loss.item())

    optimizer.zero_grad()
    # Backward pass
    loss.backward()

    optimizer.step()


0 705.3535766601562
1 549.8038330078125
2 432.47235107421875
3 340.1869201660156
4 266.9209899902344
5 207.5878448486328
6 159.68301391601562
7 121.83702850341797
8 92.56317138671875
9 70.62801361083984
10 54.74915313720703
11 43.72365951538086
12 36.36562728881836
13 31.69542694091797
14 28.907480239868164
15 27.28572654724121
16 26.213579177856445
17 25.195505142211914
18 23.968252182006836
19 22.37776756286621
20 20.49827766418457
21 18.427186965942383
22 16.359954833984375
23 14.394976615905762
24 12.605077743530273
25 11.017556190490723
26 9.630640983581543
27 8.441401481628418
28 7.430429935455322
29 6.575705051422119
30 5.866136074066162
31 5.289244174957275
32 4.814431190490723
33 4.417747974395752
34 4.073944568634033
35 3.765989303588867
36 3.482741117477417
37 3.2172741889953613
38 2.96570086479187
39 2.722938299179077
40 2.484678030014038
41 2.248218059539795
42 2.012460231781006
43 1.7819868326187134
44 1.5640674829483032
45 1.3666749000549316
46 1.1954727172851562
47 1.05

89 0.01598650962114334
90 0.014504293911159039
91 0.013121262192726135
92 0.011874046176671982
93 0.010766381397843361
94 0.009767632931470871
95 0.008833660744130611
96 0.007932527922093868
97 0.007055517751723528
98 0.006215090863406658
99 0.0054369112476706505
