In [1]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [2]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 35081024.0
1 28128328.0
2 21602744.0
3 15034894.0
4 9651694.0
5 6010901.5
6 3840269.75
7 2603328.0
8 1887548.375
9 1448593.125
10 1158133.0
11 951632.1875
12 796328.1875
13 674861.5625
14 577375.6875
15 497710.5625
16 431645.0
17 376271.59375
18 329491.0
19 289738.84375
20 255735.8125
21 226549.203125
22 201297.0625
23 179376.140625
24 160287.390625
25 143592.015625
26 128949.1015625
27 116086.4453125
28 104732.984375
29 94681.140625
30 85745.1171875
31 77801.3125
32 70715.875
33 64377.1171875
34 58691.6015625
35 53584.3359375
36 48986.328125
37 44838.15625
38 41089.30078125
39 37691.26171875
40 34611.16796875
41 31814.564453125
42 29271.46484375
43 26957.57421875
44 24847.767578125
45 22922.626953125
46 21163.7578125
47 19556.193359375
48 18085.419921875
49 16737.5703125
50 15500.296875
51 14365.607421875
52 13325.6474609375
53 12368.5341796875
54 11491.6865234375
55 10683.7705078125
56 9940.1591796875
57 9254.26953125
58 8620.2470703125
59 8033.951171875
60 7494.40771484375
61 6995

466 0.0001050714126904495
467 0.00010325733455829322
468 0.00010142367682419717
469 9.949003288056701e-05
470 9.831743955146521e-05
471 9.697275527287275e-05
472 9.527366637485102e-05
473 9.339013922726735e-05
474 9.197303734254092e-05
475 9.049190703080967e-05
476 8.931881166063249e-05
477 8.755223825573921e-05
478 8.638578583486378e-05
479 8.494203211739659e-05
480 8.34128659334965e-05
481 8.20680070319213e-05
482 8.081789565039799e-05
483 7.973943138495088e-05
484 7.837264274712652e-05
485 7.71837294450961e-05
486 7.596635987283662e-05
487 7.52713021938689e-05
488 7.423126953653991e-05
489 7.309508509933949e-05
490 7.195628859335557e-05
491 7.098781497916207e-05
492 6.97308496455662e-05
493 6.895632395753637e-05
494 6.783132994314656e-05
495 6.696485070278868e-05
496 6.57920609228313e-05
497 6.484766345238313e-05
498 6.407267937902361e-05
499 6.314895290415734e-05
