# PYTORCH: CONTROL FLOW + WEIGHT SHARING

PyTorch动态计算图

In [2]:
import torch
import random

class DynamicNet(torch.nn.Module):
    
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(in_features=D_in, out_features=H)
        self.middle_linear = torch.nn.Linear(in_features=H, out_features=H)
        self.output_linear = torch.nn.Linear(in_features=H, out_features=D_out)
        
    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
    
    
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)
    
    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 653.5411987304688
1 651.8615112304688
2 655.7376098632812
3 651.1415405273438
4 638.6050415039062
5 640.2182006835938
6 647.7581787109375
7 676.2109375
8 645.4925537109375
9 589.817138671875
10 629.1117553710938
11 642.6788330078125
12 625.094970703125
13 640.9523315429688
14 639.9329223632812
15 472.81048583984375
16 542.454345703125
17 408.0365905761719
18 362.7013244628906
19 632.7897338867188
20 630.42431640625
21 627.1859130859375
22 622.9035034179688
23 617.2451782226562
24 164.66934204101562
25 450.02606201171875
26 429.7854919433594
27 539.9929809570312
28 518.298095703125
29 120.62198638916016
30 529.7815551757812
31 113.78690338134766
32 486.1651611328125
33 401.12506103515625
34 265.93402099609375
35 355.46856689453125
36 95.0851058959961
37 215.71693420410156
38 80.07152557373047
39 179.13165283203125
40 321.51385498046875
41 57.68758010864258
42 137.3516082763672
43 269.7030334472656
44 112.98947143554688
45 53.676483154296875
46 188.59181213378906
47 161.91036987304688


426 0.08325709402561188
427 0.5386808514595032
428 0.5338664054870605
429 0.4145415425300598
430 0.5375057458877563
431 0.4967431426048279
432 0.46232837438583374
433 0.45350271463394165
434 0.4851626455783844
435 0.38303664326667786
436 0.15753160417079926
437 0.5679344534873962
438 0.2969215512275696
439 0.17010730504989624
440 0.15757006406784058
441 0.12982314825057983
442 0.3040786385536194
443 0.6061858534812927
444 0.5197489857673645
445 0.09443797171115875
446 0.11163611710071564
447 0.7144278287887573
448 0.501566469669342
449 0.415312796831131
450 0.5867985486984253
451 0.051113855093717575
452 0.42102304100990295
453 0.5024016499519348
454 0.4216706156730652
455 0.5016684532165527
456 0.06912396848201752
457 0.33773574233055115
458 0.29985156655311584
459 0.25205451250076294
460 0.394182413816452
461 0.3894519805908203
462 0.128422811627388
463 0.24830172955989838
464 0.3238995671272278
465 0.856502115726471
466 0.19070041179656982
467 0.11447959393262863
468 0.2335065603256