In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [25]:
import random
import torch

In [38]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet,self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H,H)
        self.output_linear = torch.nn.Linear(H, D_out)
    
    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min =0)
        for _ in range(random.randint(0,13)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
        

In [39]:
N, D_in , H, D_out = 64,1000,100,10
x = torch.randn(N,D_in)
y = torch.randn(N, D_out)
model = DynamicNet(D_in, H , D_out)
criterion = torch.nn.MSELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), lr  = 1e-4, momentum =0.9)



In [48]:
for t in range(500):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    
    print(t, loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    

0 0.041949477046728134
1 0.05507908761501312
2 0.11131494492292404
3 0.046538095921278
4 0.48481667041778564
5 0.03988025709986687
6 0.10894663631916046
7 0.02434893697500229
8 0.046595241874456406
9 0.03672238811850548
10 0.04745016619563103
11 0.02112298272550106
12 0.280609667301178
13 0.026688670739531517
14 0.024086952209472656
15 0.46437764167785645
16 0.09547724574804306
17 0.44588950276374817
18 0.03642099350690842
19 0.024704398587346077
20 0.1091727614402771
21 0.23184779286384583
22 0.2773444652557373
23 0.04407142475247383
24 0.10099394619464874
25 0.03863563388586044
26 0.0281369611620903
27 0.20964214205741882
28 0.02950955368578434
29 0.042603522539138794
30 0.03304529935121536
31 0.18336796760559082
32 0.045622751116752625
33 0.02935270592570305
34 0.02099362201988697
35 0.02140446938574314
36 0.0550759956240654
37 0.04333793371915817
38 0.03885185346007347
39 0.032527148723602295
40 0.0194929838180542
41 0.03881419077515602
42 0.15713541209697723
43 0.05569931492209434

351 0.015394452027976513
352 0.24853353202342987
353 0.016587959602475166
354 0.012807020917534828
355 0.07926558703184128
356 0.01142523530870676
357 0.025613483041524887
358 0.024297932162880898
359 0.04439034312963486
360 0.0228982362896204
361 0.02506488561630249
362 0.013569771312177181
363 0.009998390451073647
364 0.01428870391100645
365 0.014520736411213875
366 0.07248790562152863
367 0.23242008686065674
368 0.033387407660484314
369 0.2178194224834442
370 0.016191447153687477
371 0.01526546012610197
372 0.167498379945755
373 0.03032037988305092
374 0.013560508377850056
375 0.014999346807599068
376 0.014540543779730797
377 0.01584547571837902
378 0.03145025297999382
379 0.2019221931695938
380 0.14362752437591553
381 0.5283604264259338
382 0.0285023283213377
383 0.49382370710372925
384 0.02747708559036255
385 0.05380163714289665
386 0.0647129938006401
387 0.027709703892469406
388 0.22755838930606842
389 0.049277644604444504
390 0.037501346319913864
391 0.20944884419441223
392 0.02

In [33]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 613.0685424804688
1 615.8286743164062
2 613.2557373046875
3 650.8713989257812
4 599.9246826171875
5 607.2288208007812
6 604.2871704101562
7 594.74658203125
8 380.8633117675781
9 337.76458740234375
10 574.0498657226562
11 254.71676635742188
12 215.2542724609375
13 547.0830688476562
14 596.153564453125
15 594.25634765625
16 105.63397979736328
17 486.4574890136719
18 462.6673889160156
19 430.0323181152344
20 575.714599609375
21 567.71044921875
22 324.9360656738281
23 290.7056579589844
24 255.46754455566406
25 512.6973266601562
26 252.57281494140625
27 179.08229064941406
28 388.6142272949219
29 360.3486022949219
30 323.6364440917969
31 381.949462890625
32 146.20822143554688
33 217.07272338867188
34 174.17236328125
35 198.99378967285156
36 254.5848846435547
37 66.42955780029297
38 181.45506286621094
39 44.743797302246094
40 190.1979522705078
41 45.063438415527344
42 153.6013946533203
43 85.84371948242188
44 73.45912170410156
45 179.5205078125
46 69.43209075927734
47 67.40245056152344
48 5

389 0.2760138511657715
390 0.3296429216861725
391 0.2844026982784271
392 0.3580690026283264
393 1.1755170822143555
394 0.6573696732521057
395 0.6575215458869934
396 0.27164921164512634
397 0.6314816474914551
398 0.9922917485237122
399 0.32082033157348633
400 0.4094148576259613
401 1.063256859779358
402 0.23239365220069885
403 0.9050524234771729
404 0.4101547300815582
405 0.401638925075531
406 0.39605221152305603
407 0.3675440847873688
408 0.7991337180137634
409 0.29436978697776794
410 0.27433863282203674
411 0.6941025257110596
412 0.24372470378875732
413 0.6698904633522034
414 0.7998530864715576
415 0.7635285258293152
416 0.25608134269714355
417 0.6880286931991577
418 0.2174615114927292
419 0.27769482135772705
420 0.28720515966415405
421 0.16207922995090485
422 0.2862741947174072
423 0.2719910442829132
424 0.6603026390075684
425 0.6199650764465332
426 0.20767149329185486
427 0.14101791381835938
428 0.23078055679798126
429 0.15196079015731812
430 0.5913381576538086
431 0.133661717176437