
PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: <p style="color:red">a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, </p> reusing the same weights multiple times to compute the innermost hidden layers.



In [None]:
%matplotlib inline

<h1 style="background-image: linear-gradient( 135deg, #ABDCFF 10%, #0396FF 100%);"> Orinal Tutorial code

In [1]:
import random
import torch
from torch.autograd import Variable


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 644.5894775390625
1 639.4033813476562
2 638.3150024414062
3 670.4744262695312
4 635.1463012695312
5 587.4576416015625
6 634.0379638671875
7 630.0640258789062
8 439.10614013671875
9 627.6533813476562
10 626.1864013671875
11 624.9271240234375
12 301.4223327636719
13 622.37158203125
14 617.2005004882812
15 601.5762939453125
16 592.2183837890625
17 605.8775634765625
18 599.1980590820312
19 539.9949951171875
20 580.2672119140625
21 567.6670532226562
22 595.6627197265625
23 536.1651000976562
24 144.83700561523438
25 133.03121948242188
26 482.430419921875
27 99.84098815917969
28 364.9731140136719
29 515.7755737304688
30 396.20147705078125
31 366.62249755859375
32 333.6767272949219
33 401.99188232421875
34 240.6501922607422
35 133.25538635253906
36 335.3799743652344
37 296.30657958984375
38 110.04197692871094
39 217.27403259277344
40 74.6769027709961
41 167.05450439453125
42 147.43907165527344
43 125.50924682617188
44 242.14285278320312
45 109.07363891601562
46 252.27027893066406
47 90.68605

418 1.1911224126815796
419 0.8813117146492004
420 1.302135944366455
421 0.7241997122764587
422 8.525102615356445
423 15.13372802734375
424 2.0926566123962402
425 8.081398963928223
426 16.25365447998047
427 13.61452865600586
428 6.380634784698486
429 3.8261148929595947
430 3.3459224700927734
431 6.9500041007995605
432 11.443480491638184
433 7.508928298950195
434 3.414703845977783
435 1.4115219116210938
436 1.6969949007034302
437 1.3115109205245972
438 2.016249179840088
439 30.04707145690918
440 15.225565910339355
441 2.0710134506225586
442 3.3934998512268066
443 17.490060806274414
444 17.14543914794922
445 13.245141983032227
446 7.75592565536499
447 2.808720588684082
448 52.353885650634766
449 2.649913787841797
450 1.0520647764205933
451 15.007600784301758
452 2.204399585723877
453 5.307397365570068
454 9.032097816467285
455 8.293466567993164
456 0.9743533730506897
457 1.6548889875411987
458 1.7362059354782104
459 2.461306095123291
460 3.49342679977417
461 6.878848552703857
462 1.878756