# Tensor Differential Calculus Example: Matrix Inverse
By Eric Wong

This notebook demonstrates how to derive and implement the derivatives of a matrix valued function with matrix arguments. It further implements this in the context of backpropogation in PyTorch as a differentiable function. 

## Derivation

### Differential of an inverse
First, we can derive the well known differential rule on matrix inverses (eq 40 in the Matrix Cookbook): 

We do this by starting with the identity, $X^{-1}X = I$. Taking the differential of this equation, we have

$$0 = \mathrm d(X^{-1})X + X^{-1}\mathrm dX $$

Multiplying on the right by $X^{-1}$ and rearranging gets the desired equation: 

$$\mathrm d(X^{-1}) = -X^{-1}\mathrm (dX) X^{-1}$$

### Tensor-Jacobian form 

In order to get this into a Jacobian form, we need to rearrange this into $\mathrm d(X^{-1}) = A \cdot_T dX$ for some tensor $A$. Recall from the <a href="https://github.com/riceric22/riceric22.github.io/blob/master/notebooks/tensors.ipynb">tensor notebook</a> that we can rewrite this as a tensor dot product (this is the property that $ABC = A\otimes C \cdot_T^{2,3} B$ for matrices $A,B,C$): 

$$\mathrm d(X^{-1}) = -X^{-1}\otimes X^{-1} \cdot_T^{2,3} \mathrm dX$$

Lastly, to get this into canonical form (without needing to specify the indices), we can rotate the 4th to the 2nd position. I denote this as $\textrm{roll}_{4\rightarrow 2}$. 

$$\mathrm d(X^{-1}) = \textrm{roll}_{4\rightarrow2}(-X^{-1}\otimes X^{-1}) \cdot_T \mathrm dX$$

That's it! We now have the Jacobian of the matrix inverse function: $J_{inv} = \textrm{roll}_{4\rightarrow2}(-X^{-1}\otimes X^{-1})$. This means that $J_{inv}[i,j][k,l] = \frac{\partial X^{-1}_{i,j}}{\partial X_{k,l}}$, which contains all the partial derivatives while retaining the original structure of the function's inputs and outputs. 

### Chain rule in back propogation

While it is nice to be able to compute the Jacobian explicitly (which helps a lot in debugging gradient calculations), it is more often the case that this is just one step in unrolling an expression via the chain rule. Let's consider the case of PyTorch and suppose we have a scalar valued function $f(X^{-1})$. For notational convenience, let $g(X) = X^{-1}$, so we have $f(g(X))$. 

The usual back propogation scenario is that
* We want to compute $\frac{\partial f(X)}{\partial X}$. In PyTorch terms, this is what we want to return in backwards pass, i.e. `grad_input`. 
* We have $\frac{\partial f(X)}{\partial g(X)} = J_g$. In PyTorch terms, this is the argument `grad_output` in the backwards pass. 

We start by converting to the nicer differential form. 

1. From the latter bullet point, we can write this in differential form as 
$\mathrm df = J_g \cdot_T \mathrm dg$
2. Using the derivation from earlier, we have
$$\mathrm dg = -X^{-1}\otimes X^{-1} \cdot_T^{2,3} \mathrm dX$$
3. Now substitute!
$$\mathrm df = J_g \cdot_T (-X^{-1}\otimes X^{-1} \cdot_T^{2,3} \mathrm dX)$$
4. All that's left is to rearrange. Recall *associativity* of the tensordot operation, and noting that the first tensordot product acts on the (1,4) axis of the $(-X^{-1}\otimes X^{-1})$. 
$$\mathrm df = (J_g \cdot_T^{1,4} (-X^{-1}\otimes X^{-1}) \cdot_T \mathrm dX$$
5. Next, use the *transpose invariance*  and *commutative property*
$$\mathrm df = ((-(X^{-1})^T\otimes (X^{-1})^T \cdot_T^{2,3} J_g) \cdot_T \mathrm dX$$
6. Lastly, convert from tensordot back to standard matrix operation. 
$$\mathrm df = ((-(X^{-1})^TJ_g (X^{-1})^T ) \cdot_T \mathrm dX$$

We are done! We now have a nice matrix form for the Jacobian of $f$ with respect to $X$, given $J_g$: 
$$J_X = (-(X^{-1})^TJ_g (X^{-1})^T$$


## PyTorch Extension

To implement this in PyTorch, we simply compute this Jacobian and return it in the backwards pass as follows: 

In [3]:
import torch
from torch.autograd import Function, gradcheck, Variable

_dim_err = 'Input to inverse has {} dimension(s), expected 2'
_square_err = 'Input to inverse is not square, got dimensions {}'

class Inverse(Function): 
    """ Implementation of the Inverse Function """
    def forward(self, input): 
        """ In the forward pass, check for a square matrix, and save the resulting inverse 
        for the backwards pass. """
        assert len(input.size()) == 2, _dim_err.format(len(input.size()))
        assert input.size()[0] == input.size()[1], _square_err.format(input.size())
        inv = torch.inverse(input)
        self.inv_T = (torch.transpose(inv, 0, 1))
        return inv
    
    def backward(self, grad_output): 
        """ Compute the Jacobian explicitly in matrix form. """
        inv_T = self.inv_T
        grad_input = -torch.mm(torch.mm(inv_T, grad_output), inv_T)
        return grad_input

T = torch.randn(20,20)
T = torch.mm(T, torch.transpose(T, 0, 1))
T += torch.eye(20)
input = (Variable(T.double(), requires_grad=True),)
test = gradcheck(Inverse(), input, eps=1e-6, atol=1e-4) 
print(test)

True
