In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [7]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) # <- need to read Goodfellow book on momentum! :Now had some idea.
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 632.0070190429688
1 666.876953125
2 638.161865234375
3 627.4966430664062
4 540.6419677734375
5 625.9701538085938
6 624.4487915039062
7 404.8503112792969
8 360.1130676269531
9 309.9834899902344
10 619.8305053710938
11 617.7686767578125
12 187.36077880859375
13 619.4473266601562
14 572.4185180664062
15 560.6989135742188
16 538.2614135742188
17 615.2755737304688
18 478.0504150390625
19 91.02987670898438
20 607.3973999023438
21 86.8807601928711
22 599.2420654296875
23 593.2177734375
24 541.4122314453125
25 524.780029296875
26 298.0304870605469
27 274.3638610839844
28 451.6155700683594
29 420.3275146484375
30 474.326904296875
31 193.68658447265625
32 415.274169921875
33 304.5037841796875
34 279.4242248535156
35 328.88238525390625
36 199.12786865234375
37 211.66763305664062
38 162.36727905273438
39 164.17698669433594
40 220.40512084960938
41 192.2908477783203
42 318.90118408203125
43 120.50137329101562
44 150.4662322998047
45 110.84688568115234
46 141.32614135742188
47 106.79881286621094
4

404 0.10278981178998947
405 0.5079350471496582
406 0.2549491226673126
407 0.135505810379982
408 0.8293576240539551
409 0.7672693133354187
410 0.11015011370182037
411 0.6788011789321899
412 0.3940145671367645
413 0.050739821046590805
414 0.5166071653366089
415 0.566130518913269
416 0.043732915073633194
417 0.37018921971321106
418 0.04540588706731796
419 0.47056618332862854
420 0.045646168291568756
421 0.6915830969810486
422 0.04357588291168213
423 0.4896980822086334
424 0.4213946461677551
425 0.36572784185409546
426 0.12923529744148254
427 0.5229228734970093
428 0.770810604095459
429 0.4281606078147888
430 0.38533082604408264
431 0.11228673905134201
432 0.12014661729335785
433 0.5066710114479065
434 0.7154948711395264
435 0.08985961973667145
436 0.36750471591949463
437 1.1001806259155273
438 0.4051971733570099
439 0.5390541553497314
440 0.5732767581939697
441 0.5922435522079468
442 0.3235422968864441
443 0.44330713152885437
444 0.7471727132797241
445 0.059689927846193314
446 0.424282222