# Implicit differentiation

The other notebook was all about defining primatives.  For each of those primatives, we had some kind of basic function we wanted to add to the library and we knew the functional form of its derivative.  Adding the primative was just an exercise in writing that all down in code.  The cosine example showed us a few potential benefits to creating primatives instead of using existing ones.

* Smaller computational graph means less memory.
* Option to pull in code that's not compatible with automatic differentiation.
* Numerical stability.

Sometimes, we can't just write down the derivative because the function is mess but we can use a trick called implicit differentiation to solve for the derivative.  That's what this notebook is about.  We're going two cover two variaties of "Implicit Layers" as the creators call them - fixed-point layers and optimization layers.  Neural ODEs is another use for this technology but we won't cover that here.

One defining feature of both fixed-point layers and optimization layer is that the output of the function is also defined implicitly.


## Fixed-point layers

In [17]:
from typing import Union, List, Any
from semiautograd import Scalar, Function, trace, backward, reset_grad
from semiautograd import Abs, Cos, Plus, Times, Sum, Pow

First up is a fixed-point calculation.  Suppose we want to differentiate the solution to a fixed-point problem with respect to some other data in the problem.  For this example we'll look for a fixed point
$$x = \cos(a x + b)$$
and the question is, what are $d x/d a$ and $d x/d b$?.

Well, cosine is differentiable and it's a nice enough function that we can find the fixed point just by iteratively applying the function.   We already have cosine available in semiautograd, so it's easy to code up the fixed point iteration.

In [18]:
def iterativefixedpoint(*args, fun=None, x0=0):
    ''' Find a fixed point x = fun(x,*args) by iterating until convergence

    Arguments:
        args -- List[Scalar] or List[float] arguments to fun
        fun -- the function
        x0 -- initial guess
    Returns:
        the fixed point, either as a Scalar or as a float
    '''
    eps = 1e-7
    oldx = Scalar(x0-1)
    x = Scalar(x0)
    isscalar = isinstance(args[0],Scalar)
    args = [a if isinstance(a,Scalar) else Scalar(a) for a in args]
    ii=0
    while ii<100 and Abs( Plus(x, Times(Scalar(-1), oldx)))>eps: # | x - oldx |
        oldx = x
        x = fun(x,*args)
        ii += 1
    if ii==100:
        print("Failed to converge")
    if isscalar:
        return x
    return x.value

def cosaxb(x, a, b):
    return Cos(Plus(b, Times(a,x)))
    
a = Scalar(1)
b = Scalar(0)
x = iterativefixedpoint(a, b, fun=cosaxb)
print(f'{x=}')
print(f'{len(trace(x))=}')
backward(x)
print(f'{a.grad=}, {b.grad=}')

print('')

a = Scalar(-0.5)
b = Scalar(0.5)
x = iterativefixedpoint(a, b, fun=cosaxb)
print(f'{x=}')
print(f'{len(trace(x))=}')
backward(x)
print(f'{a.grad=}, {b.grad=}')


x=0.7390851699445545 = Cos(0.7390850786891229)
len(trace(x))=126
a.grad=-0.2974721293726153, b.grad=-0.40248894024267656

x=1.0 = Cos(1.199040866595169e-14)
len(trace(x))=18
a.grad=-1.1990409980613424e-14, b.grad=-1.1990409980622253e-14


The answers are right but 126 nodes in that trace is a lot.  Trying to solve for x and then differentiating in order to create a primative is probably hopeless.  Here comes implicit differentiation with $f(x,a,b)=\cos(a x + b)$,
$$\partial x/\partial a = df(x(a),a,b)/da = \partial f(x,a,b) / \partial a + (\partial f(x,a,b) / \partial x) \times (\partial x/\partial a).$$
Collecting the terms above we get
$$\partial x/\partial a = \frac{\partial f(x,a,b) / \partial a}{1-\partial f(x,a,b)/\partial x}.$$

There's no reason to care how we computed the fixed-point x, implicit differentiation just magic-ed it's derivative into existence from the solution.  While we could actually work this out for $\cos(a x + b)$, that would kind of limit our fixed-point finding to functions with easy derivatives.  If only we had a generic way to compute the partial derivatives on the right hand side above, oh snap!

In [19]:
def fpbackward(*args, fun=None, x0=0):
    x = fpforward(*args,fun=fun, x0=x0) #recompute forward b/c we didn't store it
    sargs = [Scalar(a) for a in args]
    sx = Scalar(x)
    v = fun(sx,*sargs)  # Put the solution into the function
    backward(v)         # compute the partial derivatives
    return [sa.grad / (1-sx.grad) for sa in sargs]


FixedPoint = Function("FixedPoint", iterativefixedpoint, fpbackward)

a = Scalar(1)
b = Scalar(0)
x = FixedPoint(a,b,fun=cosaxb)
print(f'{trace(x)=}')
backward(x)
print(f'{a.grad=}, {b.grad=}')


trace(x)=[0.7390851699445545 = FixedPoint(1,0,fun=<function cosaxb at 0x1107ae980>), 0, 1]
a.grad=-0.29747436345601425, b.grad=-0.402489964016367


## Optimization layers

In [17]:
df(a,b,x) / dx = 0
f'(a,b,x) = 0
df'(a,b,x)/da = 0 = df'/da + df'/dx dx/da$#
dx/da = -df'/da / df'/dx

