##### `torch.autograd.Function`

##### Example 1

In [None]:
x = torch.tensor([1., 2., 3.], requires_grad=True)

In [None]:
x

tensor([1., 2., 3.], requires_grad=True)

Write a custom autograd function that returns the `input` in the forward pass and returns the gradient as `input + 1` in the backward pass.

**Hints**: `ctx` is a context object that can be used to save information needed for the backward pass

In [None]:
from torch.autograd import Function

In [None]:
class MultiplyConstant(Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output + 1

To use the custom function, you need to call its `apply` method:

In [None]:
output = MultiplyConstant.apply(x)

In [None]:
output

tensor([1., 2., 3.], grad_fn=<MultiplyConstantBackward>)

In [None]:
output.backward(output)

In [None]:
x.grad

tensor([2., 3., 4.])

##### Example 2

In [None]:
class MultiplyConstant(Function):
    @staticmethod
    def forward(ctx, input, constant):
        # ctx.save_for_backward(input)
        # ctx.constant = constant
        return input

    @staticmethod
    def backward(ctx, grad_output):
        # input, = ctx.saved_tensors
        # constant = ctx.constant
        # grad_input = grad_output * constant
        return grad_output, None

##### Example 3

In [None]:
import torch

In [None]:
class Square(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # Because we are saving one of the inputs use `save_for_backward`
        # Save non-tensors and non-inputs/non-outputs directly on ctx
        ctx.save_for_backward(x)
        return x**2

    @staticmethod
    def backward(ctx, grad_out):
        # A function support double backward automatically if autograd
        # is able to record the computations performed in backward
        x, = ctx.saved_tensors
        return grad_out * 2 * x

In [None]:
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)

In [None]:
torch.autograd.gradcheck(Square.apply, x)

True

In [None]:
import torchviz

x = torch.tensor(1., requires_grad=True).clone()
out = Square.apply(x)

In [None]:
grad_x, = torch.autograd.grad(out, x, create_graph=True)