# PYTORCH: CUSTOM NN MODULES

全连接ReLU网络：1层隐藏层，根据𝑥预测𝑦，通过最小化欧氏距离训练网络。

自定义模块子类。

In [1]:
import torch

class TwoLayerNet(torch.nn.Module):
    
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(in_features=D_in, out_features=H)
        self.linear2 = torch.nn.Linear(in_features=H, out_features=D_out)
        
    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
    
    
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)
    
    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


0 614.4730224609375
1 571.2723999023438
2 533.4083862304688
3 499.89703369140625
4 469.91326904296875
5 442.6645202636719
6 417.9061279296875
7 395.0079040527344
8 373.72412109375
9 353.9446105957031
10 335.28997802734375
11 317.7622375488281
12 301.2640075683594
13 285.65460205078125
14 270.8759460449219
15 256.83514404296875
16 243.431640625
17 230.66954040527344
18 218.50888061523438
19 206.8704071044922
20 195.75662231445312
21 185.15740966796875
22 175.07965087890625
23 165.48141479492188
24 156.36143493652344
25 147.69451904296875
26 139.45596313476562
27 131.65579223632812
28 124.26245880126953
29 117.2536392211914
30 110.60166931152344
31 104.31695556640625
32 98.3486557006836
33 92.7000961303711
34 87.35453033447266
35 82.310546875
36 77.52375030517578
37 73.00260162353516
38 68.74864959716797
39 64.73233795166016
40 60.94955825805664
41 57.3871955871582
42 54.03950119018555
43 50.89844512939453
44 47.94298553466797
45 45.17049789428711
46 42.5623779296875
47 40.12022399902344

385 0.00020943707204423845
386 0.0002035634097410366
387 0.00019786832854151726
388 0.00019232045451644808
389 0.00018694270693231374
390 0.00018171586270909756
391 0.00017663381004240364
392 0.0001716957922326401
393 0.000166904108482413
394 0.00016224815044552088
395 0.0001577242510393262
396 0.00015332907787524164
397 0.00014905216812621802
398 0.00014490271860267967
399 0.00014086998999118805
400 0.00013694957306142896
401 0.00013314424722921103
402 0.00012944734771735966
403 0.00012584842625074089
404 0.0001223535800818354
405 0.00011896465730387717
406 0.00011566530884010717
407 0.00011245830683037639
408 0.00010934476449619979
409 0.00010631532495608553
410 0.00010337711137253791
411 0.00010051595018012449
412 9.773592319106683e-05
413 9.50377871049568e-05
414 9.241594671038911e-05
415 8.986311149783432e-05
416 8.738879841985181e-05
417 8.497518865624443e-05
418 8.263309428002685e-05
419 8.036009239731357e-05
420 7.81456838012673e-05
421 7.599515811307356e-05
422 7.3905219323933