In [1]:
import random
import torch

%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
class DynamicNet(torch.nn.Module):
    def _forward_unimplemented(self, *input)-> None:
        raise NotImplemented("forward_unimplemented")

    def __init__(self, d_in, h, d_out):
        """
        In the constructor we construct three nn.Linear instances
        that we will use in the forward pass.
        d_in: dimension of input
        h: dimension of hidden
        d_out: dimension of output
        """
        super().__init__()
        self.input_linear = torch.nn.Linear(d_in, h)
        self.middle_linear = torch.nn.Linear(h, h)
        self.output_linear = torch.nn.Linear(h, d_out)

    def forward(self, x):
        """
        For the forward pass of the model,
        we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module
        that many times to compute hidden layer representations.

        Since each forward pass builds a dynamic computation graph,
        we can use normal Python control-flow operators
        like loops or conditional statements
        when defining the forward pass of the model.

        Here we also see that it is perfectly safe
        to reuse the same Module many times
        when defining a computational graph.
        This is a big improvement from Lua Torch,
        where each Module could be used only once.
        """
        h_relu = torch.clamp(self.input_linear(x), min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = torch.clamp(self.middle_linear(h_relu), min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N = 64
D_in = 1000
H = 100
D_out = 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = DynamicNet(D_in, H, D_out)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9)
for t in range(501):
    y_pred = model(x)

    loss = criterion(y_pred, y)
    if t % 100 == 0:
        print(t, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 0.9826545715332031
100 0.016170386224985123
200 0.0018488296773284674
300 0.0009112719562835991
400 0.07059653103351593
500 0.0008878077496774495
