In [1]:
# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable

In [2]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        """
        h1_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h1_relu)
        return y_pred

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

In [4]:
# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

In [5]:
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

In [7]:
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 3.318470044177957e-05
1 3.241200465708971e-05
2 3.1657444196753204e-05
3 3.0919072742108256e-05
4 3.019823634531349e-05
5 2.949641020677518e-05
6 2.8810041840188205e-05
7 2.8139014830230735e-05
8 2.7482557925395668e-05
9 2.6843546947930008e-05
10 2.621873136376962e-05
11 2.5607358111301437e-05
12 2.501367634977214e-05
13 2.4432649297523312e-05
14 2.3862819944042712e-05
15 2.3307648007175885e-05
16 2.276632767461706e-05
17 2.223803494416643e-05
18 2.1721103621530347e-05
19 2.121534271282144e-05
20 2.0721412511193193e-05
21 2.024121567956172e-05
22 1.9768902348005213e-05
23 1.9312234144308604e-05
24 1.8863860532292165e-05
25 1.8424931113258936e-05
26 1.799802885216195e-05
27 1.7580221538082696e-05
28 1.717217128316406e-05
29 1.6774432879174128e-05
30 1.63829554367112e-05
31 1.6003172277123667e-05
32 1.5633719158358872e-05
33 1.5270919902832247e-05
34 1.4917017324478365e-05
35 1.457209509680979e-05
36 1.4232732610253152e-05
37 1.3903572835261002e-05
38 1.3581085113401059e-05
39 1.326597

335 1.616555422856436e-08
336 1.5848423018383073e-08
337 1.553178563540314e-08
338 1.5221107929619393e-08
339 1.4940523485051926e-08
340 1.4642538737064115e-08
341 1.434334784278235e-08
342 1.4105270729203312e-08
343 1.3793470365897065e-08
344 1.352294987100322e-08
345 1.3285543332131056e-08
346 1.3055378111914706e-08
347 1.2803324622723267e-08
348 1.2553474704191103e-08
349 1.232548285656776e-08
350 1.2109161673379276e-08
351 1.1886692519169628e-08
352 1.164633633976564e-08
353 1.1447398584607527e-08
354 1.1245131048553958e-08
355 1.1039448466476642e-08
356 1.0819158013930519e-08
357 1.0666548533322384e-08
358 1.0462398059019051e-08
359 1.0277391382373935e-08
360 1.0098773373101722e-08
361 9.947894064055163e-09
362 9.74311920032278e-09
363 9.584113946914385e-09
364 9.390482169635561e-09
365 9.241466258913533e-09
366 9.082654628400633e-09
367 8.909851523242196e-09
368 8.765723258363778e-09
369 8.604049916982603e-09
370 8.470135703930737e-09
371 8.312087906858778e-09
372 8.1612405722353