In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 644.6627807617188
1 643.5653076171875
2 688.9827270507812
3 639.3114624023438
4 634.9447021484375
5 633.6138916015625
6 635.6046752929688
7 627.9896240234375
8 625.3439331054688
9 455.9207458496094
10 606.3209838867188
11 596.0707397460938
12 609.029541015625
13 617.5655517578125
14 340.4871826171875
15 615.0897216796875
16 287.83599853515625
17 251.9388885498047
18 611.0119018554688
19 518.3916625976562
20 501.4783020019531
21 580.9364624023438
22 116.7131576538086
23 560.9356079101562
24 93.58380889892578
25 531.449951171875
26 75.97832489013672
27 559.9642944335938
28 543.8159790039062
29 521.3397216796875
30 302.643798828125
31 393.863037109375
32 433.0780944824219
33 104.15827178955078
34 318.39215087890625
35 102.70698547363281
36 343.8573913574219
37 208.93682861328125
38 189.76467895507812
39 224.5198516845703
40 203.00758361816406
41 136.34304809570312
42 118.95207977294922
43 98.97998046875
44 84.15271759033203
45 206.9392547607422
46 215.49209594726562
47 166.5280303955078

375 5.517325401306152
376 0.31607767939567566
377 1.8967866897583008
378 1.2167118787765503
379 0.8606668710708618
380 3.7295093536376953
381 1.2953513860702515
382 1.4246147871017456
383 3.4370598793029785
384 0.4661564826965332
385 2.3090455532073975
386 0.14436478912830353
387 0.22734835743904114
388 1.7997506856918335
389 1.5710887908935547
390 0.7745603322982788
391 1.3040941953659058
392 0.5360349416732788
393 2.6828181743621826
394 2.989166259765625
395 2.292545795440674
396 1.1375465393066406
397 0.6958595514297485
398 1.0297913551330566
399 3.6965930461883545
400 0.429613322019577
401 3.3196282386779785
402 0.7028681635856628
403 2.3800013065338135
404 4.218087196350098
405 2.0438485145568848
406 0.7637020945549011
407 0.3728456497192383
408 2.0448098182678223
409 1.4178811311721802
410 0.1351921111345291
411 1.9018195867538452
412 0.8187096118927002
413 0.5728773474693298
414 0.950486958026886
415 1.8198342323303223
416 0.9922348260879517
417 0.9671313166618347
418 0.55240780