In [59]:
import numpy as np
from numpy import ndarray
from typing import Callable, List

In [60]:
# derivative
def deriv(func: Callable[[ndarray], ndarray], _input: ndarray, delta:float=0.001) -> ndarray:
    return (func(_input + delta) - func(_input - delta)) / (delta * 2)
def cube(x: ndarray) -> ndarray:
    return x ** 3
def square(x: ndarray) -> ndarray:
    return x ** 2
print(deriv(func=cube, _input=np.array([3,2,0.1])))
print(deriv(func=square, _input=np.array([3,2,0.1])))

[27.000001 12.000001  0.030001]
[6.  4.  0.2]


## Chain Rule
chain of 2:

F′(x) = f2′​(f1​(x)) ⋅ f1′​(x)

chain of 3:

F′(x) = f3′​(f2​(f1​(x))) ⋅ f2′​(f1​(x)) ⋅ f1′​(x)

In [61]:
# chain of 2
array_func = Callable[[ndarray], ndarray]
chain = List[array_func]

def chain_deriv_2(chain: chain, _input: ndarray) -> ndarray:
    assert len(chain) == 2 # 2 composite fs
    assert _input.ndim == 1 # 1-dim input [3,2,1]
    # arrange chain
    f1 = chain[0]
    f2 = chain[1]

    # chain rule
    dx = deriv(f2, f1(_input))  * deriv(f1, _input)
    return dx
chain_2 = [cube, square]
print(chain_deriv_2(chain=chain_2, _input=np.array([2,1,4])))

[1.92000016e+02 6.00000200e+00 6.14400013e+03]


In [62]:
# chain rule 3
array_func = Callable[[ndarray], ndarray]
chain = List[array_func]

def chain_deriv_3(chain: chain, _input: ndarray) -> ndarray:
    assert len(chain) == 3 # 3 composite fs
    assert _input.ndim == 1 # 1-dim input [3,2,1]
    # arrange chain
    f1 = chain[0]
    f2 = chain[1]
    f3 = chain[2]

    # chain rule
    dx =  deriv(f3, f2(f1(_input))) * deriv(f2, f1(_input))  * deriv(f1, _input)
    return dx

def relu(x: ndarray) -> ndarray:
    return np.maximum(0, x)

chain_3 = [square, cube, relu]
print(chain_deriv_3(chain=chain_3, _input=np.array([3,7,2])))

[  1458.00000596 100842.00040122    192.000004  ]