SyntaxError: unterminated string literal (detected at line 2) (2830798651.py, line 2)

In [4]:
def minimizeforward(*args, fun=None, dfundx=None, x0=0, lr=0.1, max_iter=1000):
    x = x0
    ii=0
    notconverged = True
    eps = 1e-3
    while (ii<max_iter) and notconverged:
        ii+=1
        s_args = [Scalar(a) for a in args]
        s_x = Scalar(x)
        y = fun(s_x, *s_args)
        backward(y)
        x += -lr * s_x.grad
        notconverged = abs(s_x.grad)>eps
    if notconverged:
        print('minimizeforward Failed to converge')
    return x

def minimizebackward(*args, fun=None, dfundx=None, x0=0, lr=0.1, max_iter=1000):
    x = Scalar(minimizeforward(*args, fun=fun, dfundx=dfundx, x0=x0, lr=lr, max_iter=max_iter))
    s_args = [Scalar(a) for a in args]
    y = dfundx(x,*s_args)
    backward(y)
    if abs(x.grad)<1e-3:
        print('minimizebackward does not support functions with d^2f/dx^2 = 0')
    return [-a.grad / x.grad for a in s_args] # requires dfundx(x)!=0
    
def quadratic(x,a,b):
    return Sum(Times(a, Pow(x,p=2)), Times(b, x))

def dquadraticdx(x,a,b):
    return Sum(Times(a, Times(x,Scalar(2))), b)
    
minimizeforward(1,0,fun=quadratic,dfundx=None,x0=0)
Minimize = Function("Minimize", minimizeforward, minimizebackward)

In [5]:
a = Scalar(1)
b = Scalar(0)
x = Minimize(a,b,fun=quadratic,dfundx=dquadraticdx, x0=0.1)
print(x)
backward(x)
display(trace(x))

0.0003777893186295717 = Minimize(1,0,fun=<function quadratic at 0x1107ae3e0>,dfundx=<function dquadraticdx at 0x1107ae520>,x0=0.1)


[0.0003777893186295717 = Minimize(1,0,fun=<function quadratic at 0x1107ae3e0>,dfundx=<function dquadraticdx at 0x1107ae520>,x0=0.1) <grad=1>,
 0 <grad=-0.5>,
 1 <grad=-0.0003777893186295717>]

In [6]:
a = Scalar(2)
b = Scalar(-1)
x = Minimize(a,b,fun=quadratic,dfundx=dquadraticdx, x0=1, lr=0.1, max_iter=1000)
print(x)
backward(x)
display(trace(x))

0.2501269499458355 = Minimize(2,-1,fun=<function quadratic at 0x1107ae3e0>,dfundx=<function dquadraticdx at 0x1107ae520>,x0=1,lr=0.1,max_iter=1000)


[0.2501269499458355 = Minimize(2,-1,fun=<function quadratic at 0x1107ae3e0>,dfundx=<function dquadraticdx at 0x1107ae520>,x0=1,lr=0.1,max_iter=1000) <grad=1>,
 -1 <grad=-0.25>,
 2 <grad=-0.12506347497291775>]

In [8]:
import random
random.seed(314)
truea = 1 + random.random()
print(f'{a=}')
N = 100
bs = [2*random.random()-1 for _ in range(N)]
xs = [minimizeforward(truea,b,fun=quadratic) for b in bs]

ahat = 1.5
notconverged = True
ii = 0 
lr = 0.05
while ii<1000 and notconverged:
    sa = Scalar(ahat)
    sbs = [Scalar(b) for b in bs]
    sxs = [Scalar(x) for x in xs]
    xhats = [Minimize(sa,sb,fun=quadratic,dfundx=dquadraticdx, x0=0, lr=0.1, max_iter=1000) for sb in sbs]
    loss = Sum(*[Pow(Plus(xhat,Times(Scalar(-1),Scalar(x))),p=2) for xhat,x in zip(xhats,xs)])
    backward(loss)
    ahat += -lr * sa.grad
    print(f'{ahat=}, {loss.value=}')
    if abs(sa.grad)<1e-4:
        notconverged=False
    ii+=1
if notconverged:
    print("Failed to converge")

a=2 <grad=-0.12506347497291775>
ahat=1.4395180480454486, loss.value=0.23002584979209342
ahat=1.3847254217289886, loss.value=0.16013589480537632
ahat=1.3370430170652248, loss.value=0.10384281388725108
ahat=1.2974875261161845, loss.value=0.06211924241381286
ahat=1.2663739460182275, loss.value=0.034085922552572985
ahat=1.243223601699386, loss.value=0.017124219105256723
ahat=1.2268526206239796, loss.value=0.00795449741883351
ahat=1.215784367648377, loss.value=0.0034482780813583735
ahat=1.2085437869680224, loss.value=0.0014232901298072355
ahat=1.2039325766734423, loss.value=0.0005636729709401966
ahat=1.2010473682162583, loss.value=0.0002173408884092487
ahat=1.1992562251553487, loss.value=8.298661999315701e-05
ahat=1.1981579337378985, loss.value=3.1022288393772386e-05
ahat=1.1974860934117495, loss.value=1.156886497544171e-05
ahat=1.1970742345874161, loss.value=4.345082705508395e-06
ahat=1.1968236200428806, loss.value=1.6081443312819469e-06
ahat=1.1966708876421972, loss.value=5.94663453307333