In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [12]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        #super().__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 648.777099609375
1 597.255859375
2 606.57373046875
3 520.9375610351562
4 597.3939819335938
5 583.7222900390625
6 380.49993896484375
7 590.938720703125
8 589.9258422851562
9 550.6576538085938
10 253.9046630859375
11 584.337158203125
12 202.6620330810547
13 172.86862182617188
14 575.093994140625
15 582.7093505859375
16 563.162353515625
17 553.6776733398438
18 465.10675048828125
19 566.54931640625
20 510.8924255371094
21 400.8252258300781
22 91.05535888671875
23 520.2099609375
24 93.90271759033203
25 307.71368408203125
26 283.29888916015625
27 377.6636962890625
28 351.3905334472656
29 410.18743896484375
30 106.5442123413086
31 271.9952087402344
32 247.11453247070312
33 221.507568359375
34 272.93182373046875
35 176.0035858154297
36 112.40538787841797
37 162.09669494628906
38 92.98609924316406
39 144.59396362304688
40 162.7040557861328
41 51.29045867919922
42 41.16269302368164
43 159.15647888183594
44 107.45955657958984
45 35.554386138916016
46 187.73654174804688
47 34.99728012084961
48 1

436 1.2572232484817505
437 1.1814792156219482
438 4.911816120147705
439 0.7525244951248169
440 1.7329334020614624
441 0.47518405318260193
442 0.4733206033706665
443 2.359785556793213
444 5.249399185180664
445 0.7289707064628601
446 1.9865050315856934
447 2.880176544189453
448 1.2112845182418823
449 0.6547799706459045
450 0.6273237466812134
451 7.417335033416748
452 1.795701026916504
453 1.436179757118225
454 6.508212089538574
455 8.927679061889648
456 0.2955068349838257
457 2.510340929031372
458 9.974349975585938
459 1.6311030387878418
460 0.5990903377532959
461 3.2645273208618164
462 0.3583051264286041
463 0.3057057857513428
464 0.2504807114601135
465 4.308474063873291
466 0.6669840812683105
467 0.7849133014678955
468 5.367554664611816
469 2.742260694503784
470 0.5812362432479858
471 0.3435352146625519
472 8.882439613342285
473 3.0795650482177734
474 0.5928665399551392
475 2.215555429458618
476 1.403088092803955
477 26.26254653930664
478 4.291867256164551
479 1.6539314985275269
480 4.