In [1]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [2]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 29054064.0
1 25657364.0
2 28104228.0
3 32269688.0
4 33563164.0
5 28637716.0
6 19101244.0
7 10342804.0
8 5067596.0
9 2585022.75
10 1510459.25
11 1028231.125
12 783344.75
13 637096.25
14 536253.875
15 459728.625
16 398431.75
17 347762.15625
18 305214.9375
19 269118.28125
20 238287.5625
21 211781.78125
22 188851.25
23 168933.53125
24 151549.03125
25 136308.125
26 122902.390625
27 111073.078125
28 100600.671875
29 91324.1328125
30 83070.046875
31 75704.6640625
32 69111.0
33 63191.6796875
34 57862.84375
35 53058.58984375
36 48718.3671875
37 44790.9609375
38 41234.26953125
39 38004.0546875
40 35066.375
41 32389.73046875
42 29948.15625
43 27717.8515625
44 25678.353515625
45 23811.451171875
46 22099.123046875
47 20527.67578125
48 19082.919921875
49 17754.0703125
50 16530.421875
51 15403.0087890625
52 14363.220703125
53 13403.572265625
54 12516.25390625
55 11694.740234375
56 10934.021484375
57 10229.349609375
58 9575.646484375
59 8968.7734375
60 8405.2099609375
61 7881.32177734375
62 7393.864

404 0.004738704767078161
405 0.004582607187330723
406 0.004436916671693325
407 0.004291645251214504
408 0.004153021145612001
409 0.004018784034997225
410 0.0038907628040760756
411 0.003769516246393323
412 0.0036483623553067446
413 0.003534065093845129
414 0.0034203235991299152
415 0.0033133020624518394
416 0.0032083282712846994
417 0.0031095552258193493
418 0.0030122706666588783
419 0.0029170173220336437
420 0.0028239742387086153
421 0.0027423426508903503
422 0.002656154567375779
423 0.0025746093597263098
424 0.0024968008510768414
425 0.002421833109110594
426 0.0023483342956751585
427 0.0022770653013139963
428 0.0022092743311077356
429 0.002143465681001544
430 0.002080705715343356
431 0.0020195599645376205
432 0.0019602628890424967
433 0.0019039383623749018
434 0.0018479425925761461
435 0.001794559066183865
436 0.001741899410262704
437 0.0016920154448598623
438 0.0016422360204160213
439 0.0015963874757289886
440 0.001549141830764711
441 0.0015050973743200302
442 0.0014651136007159948
4