In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 648.9775390625
1 646.5050659179688
2 642.1493530273438
3 641.8135986328125
4 685.8380126953125
5 642.6522216796875
6 584.0333251953125
7 621.5007934570312
8 455.608154296875
9 608.1010131835938
10 632.3406372070312
11 616.1444091796875
12 286.30694580078125
13 630.7894897460938
14 610.2770385742188
15 606.1801147460938
16 192.44374084472656
17 548.6043090820312
18 624.430419921875
19 128.59329223632812
20 618.9412231445312
21 615.141357421875
22 469.7901306152344
23 444.9983215332031
24 528.701904296875
25 508.823974609375
26 98.26380920410156
27 556.087646484375
28 297.55560302734375
29 402.8298034667969
30 103.24909973144531
31 222.23342895507812
32 317.0555419921875
33 179.3186798095703
34 156.77008056640625
35 348.9098205566406
36 112.85706329345703
37 100.86849212646484
38 76.3909683227539
39 52.423030853271484
40 110.68841552734375
41 26.18639373779297
42 251.68824768066406
43 79.24877166748047
44 180.97434997558594
45 144.20106506347656
46 64.11666107177734
47 82.7652053833007

389 0.5718395113945007
390 0.45080092549324036
391 0.5746116638183594
392 0.8309420943260193
393 0.6727114319801331
394 0.6056891083717346
395 0.505732536315918
396 0.1628689020872116
397 0.12365368008613586
398 0.08635041862726212
399 0.5267108678817749
400 0.4192533791065216
401 0.4209129810333252
402 0.41378602385520935
403 0.19856002926826477
404 0.16643434762954712
405 0.11080358922481537
406 0.26556992530822754
407 0.9801313877105713
408 0.5869809985160828
409 0.2858535349369049
410 1.042236566543579
411 0.8219985365867615
412 0.6575505137443542
413 0.36951741576194763
414 0.9422004222869873
415 0.8377964496612549
416 0.4281536042690277
417 0.3331311047077179
418 0.5153644680976868
419 0.49676772952079773
420 0.6329869627952576
421 1.096946120262146
422 0.17794817686080933
423 0.23617829382419586
424 1.1639161109924316
425 0.13709668815135956
426 0.47403398156166077
427 0.43766409158706665
428 0.8926413059234619
429 0.08124659210443497
430 0.08675819635391235
431 0.87858045101165