In [1]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [2]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out) # Read more about torch.nn doc, especially about layers and their arguments.

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print("[{}] {}".format(t, loss.item()))

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

[0] 636.9194946289062
[1] 592.531494140625
[2] 553.710205078125
[3] 519.7258911132812
[4] 489.281005859375
[5] 461.8702392578125
[6] 436.8700256347656
[7] 413.71893310546875
[8] 392.3301086425781
[9] 372.62286376953125
[10] 354.1456298828125
[11] 336.6742248535156
[12] 320.2906799316406
[13] 304.73016357421875
[14] 290.0709533691406
[15] 276.1735534667969
[16] 262.94757080078125
[17] 250.31365966796875
[18] 238.23040771484375
[19] 226.64918518066406
[20] 215.57943725585938
[21] 204.9400177001953
[22] 194.76821899414062
[23] 184.98814392089844
[24] 175.65174865722656
[25] 166.73196411132812
[26] 158.19517517089844
[27] 150.04238891601562
[28] 142.27337646484375
[29] 134.86276245117188
[30] 127.79107666015625
[31] 121.02002716064453
[32] 114.56597137451172
[33] 108.42288208007812
[34] 102.55166625976562
[35] 96.98040008544922
[36] 91.70208740234375
[37] 86.69710540771484
[38] 81.9527587890625
[39] 77.45317840576172
[40] 73.18809509277344
[41] 69.14895629882812
[42] 65.32052612304688
[43]

[384] 0.00015105766942724586
[385] 0.0001464281667722389
[386] 0.00014195531548466533
[387] 0.0001376177096972242
[388] 0.00013340443547349423
[389] 0.000129326872411184
[390] 0.00012537886505015194
[391] 0.00012155556032666937
[392] 0.00011784442176576704
[393] 0.00011424621334299445
[394] 0.0001107561110984534
[395] 0.00010737638513091952
[396] 0.00010410111281089485
[397] 0.00010092943557538092
[398] 9.784851863514632e-05
[399] 9.486500493949279e-05
[400] 9.197470353683457e-05
[401] 8.917315426515415e-05
[402] 8.645620982861146e-05
[403] 8.381728548556566e-05
[404] 8.126273314701393e-05
[405] 7.878962787799537e-05
[406] 7.63933639973402e-05
[407] 7.406780787277967e-05
[408] 7.181249384302646e-05
[409] 6.962746556382626e-05
[410] 6.75102201057598e-05
[411] 6.54604154988192e-05
[412] 6.347094313241541e-05
[413] 6.154257425805554e-05
[414] 5.966968819848262e-05
[415] 5.7858473155647516e-05
[416] 5.610283187706955e-05
[417] 5.4400985391112044e-05
[418] 5.274452269077301e-05
[419] 5.1145