In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

(0, 715.3890380859375)
(1, 660.4926147460938)
(2, 612.8158569335938)
(3, 571.4666748046875)
(4, 534.936767578125)
(5, 501.9896240234375)
(6, 471.7643737792969)
(7, 444.1095886230469)
(8, 418.4523620605469)
(9, 394.58624267578125)
(10, 372.15325927734375)
(11, 351.12353515625)
(12, 331.174072265625)
(13, 312.2660217285156)
(14, 294.3206787109375)
(15, 277.27886962890625)
(16, 261.1218566894531)
(17, 245.73541259765625)
(18, 231.10032653808594)
(19, 217.25906372070312)
(20, 204.12539672851562)
(21, 191.71636962890625)
(22, 179.95968627929688)
(23, 168.8275909423828)
(24, 158.31483459472656)
(25, 148.42327880859375)
(26, 139.1171417236328)
(27, 130.35455322265625)
(28, 122.10467529296875)
(29, 114.35877990722656)
(30, 107.08446502685547)
(31, 100.28746032714844)
(32, 93.9335708618164)
(33, 87.98413848876953)
(34, 82.42335510253906)
(35, 77.20706176757812)
(36, 72.32859802246094)
(37, 67.75935363769531)
(38, 63.49537658691406)
(39, 59.525611877441406)
(40, 55.82455062866211)
(41, 52.372299

(408, 6.725919229211286e-05)
(409, 6.524274795083329e-05)
(410, 6.328708695946261e-05)
(411, 6.139417382655665e-05)
(412, 5.9555059124249965e-05)
(413, 5.777029218734242e-05)
(414, 5.604069519904442e-05)
(415, 5.4359654313884676e-05)
(416, 5.273879651213065e-05)
(417, 5.1155446271877736e-05)
(418, 4.962425009580329e-05)
(419, 4.814197382074781e-05)
(420, 4.669983536587097e-05)
(421, 4.5303593651624396e-05)
(422, 4.39482246292755e-05)
(423, 4.26323531428352e-05)
(424, 4.135877679800615e-05)
(425, 4.012539648101665e-05)
(426, 3.8922149542486295e-05)
(427, 3.7759287806693465e-05)
(428, 3.663180905277841e-05)
(429, 3.553661372279748e-05)
(430, 3.4472683182684705e-05)
(431, 3.344576180097647e-05)
(432, 3.2447424018755555e-05)
(433, 3.148192627122626e-05)
(434, 3.053904583794065e-05)
(435, 2.9626677132910118e-05)
(436, 2.8743117582052946e-05)
(437, 2.7886255338671617e-05)
(438, 2.705230326682795e-05)
(439, 2.6247471396345645e-05)
(440, 2.5463654310442507e-05)
(441, 2.470429717504885e-05)
(44