In [1]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [2]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 695.6751708984375
1 646.6774291992188
2 604.6642456054688
3 567.6747436523438
4 534.6677856445312
5 504.7969970703125
6 477.57464599609375
7 452.2152099609375
8 428.53265380859375
9 406.2913513183594
10 385.5321350097656
11 366.0522766113281
12 347.6016540527344
13 330.047607421875
14 313.3310546875
15 297.4138488769531
16 282.1618957519531
17 267.3375244140625
18 253.1653289794922
19 239.61422729492188
20 226.6613311767578
21 214.24618530273438
22 202.36819458007812
23 191.0445098876953
24 180.2359161376953
25 169.89537048339844
26 160.0499725341797
27 150.6020965576172
28 141.6372833251953
29 133.14633178710938
30 125.10774230957031
31 117.48148345947266
32 110.2770004272461
33 103.47874450683594
34 97.08332061767578
35 91.05171966552734
36 85.38423156738281
37 80.06195068359375
38 75.0705337524414
39 70.38928985595703
40 65.99249267578125
41 61.8761100769043
42 58.02488708496094
43 54.42625427246094
44 51.065330505371094
45 47.921329498291016
46 44.984230041503906
47 42.2364540100

399 3.123076385236345e-05
400 3.0291414077510126e-05
401 2.9382754291873425e-05
402 2.8502812710939907e-05
403 2.7649935873341747e-05
404 2.682219201233238e-05
405 2.6019482902484015e-05
406 2.524228511902038e-05
407 2.4488987037329935e-05
408 2.375877375015989e-05
409 2.3049065930536017e-05
410 2.2364731194102205e-05
411 2.1701531295548193e-05
412 2.1057199774077162e-05
413 2.0431876691873185e-05
414 1.9824223272735253e-05
415 1.9237295418861322e-05
416 1.8667933545657434e-05
417 1.8115664715878665e-05
418 1.7580881831236184e-05
419 1.7061689504771493e-05
420 1.655934647715185e-05
421 1.606955447641667e-05
422 1.5596539014950395e-05
423 1.5138444723561406e-05
424 1.4693932826048695e-05
425 1.4262114746088628e-05
426 1.384341339871753e-05
427 1.3437865163723473e-05
428 1.3043837498116773e-05
429 1.2662081644521095e-05
430 1.2293076906644274e-05
431 1.1932302186323795e-05
432 1.1584532330743968e-05
433 1.1246196663705632e-05
434 1.0919346095761284e-05
435 1.0601158464851324e-05
436 1.02