The differential formulation of the self-consistency constraint requires 
enforcing of an equation of the form

$
v_G(x) = \dfrac{\partial G(a,g)(x)}{\partial a(y)} v_a(y) + \dfrac{\partial G(a,g)(x)}{\partial g(\sigma)} v_G(\sigma)
$

where $v_a(y) = y\cdot \nabla_y a(y)$ and $v_G(\sigma) = \sigma\cdot\nabla_y G(a,g)(\sigma)$ are defined 
on $\Omega$ and the boundary $\partial \Omega$, respectively.

The terms $\dfrac{\partial G}{\partial a} v_a$  and  $\dfrac{\partial G}{\partial g} v_G$ are directional derivatives of the model $G$ in the direction of $v_a$ and $v_G$.

It seems computation of such directional derivatives is supported by pytorch!
- https://pytorch.org/tutorials/intermediate/forward_ad_usage.html
- This blog-entry discusses performance of different possibilities (but may be outdated -- running code from there, I get deprecated warnings): https://leimao.github.io/blog/PyTorch-Automatic-Differentiation/


Fromt the pytorch tutorial: uses functorch

In [10]:
import functorch as ft

# value
primal0 = torch.randn(10, 10)
primal1 = torch.randn(10, 10)

# direction vector
tangent0 = torch.ones(10, 10)
tangent1 = torch.ones(10, 10)

# function to be differentiated
def fn(x, y):
    return x ** 2 + y ** 2

# note: the jacobian-vector-product with (1,1) for the above fn is 2*(x+y),
#       i.e. fn_jac(x,y)[(1,1)] = 2*(x+y)

# Here is a basic example to compute the JVP of the above function.
# The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the
# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.
primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))

# check that we get expected results
def fn_directional(x, y, vx, vy):
    return 2*x*vx + 2*y*vy

primal_check, tangent_check = fn(primal0,primal1), fn_directional(primal0,primal1,tangent0,tangent1)

assert torch.allclose(primal_out,primal_check), 'primals wrong??'
assert torch.allclose(tangent_out,tangent_check), 'tangents wrong??'
print('Checks successful!')


# ``functorch.jvp`` requires every primal to be associated with a tangent.
# If we only want to associate certain inputs to `fn` with tangents,
# then we'll need to create a new function that captures inputs without tangents:
primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)
y = torch.randn(10, 10)

# the following computes a partial derivative in only one direction
import functools
new_fn = functools.partial(fn, y=y)
primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))

Checks successful!


### It seems, the way to implement the self-consistency constraint for FNO would be as follows(?)

$
v_G(x) = \dfrac{\partial G(a,g)(x)}{\partial a(y)} v_a(y) + \dfrac{\partial G(a,g)(x)}{\partial g(\sigma)} v_G(\sigma)
$

where $v_a(y) = y\cdot \nabla_y a(y)$ and $v_G(\sigma) = \sigma\cdot\nabla_y G(a,g)(\sigma)$

In [11]:
# assume we have access to:
# 1. a function "xDx" computing (approximate) radial derivative of input
# 2. a function "restriction" computing the restriction of function to boundary

import functorch as ft

#
G = model(a,g)
tangent_a = xDx(a)
tangent_g = restriction(xDx(G))

#
LHS = tangent_G
_, RHS = ft.jvp(model, (a,g), (tangent_a, tangent_g))

#
loss = loss_fn(LHS - RHS)

NameError: name 'model' is not defined