In [1]:
using Symbolics
using LinearAlgebra
using ForwardDiff

# Matrix Gradients

The goal of this notebook is to investigate what precisely a _matrix gradient_ is, because it is different from the concept of the _gradient with respect to a matrix_. To be precise, I will investigate the derivate of a matrix function with respect to a matrix input.

For our examples, I will use a 2x2 matrix:
$$ A = \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix}$$

In [2]:
@variables a_11 a_12 a_21 a_22;
A = [a_11 a_12; a_21 a_22]

2×2 Matrix{Num}:
 a_11  a_12
 a_21  a_22

Lets begin with the gradient of a scalar function of $A$ with respect to $A$. Then I will continue with the gradient of a matrix function of $A$ with respect to $A$.

## 1. Scalar Function
As an exmpale for a scalar function $f_s(A)$, I use the [Frobenium norm](https://mathworld.wolfram.com/FrobeniusNorm.html), which maps a matrix $A$ to a scalar.

In [3]:
f_s(X) = sqrt(sum(x^2 for x in X));  # Frobenius norm
f_s(A)

sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)

Then gradient of $f_s(A)$ with respect to $A$ is $\nabla_A f_s (A)$ is:

In [4]:
ForwardDiff.gradient(f_s, A)

2×2 Matrix{Num}:
 a_11 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)  …  a_12 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)
 a_21 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)     a_22 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)

## 2. Matrix Function
Let's use matrix exponentiation $A^3 = A * A * A$, as an example of a matrix function $f_m(A)$, which maps a 2x2 matrix into another 2x2 matrix.

In [5]:
f_m(X) = X * X * X
f_m(A)

2×2 Matrix{Num}:
 a_11*(a_11^2 + a_12*a_21) + (a_11*a_12 + a_12*a_22)*a_21  …  (a_11^2 + a_12*a_21)*a_12 + (a_11*a_12 + a_12*a_22)*a_22
 a_11*(a_11*a_21 + a_21*a_22) + (a_12*a_21 + a_22^2)*a_21     (a_11*a_21 + a_21*a_22)*a_12 + (a_12*a_21 + a_22^2)*a_22

Let's ask again what the gradient of the matrix function $f_m(A)$ with respect to its input matrix $A$ is, so $\nabla_A f_m(A)$.
$$ 
f_m(A) = A * A * A
$$

and the differential is
$$
\begin{align}
df_m &=  dA A^2 + A dA A + A^2 dA
\end{align}
$$

In [6]:
df_m(x, dx) = dx*x^2 + x*dx*x + x^2*dx;

The solution below does not access to $dx$:

In [7]:
grad_A = ForwardDiff.jacobian(f_m, A) * vec(A) # <- ForwardDiff.gradient(f_m, A) throws an exception!

4-element Vector{Num}:
          a_11*(3(a_11^2) + 2a_12*a_21) + (2a_11*a_12 + a_12*a_22)*a_21 + ((a_11 + a_22)*a_21 + a_11*a_21)*a_12 + a_12*a_21*a_22
 (a_11*(a_11 + a_22) + 2a_12*a_21 + a_22^2)*a_21 + a_11*(2a_11*a_21 + a_21*a_22) + (a_11*a_21 + 2a_21*a_22)*a_22 + a_12*(a_21^2)
 (a_11^2 + (a_11 + a_22)*a_22 + 2a_12*a_21)*a_12 + a_11*(2a_11*a_12 + a_12*a_22) + (a_11*a_12 + 2a_12*a_22)*a_22 + (a_12^2)*a_21
          a_11*a_12*a_21 + ((a_11 + a_22)*a_12 + a_12*a_22)*a_21 + (a_11*a_21 + 2a_21*a_22)*a_12 + (2a_12*a_21 + 3(a_22^2))*a_22

### Checking with finite differences
The finite difference gives a ground truth approximation to verify that the gradient implementation is correct.
    $$ df = f(X + \delta X) - f(X) * dX $$

In [8]:
Y = randn(2, 2)

2×2 Matrix{Float64}:
 -0.614174   1.60984
  1.73402   -0.302657

In [9]:
δY = randn(2, 2) * 1e-10

2×2 Matrix{Float64}:
 -1.8394e-10    1.09165e-11
 -4.72675e-12  -3.25168e-11

In [10]:
finite_difference = f_m(Y + δY) - f_m(Y)

2×2 Matrix{Float64}:
 -1.34319e-9    5.73031e-10
  5.60424e-10  -7.17747e-10

In [11]:
df_m(Y, δY)

2×2 Matrix{Float64}:
 -1.34318e-9    5.73031e-10
  5.60423e-10  -7.17747e-10

My solution below does not yet yield result equal to the finite difference method:

In [12]:
substitute(grad_A, Dict(a_11 => Y[1, 1], a_12 => Y[1, 2], a_21 => Y[2, 1], a_22 => Y[2, 2]))

4-element Vector{Num}:
 -13.51638761045899
  17.92726613749402
  16.643410597169613
 -10.2957502450735