In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 684.2272338867188
1 634.1856689453125
2 590.7869873046875
3 552.588134765625
4 518.6983032226562
5 487.9130554199219
6 460.0026550292969
7 434.5419616699219
8 410.87799072265625
9 388.839599609375
10 368.19500732421875
11 348.62908935546875
12 330.1213684082031
13 312.5187683105469
14 295.8311462402344
15 280.01043701171875
16 265.0495300292969
17 250.75555419921875
18 237.0924072265625
19 224.05728149414062
20 211.62413024902344
21 199.82315063476562
22 188.5401153564453
23 177.83114624023438
24 167.6607666015625
25 158.0272979736328
26 148.86019897460938
27 140.16738891601562
28 131.95082092285156
29 124.18256378173828
30 116.8108139038086
31 109.85496520996094
32 103.27379608154297
33 97.06419372558594
34 91.2105484008789
35 85.70203399658203
36 80.5042953491211
37 75.62415313720703
38 71.0205078125
39 66.69683074951172
40 62.632568359375
41 58.81942367553711
42 55.238826751708984
43 51.88187789916992
44 48.73298645019531
45 45.78246307373047
46 43.01930236816406
47 40.43046951293