In [1]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [2]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6

for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 35139152.0
1 35131632.0
2 37414624.0
3 35390184.0
4 27062944.0
5 16222431.0
6 8288489.5
7 4100730.75
8 2248848.5
9 1431954.5
10 1037573.625
11 815257.625
12 669799.25
13 563676.75
14 481028.625
15 414226.5
16 359101.53125
17 313016.1875
18 274149.03125
19 241091.953125
20 212779.1875
21 188408.40625
22 167361.90625
23 149069.0625
24 133105.625
25 119135.953125
26 106878.984375
27 96085.59375
28 86547.234375
29 78093.859375
30 70590.75
31 63911.40234375
32 57952.359375
33 52626.625
34 47854.01171875
35 43567.140625
36 39709.875
37 36235.125
38 33102.0390625
39 30267.6015625
40 27705.2109375
41 25383.31640625
42 23275.599609375
43 21360.498046875
44 19618.76953125
45 18031.56640625
46 16584.80078125
47 15265.4326171875
48 14059.42578125
49 12962.119140625
50 11963.80078125
51 11050.5712890625
52 10212.84765625
53 9443.609375
54 8737.1796875
55 8088.0615234375
56 7490.67431640625
57 6940.80859375
58 6434.2265625
59 5967.0947265625
60 5536.162109375
61 5138.681640625
62 4771.51220703125


480 3.9455633668694645e-05
481 3.9049053157214075e-05
482 3.8600144762312993e-05
483 3.805687083513476e-05
484 3.746600850718096e-05
485 3.7060446629766375e-05
486 3.663549432530999e-05
487 3.610183921409771e-05
488 3.5652545193443075e-05
489 3.539401586749591e-05
490 3.517290315357968e-05
491 3.499575541354716e-05
492 3.4479526220820844e-05
493 3.4018354199361056e-05
494 3.369605110492557e-05
495 3.325490979477763e-05
496 3.309805833850987e-05
497 3.2828196708578616e-05
498 3.260125595261343e-05
499 3.2029274734668434e-05
