# Extending Pytorch

See details at [this](https://pytorch.org/docs/stable/notes/extending.html?highlight=extend%20autograd).

See Exp example at [this](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function).

## Function
### forward
The `forward` function in `Function` derived class is similar to that in `torch.nn.Modules` derived class. The difference is that the first parameter must be `ctx` where we can save some parameters in `forward` function and use them later in the `backward` function.

### backward
The parameters contain ctx and $\frac{\partial loss}{\partial output}$, where $loss = f(output)$ and `output` is exactly the `output` returned by `forward` method. Note that `output` can be a scalar, vector or a matrix and `loss` must be a scalar. If the `output` is a matrix, then the `grad_output` in the parameter list of `backward` method, i.e., $\frac{\partial loss}{\partial output}$ is also a matrix. 

In [1]:
import torch
from torch.autograd import Function, gradcheck

In [2]:
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        print('FORWARD: output=', output)
        return output
    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        print('BACKWARD: gard_output=', grad_output)
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight.t())
        if ctx.needs_input_grad[1]:
            grad_weight = input.t().mm(grad_output)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        print('BACKWARD: gard_input=', grad_input)
        print('BACKWARD: gard_weight=', grad_weight)
        print('BACKWARD: gard_bias=', grad_bias)
        return grad_input, grad_weight, grad_bias
linear = LinearFunction.apply

The `linear(a, b)` is the `output` returned by `forward` method in LinearFunction. Since the loss is computed by simply the element-wise summation of `output`, so the `grad_output` in `backward` method equals to 
$$
\text{grad_output} = 
 \begin{pmatrix}
  1 & 1 & 1 \\
  1 & 1 & 1 \\
 \end{pmatrix}$$.
 
 
If $y = \sum{A \times B}$, then 
1. $\frac{\partial y}{\partial A} = \frac{\partial y}{\partial u} \times \frac{\partial u}{\partial A} = \frac{\partial y}{\partial u} \times B^T, \text{where } u = A \times B$

2. $\frac{\partial y}{\partial B} = \frac{\partial y}{\partial u} \times \frac{\partial u}{\partial B} = A^T \times \frac{\partial y}{\partial u}, \text{where } u = A \times B$

In [3]:
#input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
a = torch.arange(4).view(2,2).double()       
b = torch.arange(6).view(2,3).double()       
a.requires_grad = True                      
b.requires_grad = True                      


print(a)
print(b)
output = linear(a, b)
print('output=', output)
loss = torch.sum(output)
loss.backward()

tensor([[0., 1.],
        [2., 3.]], dtype=torch.float64, requires_grad=True)
tensor([[0., 1., 2.],
        [3., 4., 5.]], dtype=torch.float64, requires_grad=True)
FORWARD: output= tensor([[ 3.,  4.,  5.],
        [ 9., 14., 19.]], dtype=torch.float64)
output= tensor([[ 3.,  4.,  5.],
        [ 9., 14., 19.]],
       dtype=torch.float64, grad_fn=<LinearFunctionBackward>)
BACKWARD: gard_output= tensor([[1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
BACKWARD: gard_input= tensor([[ 3., 12.],
        [ 3., 12.]], dtype=torch.float64)
BACKWARD: gard_weight= tensor([[2., 2., 2.],
        [4., 4., 4.]], dtype=torch.float64)
BACKWARD: gard_bias= None


## gradcheck

Check gradients computed via small finite differences against analytical gradients w.r.t. tensors in :attr:`inputs` that are of floating point type and with ``requires_grad=True``.

If the `backward` is corresponds to the `forward`, then `gradcheck` returns True, otherwise reports error.

Case 1: The correct way (We use a correct way to compute `grad_weight`)

In [4]:
class LinearFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight.t())
        if ctx.needs_input_grad[1]:
            grad_weight = input.t().mm(grad_output)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        return grad_input, grad_weight, grad_bias
linear = LinearFunction.apply

test = gradcheck(linear, (a, b), eps=1e-6, atol=1e-4)
print(test)

True


Case 2: What if the backward has problems? (We use a wrong way to compute `grad_weight`)

In [5]:
class LinearFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight.t())
        if ctx.needs_input_grad[1]:
            grad_weight = input.mm(grad_output)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        return grad_input, grad_weight, grad_bias
linear = LinearFunction.apply

test = gradcheck(linear, (a, b), eps=1e-6, atol=1e-4)
print(test)

RuntimeError: Jacobian mismatch for output 0 with respect to input 1,
numerical:tensor([[0.0000, 0.0000, 0.0000, 2.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 2.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.0000],
        [1.0000, 0.0000, 0.0000, 3.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 3.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 3.0000]],
       dtype=torch.float64)
analytical:tensor([[0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1.],
        [2., 0., 0., 3., 0., 0.],
        [0., 2., 0., 0., 3., 0.],
        [0., 0., 2., 0., 0., 3.]], dtype=torch.float64)


Example: $y = e^x$

In [6]:
class Exp(Function):
    @staticmethod
    def forward(ctx, input):
        output = input.exp()
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        output, = ctx.saved_tensors
        grad_input = grad_output * output
        return grad_input

exp = Exp.apply

x = torch.Tensor([2]).squeeze().double()
x.requires_grad = True
test = gradcheck(exp, (x,), eps=1e-6, atol=1e-4)
print(test)
output = exp(x)
print('output =', output)
output.backward()
print(x.grad)

True
output = tensor(7.3891, dtype=torch.float64, grad_fn=<ExpBackward>)
tensor(7.3891, dtype=torch.float64)
