# Abstractions for solving optimisation problems

Techniques such as ultrasound computed tomography or optoacoustic tomography are most generally formulated as mathematical optimisation problems, which are solved numerically by using local gradient-based methods like gradient descent. 

Abstractions are then needed that allow us to pose our optimisation problems, calculate gradients
of those problems with respect to the relevant parameters, and then apply these gradients through some local
optimisation algorithm.

In this notebook, we will introduce these abstractions from the point of view of Stride.

## Mathematical basics - Gradient calculation

We will first review some of the mathematical basics behind these abstractions. Feel free to skip to the next section if you are not interested in diving into the math!

Consider a continuously differentiable function $f(\mathbf{y}) = \left\langle \hat{f}(\mathbf{y}), 1 \right\rangle$ with some bilinear form $\left\langle \alpha, \beta \right\rangle$. We know that the directional derivative of $f(\mathbf{y})$ with respect to $\mathbf{y}$ is,

$$
\nabla_\mathbf{y} f(\mathbf{y}) \delta\mathbf{y} 
    = \left\langle \nabla_\mathbf{y} \hat{f}(\mathbf{y}) \delta\mathbf{y}, 1 \right\rangle
    = \left\langle \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \delta\mathbf{y} \right\rangle
$$

Consider now that $\mathbf{y} = \mathbf{g}(\mathbf{z})$ is another continuously differentiable function. Then the derivative of $f(\mathbf{y})$ with respect to $\mathbf{z}$ is,

$$
    \nabla_\mathbf{z} f(\mathbf{y}) \delta\mathbf{z} 
    = \left\langle \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \delta\mathbf{y} \right\rangle
    = \left\langle \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \nabla_\mathbf{z} \mathbf{g}(\mathbf{z}) \delta\mathbf{z} \right\rangle
$$

by virtue of the product rule. Let's now introduce the concept of the adjoint of an operator: given an operator $D\cdot$, its adjoint is $D^*\cdot$, defined so that $\left\langle a, Db  \right\rangle = \left\langle b, D^*a  \right\rangle$. Then, we can rewrite the expression as,

$$
    \nabla_\mathbf{z} f(\mathbf{y}) \delta\mathbf{z} 
    = \left\langle \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \nabla_\mathbf{z} \mathbf{g}(\mathbf{z}) \delta\mathbf{z} \right\rangle
    = \left\langle \nabla_\mathbf{z}^* \mathbf{g}(\mathbf{z}) \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \delta\mathbf{z} \right\rangle
$$

That is, the derivative of function $f(\mathbf{y})$ with respect to $\mathbf{z}$ can be calculated by finding the derivative of $\hat{f}(\mathbf{y})$ with respect to its input $\mathbf{y}$ and then applying the adjoint of the Jacobian of $\mathbf{g}(\mathbf{z})$ on the result. In the discrete case, this is equivalent to the Jacobian-vector product. 

Similarly, if we added a third function $\mathbf{z} = \mathbf{h}(\mathbf{x})$, then the same result could be obtained for the derivative of $f(\mathbf{y})$ with respect to $\mathbf{x}$,

$$
    \nabla_\mathbf{x} f(\mathbf{y}) \delta\mathbf{x} 
    = \left\langle \nabla_\mathbf{z}^* \mathbf{g}(\mathbf{z}) \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \delta\mathbf{z} \right\rangle \\
    = \left\langle \nabla_\mathbf{z}^* \mathbf{g}(\mathbf{z}) \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \nabla_\mathbf{x} \mathbf{h}(\mathbf{x}) \delta\mathbf{x} \right\rangle \\
    = \left\langle \nabla_\mathbf{x}^* \mathbf{h}(\mathbf{x}) \nabla_\mathbf{z}^* \mathbf{g}(\mathbf{z}) \nabla_\mathbf{y} \hat{f}(\mathbf{y}), \delta\mathbf{x} \right\rangle
$$

and the same procedure could be followed for any arbitrary chain of functions for whose inputs we wanted to calculate a derivative. This procedure, known as the adjoint method or backpropagation in the field of machine learning, is effectively the reverse mode that automatic differentiation libraries provide to calculate derivatives and the core abstraction used in Stride.

## Gradient calculation in Stride

Stride considers all components in the optimisation problem, from partial differential equations to objective functions, as mathematical functions that can be arbitrarily composed, and whose derivative can be automatically calculated. In Stride, each of these functions is a ``stride.Operator`` object, where their inputs and outputs are ``stride.Variable`` objects.

Let's see how this works by creating a ``stride.Scalar`` object ``x``, which inherits from ``stride.Variable``, and using Stride to calculate the gradient of some arbitrary functions with respect to to ``x``.

In [1]:
from stride import Scalar
from stride_examples import f, g, h

x = Scalar(name="x", needs_grad=True)
z = await h(x)
y = await g(z)
w = await f(y)

w.clear_grad()
await w.adjoint()
# The gradient is now in "x.grad"

x

