In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 641.0698852539062
1 596.6747436523438
2 557.7380981445312
3 523.2838134765625
4 492.4844055175781
5 464.98486328125
6 439.8149719238281
7 416.69049072265625
8 395.1796875
9 375.2060241699219
10 356.4792175292969
11 338.8349304199219
12 322.2173156738281
13 306.3661193847656
14 291.2605285644531
15 276.9498291015625
16 263.28656005859375
17 250.20050048828125
18 237.67408752441406
19 225.7671356201172
20 214.42071533203125
21 203.572265625
22 193.23353576660156
23 183.376953125
24 173.96966552734375
25 164.99476623535156
26 156.42930603027344
27 148.25270080566406
28 140.4392852783203
29 132.94419860839844
30 125.80224609375
31 119.00061798095703
32 112.54032897949219
33 106.39104461669922
34 100.5414810180664
35 94.9902114868164
36 89.73754119873047
37 84.74893951416016
38 79.9990463256836
39 75.49420166015625
40 71.22856140136719
41 67.19147491455078
42 63.372955322265625
43 59.76323318481445
44 56.35694885253906
45 53.1414909362793
46 50.1088981628418
47 47.25053405761719
48 44.559

470 3.6409585391083965e-06
471 3.5260416098026326e-06
472 3.4132888231397374e-06
473 3.3061205613194034e-06
474 3.201116442141938e-06
475 3.1005440632725367e-06
476 3.001604909513844e-06
477 2.9067341529298574e-06
478 2.8154388473922154e-06
479 2.7262694857199676e-06
480 2.6399663966003573e-06
481 2.5568981527612777e-06
482 2.4758132894930895e-06
483 2.3980633159226272e-06
484 2.322020009160042e-06
485 2.24806376536435e-06
486 2.177251190005336e-06
487 2.1092737370054238e-06
488 2.0423960904736305e-06
489 1.9779051854129648e-06
490 1.9154119854647433e-06
491 1.8556474969955161e-06
492 1.7966224277188303e-06
493 1.740423158480553e-06
494 1.685539132267877e-06
495 1.632676571716729e-06
496 1.5816104905752582e-06
497 1.531224484097038e-06
498 1.4830255850029062e-06
499 1.4364138678502059e-06
