In [0]:
# Install Pytorch.
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.3.0.post4-{platform}-linux_x86_64.whl torchvision

In [0]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [0]:
import random
import torch
from torch.autograd import Variable

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will 
        use in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2 or 
        3 and reuse the middle_linear Module that many times to compute hidden 
        layer representation.
        
        Since each forward pass builds a dynamic computation graph, we can use 
        normal Python control-flow operations like loops or conditional 
        statements when defining the forward pass of the model.
        
        Here we also see that it is prefectly safe to reuse the same Module 
        many times when defining a computational graph. This is a big 
        improvement from Lua Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [6]:
# N: batch size, D_in: input dim, H: hidden dim, D_out: output dim
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above.
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model 
# with vanilla stochastic gradient descent is tough, so we use momentum.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model.
    y_pred = model(x)
    
    # Compute and print loss.
    loss = criterion(y_pred, y)
    print(t, loss.data[0])
    
    # Zero gradients, perform a backward pass, and update the weights.
    otpimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 611.2299194335938
1 605.0342407226562
2 589.7377319335938
3 657.3991088867188
4 600.1734619140625
5 535.580078125
6 437.531005859375
7 600.4235229492188
8 535.2461547851562
9 624.2636108398438
10 802.2166748046875
11 623.890625
12 589.0665893554688
13 580.9469604492188
14 5807.146484375
15 4858.7958984375
16 2374.6298828125
17 5670.1064453125
18 719.2976684570312
19 626.1219482421875
20 611.5909423828125
21 613.2060546875
22 323168.5625
23 664.7927856445312
24 990.290771484375
25 1944.661376953125
26 69392176.0
27 7054.572265625
28 1010197248.0
29 2.63035888756313e+21
30 inf
31 nan
32 nan
33 nan
34 nan
35 nan
36 nan
37 nan
38 nan
39 nan
40 nan
41 nan
42 nan
43 nan
44 nan
45 nan
46 nan
47 nan
48 nan
49 nan
50 nan
51 nan
52 nan
53 nan
54 nan
55 nan
56 nan
57 nan
58 nan
59 nan
60 nan
61 nan
62 nan
63 nan
64 nan
65 nan
66 nan
67 nan
68 nan
69 nan
70 nan
71 nan
72 nan
73 nan
74 nan
75 nan
76 nan
77 nan
78 nan
79 nan
80 nan
81 nan
82 nan
83 nan
84 nan
85 nan
86 nan
87 nan
88 nan
89 nan
90 