In [3]:
from vmad import operator, autooperator
# from utils import finite_difference
import numpy as np

\begin{equation}
\begin{split}
    g = \dfrac{f(x+\epsilon) - f(x)}{\epsilon}\\
\end{split}
\end{equation}

In [130]:
import numpy as np
from vmad import operator

def finite_diff(param, func, epsilon, mode='forward'):
    """
    Find the finite differencing of a  given function based off of an input parameter
    Params:
    __________________
    param:   parameter to difference with respect to
    func:    function for finite difference
    epsilon: amount of forward stepping in differencing
    args:    further arguments, if any, for the function

    Returns:
    _________________
    forward differencing solution to function with respect to input parameter

    """
    if mode=='forward':
        k1, k2, k3=1, 0, 1

    elif mode=='backward':
        k1, k2, k3=-1, 0, -1

    elif mode=='central':
        k1, k2, k3=1/2, -1/2, 1

    
    return k3*(func(param+k1*epsilon) - func(param+k2*epsilon))/epsilon



@operator
class finite_difference:
    ain = {'param': '*'}
    aout = {'diff':'*'}

    def apl(node, param, func, epsilon, mode='central'):
        f = func(param)
        return dict(diff=f)

    def vjp(node, _diff, param, func, epsilon, mode='central'):
        delta = finite_diff(param, lambda x: func(x), epsilon, mode=mode)
        return dict(_param=np.dot(delta, _diff))

    def jvp(node, param_, param, func, epsilon, mode='central'):
        delta = finite_diff(param, lambda x: func(x), epsilon, mode=mode)
        return dict(diff_=np.dot(delta, param_))


In [131]:
@autooperator('a->b')
def test(a,epsilon, j, k, mode='central'):
    return finite_difference(a, lambda x, j=j, k=k: test_func(x, j, k),  epsilon, mode=mode)

def test_func(a, j,k):
    return a**3+j*a**2+k*a

In [132]:
epsilon = 1e-5
j,k = 1,2
init = dict(a=1)
v = dict(_b=1)

apl, vjp = test.build(func=test_func, epsilon=epsilon, j=j, k=k, mode='central').compute_with_vjp(init=dict(a=1), v=v)
print('apl=4: {}, vjp=7: {}'.format(apl[0]==4, int(vjp[0])==7))

apl=4: True, vjp=7: True
