In [1]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [7]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above

model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 604.6024780273438
1 560.1512451171875
2 521.8259887695312
3 487.7924499511719
4 457.4266052246094
5 430.1676025390625
6 405.25262451171875
7 382.372802734375
8 361.31787109375
9 341.60870361328125
10 323.09124755859375
11 305.7405700683594
12 289.4228515625
13 274.0544738769531
14 259.6551513671875
15 245.99337768554688
16 233.0807342529297
17 220.82118225097656
18 209.14036560058594
19 198.0222625732422
20 187.42755126953125
21 177.36294555664062
22 167.75955200195312
23 158.5957794189453
24 149.7960968017578
25 141.4335174560547
26 133.4853515625
27 125.9326400756836
28 118.7765884399414
29 111.95243072509766
30 105.4765853881836
31 99.33563995361328
32 93.52220153808594
33 88.0361557006836
34 82.85165405273438
35 77.94575500488281
36 73.3017349243164
37 68.92732238769531
38 64.80717468261719
39 60.92436218261719
40 57.270450592041016
41 53.830631256103516
42 50.58713150024414
43 47.53158187866211
44 44.655242919921875
45 41.94533157348633
46 39.39404296875
47 37.00051498413086
48 

478 7.017076768534025e-06
479 6.818489055149257e-06
480 6.625204605370527e-06
481 6.436962848965777e-06
482 6.254150775930611e-06
483 6.075832516216906e-06
484 5.903920737182489e-06
485 5.737116680393228e-06
486 5.57404291612329e-06
487 5.416246494860388e-06
488 5.263013463263633e-06
489 5.1138845265086275e-06
490 4.968732810084475e-06
491 4.8283891374012455e-06
492 4.691185040428536e-06
493 4.558832188195083e-06
494 4.42976215708768e-06
495 4.304716185288271e-06
496 4.182324119028635e-06
497 4.064916993229417e-06
498 3.95008555642562e-06
499 3.838177235593321e-06
