In [1]:
using Symbolics
using LinearAlgebra
using ForwardDiff

# Matrix Gradients

The goal of this notebook is to investigate what precisely a _matrix gradient_ is, because it is different from the concept of the _gradient with respect to a matrix_. To be precise, I will investigate the derivate of a matrix function with respect to a matrix input.

For our examples, I will use a 2x2 matrix:
$$ A = \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix}$$

In [2]:
@variables a_11 a_12 a_21 a_22;
A = [a_11 a_12; a_21 a_22]

2×2 Matrix{Num}:
 a_11  a_12
 a_21  a_22

Lets begin with the gradient of a scalar function of $A$ with respect to $A$. Then I will continue with the gradient of a matrix function of $A$ with respect to $A$.

## 1. Scalar Function
As an exmpale for a scalar function $f_s(A)$, I use the [Frobenium norm](https://mathworld.wolfram.com/FrobeniusNorm.html), which maps a matrix $A$ to a scalar.

In [3]:
f_s(X) = sqrt(sum(x^2 for x in X));  # Frobenius norm
f_s(A)

sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)

Then gradient of $f_s(A)$ with respect to $A$ is $\nabla_A f_s (A)$ is:

In [4]:
ForwardDiff.gradient(f_s, A)

2×2 Matrix{Num}:
 a_11 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)  …  a_12 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)
 a_21 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)     a_22 / sqrt(a_11^2 + a_12^2 + a_21^2 + a_22^2)

## 2. Matrix Function
Let's use matrix exponentiation $A^3 = A * A * A$, as an example of a matrix function $f_m(A)$, which maps a 2x2 matrix into another 2x2 matrix.

In [5]:
f_m(X) = X * X * X
f_m(A)

2×2 Matrix{Num}:
 a_11*(a_11^2 + a_12*a_21) + (a_11*a_12 + a_12*a_22)*a_21  …  (a_11^2 + a_12*a_21)*a_12 + (a_11*a_12 + a_12*a_22)*a_22
 a_11*(a_11*a_21 + a_21*a_22) + (a_12*a_21 + a_22^2)*a_21     (a_11*a_21 + a_21*a_22)*a_12 + (a_12*a_21 + a_22^2)*a_22

Let's ask again what the gradient of the matrix function $f_m(A)$ with respect to its input matrix $A$ is, so $\nabla_A f_m(A)$.
$$ 
f_m(A) = A * A * A
$$

and the differential is
$$
\begin{align}
df_m &=  dA A^2 + A dA A + A^2 dA
\end{align}
$$

In [6]:
df_m(x, dx) = dx*x^2 + x*dx*x + x^2*dx;

The solution below does not access to $dx$:

In [7]:
grad_A(dA) = ForwardDiff.jacobian(f_m, A) * vec(dA) # <- ForwardDiff.gradient(f_m, A) throws an exception!

grad_A (generic function with 1 method)

### Checking with finite differences
The finite difference gives a ground truth approximation to verify that the gradient implementation is correct.
    $$ df = (f(X + \delta X) - f(X)) \cdot dX $$

In [8]:
Y = randn(2, 2)

2×2 Matrix{Float64}:
 0.173128  -0.383209
 0.296065  -0.0614478

In [9]:
δY = randn(2, 2) * 1e-8

2×2 Matrix{Float64}:
 -6.54275e-9  -1.58793e-8
 -1.31257e-9   7.22189e-9

Gradient with finite difference:

In [10]:
finite_difference = f_m(Y + δY) - f_m(Y)

2×2 Matrix{Float64}:
 -1.11879e-9   3.6185e-9
 -1.56869e-9  -1.0255e-9

Gradient with linear function:

In [11]:
df_m(Y, δY)

2×2 Matrix{Float64}:
 -1.11879e-9   3.6185e-9
 -1.56869e-9  -1.0255e-9

Gradient with symbolic Jacobian:

In [12]:
substitute(grad_A(δY), Dict(a_11 => Y[1, 1], a_12 => Y[1, 2], a_21 => Y[2, 1], a_22 => Y[2, 2]))

4-element Vector{Num}:
 -1.1187921176613405e-9
 -1.5686865959619862e-9
  3.6184972720749093e-9
 -1.0254980938277533e-9

All solutions above are equal.

### Question
Are the gradients dependent on concrete $\delta Y$ that we pass?

In [13]:
δ(X) = randn(size(X)) * 1e-4

δ (generic function with 1 method)

In [14]:
Y = [1 2; 3 4]

2×2 Matrix{Int64}:
 1  2
 3  4

In [15]:
df_m(Y, δ(Y))

2×2 Matrix{Float64}:
 -0.00129718  -0.000196377
 -0.00326642  -0.000935364

In [16]:
df_m(Y, δ(Y))

2×2 Matrix{Float64}:
 -0.000551738   4.94507e-5
 -0.0031414    -0.0027555

Yes, $df$ changes dependening on $\delta Y$. But perhaps, if we divide back element-wise by $\delta Y$, we could the same gradient for any $\delta Y$?

In [17]:
d = δ(Y)
df_m(Y, d) ./ d

2×2 Matrix{Float64}:
 -68.4698   24.2113
  24.2964  -32.4979

In [18]:
d = δ(Y)
df_m(Y, d) ./ d

2×2 Matrix{Float64}:
 0.759711   -4.34671
 9.03675   -54.1683

In [19]:
d = δ(Y)
df_m(Y, d) ./ d

2×2 Matrix{Float64}:
 16.9079  -112.07
 47.0862    30.1931