In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [1]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) # <- need to read Goodfellow book on momentum! :Now had some idea.
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 641.9730224609375
1 632.764892578125
2 625.1275634765625
3 610.7161254882812
4 629.3384399414062
5 577.951171875
6 626.5672607421875
7 629.0173950195312
8 624.083740234375
9 622.90869140625
10 520.5401611328125
11 620.6424560546875
12 500.1115417480469
13 675.5016479492188
14 476.09600830078125
15 616.5093383789062
16 617.3945922851562
17 613.6233520507812
18 425.2801513671875
19 500.6015930175781
20 452.28765869140625
21 604.3907470703125
22 331.84295654296875
23 586.7882690429688
24 355.1574401855469
25 193.2036590576172
26 321.7364807128906
27 573.7020874023438
28 121.79070281982422
29 505.0581359863281
30 480.1884765625
31 446.55865478515625
32 218.00439453125
33 204.83644104003906
34 187.24295043945312
35 402.8926086425781
36 288.7848815917969
37 233.28794860839844
38 240.8118896484375
39 293.09893798828125
40 263.2111511230469
41 234.69061279296875
42 176.88014221191406
43 141.3565216064453
44 92.90983581542969
45 169.4600067138672
46 112.80691528320312
47 183.60671997070312
48

421 0.9670897126197815
422 0.838800311088562
423 0.6047565937042236
424 0.9169514775276184
425 1.0084422826766968
426 0.5078537464141846
427 0.35471203923225403
428 0.29958903789520264
429 0.5651524662971497
430 0.2133977711200714
431 0.7011054754257202
432 0.08669000118970871
433 0.7021254301071167
434 0.6098252534866333
435 0.08633134514093399
436 0.7956994771957397
437 0.7396377921104431
438 0.41057732701301575
439 0.5790735483169556
440 0.4202149212360382
441 0.10675828158855438
442 0.6275072693824768
443 0.556876540184021
444 0.4558011293411255
445 0.47552794218063354
446 0.7978129386901855
447 0.6748902201652527
448 0.3440309762954712
449 0.6505711078643799
450 0.3404955267906189
451 0.5946183204650879
452 0.4593640863895416
453 0.11427736282348633
454 0.11978919059038162
455 0.11441303044557571
456 0.5777626037597656
457 0.7058702707290649
458 0.45843884348869324
459 0.08315662294626236
460 0.4365514814853668
461 0.755645215511322
462 0.5874375104904175
463 0.05661796033382416
4