When each ``stride.Operator`` is called, it is immediately applied on its inputs to generate some outputs. At the same time, these outputs keep a record of the chain of calls that have led to them within a directed acyclic graph. When ``w.adjoint()`` is called, this graph is traversed from the root ``w`` to the leaf ``x``, calculating the gradient in the process. Only the leaves for which the flag ``needs_grad`` is set to ``True`` will have their gradient computed, which will be stored in the internal buffer of the variable ``x.grad``.

## Mathematical basics - PDE-constrained optimisation

Now, we proceed to apply these general abstractions to find the gradient of a more practical optimisation problem. This section will contain some more math, so feel free to jump to the next section if you are not interested. 

Consider the PDE-constrained optimisation problem,

$$
    \mathbf{m}^* = argmin_{\mathbf{m}} J(\mathbf{u}, \mathbf{m}) = 
    argmin_{\mathbf{m}} \left\langle \hat{J}(\mathbf{u}, \mathbf{m}), 1 \right\rangle
$$
$$
    s.t.\; \mathbf{L}(\mathbf{u},\mathbf{m}) = \mathbf{0}
$$

given some scalar objective function or loss function $J(\mathbf{u}, \mathbf{m})$ and some PDE $\mathbf{L}(\mathbf{u},\mathbf{m}) = \mathbf{0}$, for some vector of state variables $\mathbf{u}$ and a vector of design variables $\mathbf{m}$. 

If we consider $\mathbf{L}(\mathbf{u},\mathbf{m})$ to be an adequate, continuously differentiable function in some neighbourhood of $\mathbf{m}$, we can apply the implicit function theorem. Then $\mathbf{L}(\mathbf{u},\mathbf{m}) = \mathbf{0}$ has a unique continuously differentiable solution $\mathbf{u}(\mathbf{m})$, whose derivative is given by the solution of,

$$
    \nabla_\mathbf{u}\mathbf{L}(\mathbf{u}(\mathbf{m}), \mathbf{m}) \nabla_\mathbf{m}\mathbf{u}(\mathbf{m}) \delta\mathbf{m} +
    \nabla_\mathbf{m}\mathbf{L}(\mathbf{u}(\mathbf{m}), \mathbf{m}) \delta\mathbf{m} = \mathbf{0}
$$
$$
    \nabla_\mathbf{m}\mathbf{u}(\mathbf{m})\delta\mathbf{m} = - \nabla_\mathbf{u}\mathbf{L}^{-1}(\mathbf{u}(\mathbf{m}), \mathbf{m})
    \nabla_\mathbf{m}\mathbf{L}(\mathbf{u}(\mathbf{m}), \mathbf{m}) \delta\mathbf{m}
$$

We can then define a reduced objective $F(\mathbf{m}) = J(\mathbf{u}(\mathbf{m}), \mathbf{m}) = \left\langle \hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), 1 \right\rangle$, and we can take its derivative with respect to $\mathbf{m}$,

$$
    \nabla_\mathbf{m} F(\mathbf{m})(\delta \mathbf{m}) = 
    \left\langle \nabla_\mathbf{u}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \nabla_\mathbf{m}\mathbf{u}(\mathbf{m})\delta\mathbf{m} \right\rangle 
    + \left\langle \nabla_\mathbf{m}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta \mathbf{m} \right\rangle 
    = \left\langle \nabla_\mathbf{m}^*\mathbf{u}(\mathbf{m}) \nabla_\mathbf{u}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta\mathbf{m} \right\rangle 
    + \left\langle \nabla_\mathbf{m}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta \mathbf{m} \right\rangle
$$

After some substitutions we obtain,

$$
    \nabla_\mathbf{m} F(\mathbf{m})(\delta \mathbf{m}) = 
    \left\langle \nabla_\mathbf{m}^*\mathbf{u}(\mathbf{m}) \nabla_\mathbf{u}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta\mathbf{m} \right\rangle 
    + \left\langle \nabla_\mathbf{m}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta \mathbf{m} \right\rangle 
    = - \left\langle \nabla_\mathbf{m}\mathbf{L}^*(\mathbf{u}(\mathbf{m}), \mathbf{m})
    \nabla_\mathbf{u}\mathbf{L}^{-*}(\mathbf{u}(\mathbf{m}), \mathbf{m}) \right. 
     \left. \nabla_\mathbf{u}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta\mathbf{m} \right\rangle 
    + \left\langle \nabla_\mathbf{m}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta \mathbf{m} \right\rangle 
    = \left\langle \nabla_\mathbf{m}\mathbf{L}^*(\mathbf{u}(\mathbf{m}), \mathbf{m}) \mathbf{w}(\mathbf{m}), \delta\mathbf{m} \right\rangle 
    + \left\langle \nabla_\mathbf{m}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m}), \delta \mathbf{m} \right\rangle
$$

where $\mathbf{w}(\mathbf{m})$ is the solution of the adjoint PDE,

