In [1]:
# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable

In [2]:
class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [3]:
dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)


In [4]:
learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations on Variables; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data

    # Manually zero the gradients after updating weights
    w1.grad.data.zero_()
    w2.grad.data.zero_()

0 29878930.0
1 23897038.0
2 21041410.0
3 18302322.0
4 14897789.0
5 11086145.0
6 7661648.5
7 5054399.0
8 3304095.5
9 2200183.0
10 1522990.25
11 1104914.5
12 839335.1875
13 663139.4375
14 540367.5
15 450562.71875
16 382020.65625
17 327856.21875
18 283928.625
19 247562.1875
20 217033.0625
21 191132.640625
22 168976.3125
23 149885.890625
24 133383.0625
25 119026.3203125
26 106483.921875
27 95487.0859375
28 85817.0546875
29 77285.8828125
30 69733.421875
31 63032.01953125
32 57067.1171875
33 51747.046875
34 46996.94921875
35 42749.14453125
36 38936.171875
37 35511.55859375
38 32429.873046875
39 29649.5859375
40 27136.15234375
41 24862.12109375
42 22800.984375
43 20930.66796875
44 19229.552734375
45 17681.6953125
46 16271.69140625
47 14985.9345703125
48 13811.8359375
49 12738.6669921875
50 11756.6826171875
51 10857.53125
52 10033.4033203125
53 9277.4228515625
54 8583.5498046875
55 7945.99658203125
56 7359.97021484375
57 6820.43603515625
58 6323.6298828125
59 5865.9072265625
60 5443.8393554687

407 0.00017896294593811035
408 0.00017513579223304987
409 0.00017127097817137837
410 0.0001668758923187852
411 0.00016330623475369066
412 0.00015963420446496457
413 0.00015618209727108479
414 0.00015311270544771105
415 0.00015096119022928178
416 0.00014802349323872477
417 0.00014500167162623256
418 0.00014170428039506078
419 0.000139113180921413
420 0.0001368122175335884
421 0.00013427190424408764
422 0.00013156070781406015
423 0.00012934257392771542
424 0.00012631944264285266
425 0.00012437612167559564
426 0.00012197007890790701
427 0.00011952786735491827
428 0.00011776256724260747
429 0.00011542669381015003
430 0.00011313352297293022
431 0.0001108312644646503
432 0.00010887372627621517
433 0.00010729231871664524
434 0.00010524308163439855
435 0.00010336748528061435
436 0.00010127449058927596
437 9.988304373109713e-05
438 9.823378786677495e-05
439 9.660375508246943e-05
440 9.524733468424529e-05
441 9.377385140396655e-05
442 9.17289435165003e-05
443 9.032275556819513e-05
444 8.87448768