In [1]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [2]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
# device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 25364118.0
1 22642700.0
2 23942726.0
3 26157952.0
4 26384544.0
5 22665030.0
6 16146236.0
7 9713506.0
8 5333460.0
9 2908031.75
10 1704935.125
11 1109639.625
12 799797.3125
13 621511.5
14 506767.71875
15 425165.25
16 362988.0
17 313355.46875
18 272611.875
19 238530.515625
20 209694.796875
21 185100.046875
22 163979.5
23 145751.921875
24 129963.9765625
25 116230.3984375
26 104232.859375
27 93724.3359375
28 84477.8515625
29 76312.359375
30 69078.203125
31 62655.37109375
32 56935.08203125
33 51832.671875
34 47264.96484375
35 43167.2890625
36 39485.96484375
37 36170.37109375
38 33181.796875
39 30480.556640625
40 28033.55078125
41 25813.279296875
42 23795.767578125
43 21959.2578125
44 20285.939453125
45 18759.2109375
46 17364.482421875
47 16088.0458984375
48 14918.1259765625
49 13845.5390625
50 12860.7001953125
51 11955.0087890625
52 11121.3466796875
53 10353.423828125
54 9644.90234375
55 8990.615234375
56 8385.9853515625
57 7826.70361328125
58 7309.61279296875
59 6830.861328125
60 6386.929

405 0.0005467631854116917
406 0.0005301856435835361
407 0.0005145389586687088
408 0.000498355133458972
409 0.00048399457591585815
410 0.0004703231679741293
411 0.00045593257527798414
412 0.00044353908742778003
413 0.00043110435944981873
414 0.00041907664854079485
415 0.00040613001328893006
416 0.00039565813494846225
417 0.0003846243198495358
418 0.0003731827309820801
419 0.000363616447430104
420 0.00035337015287950635
421 0.0003435317485127598
422 0.0003341232368256897
423 0.00032533484045416117
424 0.0003177198814228177
425 0.00030870630871504545
426 0.00030067600891925395
427 0.00029281488968990743
428 0.0002856402425095439
429 0.0002777108456939459
430 0.00027118882280774415
431 0.0002640045713633299
432 0.0002572576922830194
433 0.0002510185295250267
434 0.0002444902784191072
435 0.0002382281090831384
436 0.00023256863642018288
437 0.00022660604736302048
438 0.0002210495586041361
439 0.00021560357708949596
440 0.00021034694509580731
441 0.0002048760507022962
442 0.00020052577019669