$$
    \mathbf{w}(\mathbf{m}) = 
    - \nabla_\mathbf{u}\mathbf{L}^{-*} (\mathbf{u}(\mathbf{m}), \mathbf{m})
    \nabla_\mathbf{u}\hat{J}(\mathbf{u}(\mathbf{m}), \mathbf{m})
$$

In this optimisation problem, both $\mathbf{L}(\mathbf{u}, \mathbf{m})$ and $J(\mathbf{u}, \mathbf{m})$ would be ``stride.Operator`` objects.

## Stride operators

Adding new functions to Stride requires defining a new ``stride.Operator`` subclass that implement two methods, ``forward`` and ``adjoint``.

Let's see how we can do this for a function that represents the PDE ``L`` and one that represents a loss function ``J``. We will then use them to calculate the gradient with respect to the ``stride.Scalar`` ``m``.

In [2]:
from stride import Operator, Scalar

class L(Operator):
    """
    L represents a partial differential equation and its adjoint.
    
    """
    def forward(self, m):
        u = m.alike()
        # Compute wave equation solution
        return u
        
    def adjoint(self, grad_u, m):
        grad_m = m.alike()
        # Calculate derivative wrt to m
        # applying adjoint on grad_u
        return grad_m
        
class J(Operator):
    """
    J represents a loss function or functional.
    
    """
    def forward(self, u, m):
        loss = Scalar()
        # Calculate loss value
        return loss
        
    def adjoint(self, grad_loss, u, m):
        grad_u = u.alike()
        # Calculate the derivative wrt u
        grad_m = m.alike()
        # Calculate the derivative wrt m
        return grad_u, grad_m
        
# Create the design parameters
m = Scalar(name="m")
m.needs_grad = True

# Instantiate the operators
l = L()
j = J()

# Apply to calculate gradient
u = await l(m)
loss = await j(u, m)

m.clear_grad()
await loss.adjoint()
# The gradient is now in "m.grad"

scalar

## Applying the gradients

The abstractions presented allow us to intuitively pose optimisation problems and calculate derivatives of an objective function with respect to the parameters of interest. However, in order to solve the problem, we have to apply this derivative to update our guess of the parameters and repeat the procedure iteratively until we are satisfied with the final result.

Stride provides local optimisers of type ``stride.Optimiser`` that determine how parameters should be updated given an available derivative. 

For our previous example, we can then apply a step of gradient descent in the direction of our calculated derivative by using the class ``stride.GradientDescent``.

In [3]:
from stride import GradientDescent

optimiser = GradientDescent(m, step_size=1.)
await optimiser.step()

Updating variable m,
	 grad before processing in range [2.627646e-17, 2.627646e-17]
	 grad after processing in range [1.313823e-19, 1.313823e-19]
	 variable range before update [2.627646e-17, 2.627646e-17]
	 taking final update step of 1.000000e+00 [unclipped step of 1.000000e+00]
	 variable range after update [2.614507e-17, 2.614507e-17]


m

In order to iterate through the optimisation procedure, we could use a standard Python ``for`` loop. However, we also provide in Stride a ``stride.OptimisationLoop`` to use in these cases, which will help structure and keep track of the optimisation process. 

Iterations in Stride are grouped together in blocks, with the ``stride.OptimisationLoop`` containing multiple blocks and each block containing multiple iterations. Partitioning the inversion in this way allows us to divide the optimisation more easily into logical units that share some characteristics. For instance, in FWI it is common to gradually introduce frequency information into the inversion to better condition the optimisation. In this case, it would make sense to assign one block to each frequency band, and run that band for some desired number of iterations. 

Let's add an ``stride.OptimisationLoop`` around our previous example.

In [4]:
from stride import OptimisationLoop

opt_loop = OptimisationLoop()

num_blocks = 2
num_iters = 3

for block in opt_loop.blocks(num_blocks):
    for iteration in block.iterations(num_iters):
        m.clear_grad()
        
        u = await l(m)
        loss = await j(u, m)
        await loss.adjoint()
        
        await optimiser.step()

Updating variable m,
	 grad before processing in range [5.242153e-17, 5.242153e-17]
	 grad after processing in range [1.307254e-19, 1.307254e-19]
	 variable range before update [2.614507e-17, 2.614507e-17]
	 taking final update step of 1.000000e+00 [unclipped step of 1.000000e+00]
	 variable range after update [2.601435e-17, 2.601435e-17]
Updating variable m,
	 grad before processing in range [7.843588e-17, 7.843588e-17]
	 grad after processing in range [1.300717e-19, 1.300717e-19]
	 variable range before update [2.601435e-17, 2.601435e-17]
	 taking final update step of 1.000000e+00 [unclipped step of 1.000000e+00]
	 variable range after update [2.588428e-17, 2.588428e-17]
Updating variable m,
	 grad before processing in range [1.043202e-16, 1.043202e-16]
	 grad after processing in range [1.294214e-19, 1.294214e-19]
	 variable range before update [2.588428e-17, 2.588428e-17]
	 taking final update step of 1.000000e+00 [unclipped step of 1.000000e+00]
	 variable range after update [2.575