In [1]:
import torch

%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [15]:
class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, *args):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass
        using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(args[0])
        return args[0].clamp(min=0)

    @staticmethod
    def backward(ctx, *grad_outputs):
        """
        In the backward pass we receive a Tensor
        containing the gradient of the loss
        with respect to the output,
        and we need to compute the gradient of the loss
        with respect to the input.
        """
        input = ctx.saved_tensors[0]
        grad_input = grad_outputs[0] * (input > 0)
        return (grad_input, )

In [16]:
dtype = torch.double
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N = 64
D_in = 1000
H = 100
D_out = 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    relu = MyReLU.apply

    y_pred = torch.mm(relu(torch.mm(x, w1)), w2)

    loss = torch.sum(torch.square(y_pred - y))
    if t % 100 == 99:
        print(t, loss.item())

    loss.backward()

    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        torch.zero_(w1.grad)
        torch.zero_(w2.grad)

99 539.6183725697779
199 3.7173742934875325
299 0.06455346738438715
399 0.0015281490866061692
499 4.080399761945463e-05
