In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 626.7391967773438
1 622.1098022460938
2 607.9678955078125
3 592.232177734375
4 667.7291870117188
5 556.2840576171875
6 612.1655883789062
7 612.8768920898438
8 611.8601684570312
9 492.64544677734375
10 604.688720703125
11 501.07867431640625
12 599.7156372070312
13 483.9908447265625
14 348.12298583984375
15 455.5238342285156
16 282.84173583984375
17 415.7019348144531
18 567.4804077148438
19 595.324462890625
20 163.48101806640625
21 582.60791015625
22 307.8831787109375
23 115.3112564086914
24 264.4409484863281
25 472.8134765625
26 91.59617614746094
27 194.427734375
28 79.7347640991211
29 68.36082458496094
30 374.44561767578125
31 457.7557067871094
32 44.2841682434082
33 38.868927001953125
34 30.612096786499023
35 113.44486236572266
36 357.3103942871094
37 24.652080535888672
38 219.2179718017578
39 287.04736328125
40 28.36749839782715
41 189.6727294921875
42 141.13368225097656
43 101.55004119873047
44 157.08685302734375
45 77.69342803955078
46 46.861412048339844
47 52.627864837646484
48 

413 0.22295868396759033
414 0.5914254784584045
415 0.39000770449638367
416 0.3812252879142761
417 0.3281693458557129
418 0.3528061509132385
419 0.36040496826171875
420 0.5483784079551697
421 0.32801246643066406
422 0.557026743888855
423 0.5376536846160889
424 0.24820666015148163
425 0.2392413318157196
426 0.26452377438545227
427 0.38757047057151794
428 0.47083815932273865
429 0.2511433959007263
430 0.250446617603302
431 0.43168702721595764
432 0.20669762790203094
433 0.1851249635219574
434 0.3737412989139557
435 0.337125688791275
436 0.5394218564033508
437 0.2949449121952057
438 0.3000316917896271
439 0.11241020262241364
440 0.26359468698501587
441 0.09783808141946793
442 0.09177977591753006
443 0.46085914969444275
444 0.36122798919677734
445 0.5160552263259888
446 0.5299643278121948
447 0.4629538357257843
448 0.2947358787059784
449 0.08217603713274002
450 0.34182727336883545
451 0.11415931582450867
452 0.118045374751091
453 0.34842973947525024
454 0.7273983359336853
455 0.305568546056