In [1]:
# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable

In [2]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

In [3]:
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

In [4]:
# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Variables for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

In [5]:
# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(size_average=False)

In [6]:
learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Variable of input data to the Module and it produces
    # a Variable of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Variables containing the predicted and true
    # values of y, and the loss function returns a Variable containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Variables with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Variable, so
    # we can access its data and gradients like we did before.
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data

0 653.754150390625
1 604.8116455078125
2 562.9284057617188
3 526.3858642578125
4 493.84210205078125
5 464.5747375488281
6 437.7945861816406
7 413.3089294433594
8 390.8015441894531
9 369.94122314453125
10 350.336181640625
11 331.95782470703125
12 314.7698669433594
13 298.53753662109375
14 283.1656188964844
15 268.5443115234375
16 254.6393585205078
17 241.33905029296875
18 228.64451599121094
19 216.5913848876953
20 205.0789031982422
21 194.12847900390625
22 183.67974853515625
23 173.7438201904297
24 164.28929138183594
25 155.26695251464844
26 146.6893310546875
27 138.49925231933594
28 130.69253540039062
29 123.28348541259766
30 116.27928924560547
31 109.64290618896484
32 103.37094116210938
33 97.45169830322266
34 91.87053680419922
35 86.59980773925781
36 81.6212387084961
37 76.92918395996094
38 72.50639343261719
39 68.33364868164062
40 64.39875793457031
41 60.69244384765625
42 57.19712448120117
43 53.908226013183594
44 50.81309127807617
45 47.89667892456055
46 45.15626525878906
47 42.577

399 0.00012956478167325258
400 0.00012589411926455796
401 0.00012232636800035834
402 0.00011886360152857378
403 0.00011550266208359972
404 0.00011223548790439963
405 0.0001090596488211304
406 0.00010598018707241863
407 0.0001029886188916862
408 0.00010008101526182145
409 9.725766722112894e-05
410 9.451406367588788e-05
411 9.185069939121604e-05
412 8.926396549213678e-05
413 8.675303251948208e-05
414 8.431493915850297e-05
415 8.194475958589464e-05
416 7.964025280671194e-05
417 7.739647844573483e-05
418 7.52237974666059e-05
419 7.311221997952089e-05
420 7.10595995769836e-05
421 6.906721682753414e-05
422 6.712978210998699e-05
423 6.524617492686957e-05
424 6.342099368339404e-05
425 6.16444376646541e-05
426 5.991887519485317e-05
427 5.823955143569037e-05
428 5.661269824486226e-05
429 5.503001739270985e-05
430 5.349464845494367e-05
431 5.1997216360177845e-05
432 5.054855500929989e-05
433 4.913657539873384e-05
434 4.776730929734185e-05
435 4.643502688850276e-05
436 4.514312604442239e-05
437 4.