We are checking the chain rule for differentialtion:

Let us check the function:

$$ f(x) = \tanh( w_2 ( \tanh (w_1 x)  ) ) $$

By the chain rule, this should be solved as (we shall use the term $d$ for the differentiation function $f'(x) = 1 - tanh^2(h)$:



In [1]:
import numpy as np

Function composition:

given a list of unary functions, can a longer function be generated?

In [29]:
np.tan(np.sin(np.cos(0.5)))

0.9681102639004562

In [14]:
from functools import reduce

In [37]:
fns  = [np.tan, np.sin, np.cos]
fns1 = [np.cos, np.sin, np.tan]

In [42]:
x = reduce(  lambda f, f1: lambda m: f1(f(m)) , fns1, lambda m: m )
x(0.5)

0.9681102639004562

In [65]:
def x(n):
    return lambda m: n*m

fns = [x(5), np.tanh, x(3), np.tanh]

In [66]:
def fnAll(fnList, x, verbose=False):

    result = x
    for i, f in enumerate(fnList):
        if verbose:
            print('[{:05d}] --> {}'.format(i, result))
        result = f(result)

    i += 1
    if verbose:
        print('[{:05d}] --> {}'.format(i, result))
    return result

In [68]:
fnAll(fns, 2, verbose=True)

[00000] --> 2
[00001] --> 10
[00002] --> 0.9999999958776927
[00003] --> 2.9999999876330783
[00004] --> 0.995054753564718


0.995054753564718

In [90]:
def fdDeltaX(fnList, x):
    deltaX = 1e-5
    result = (fnAll(fns, fDiff, x+deltaX) - fnAll(fns, fDiff, x-deltaX))/(2*deltaX)
    return result

In [91]:
fdDeltaX(fns, 0.1)

2.6121582854310432

## Simulate the scalar case

In this case, we shall look at the following:

$$ f(x) = \tanh( w_2 \tanh ( w_1x ) ) $$

The differentiation can be viewed as:

$$ f'(x) = [1][ w_1 ][ d( w_1x ) ][ w2 ][ d( w_2 \tanh ( w_1x ) ) ] $$

Note that here,

$$  d(x) = \frac {d(\tanh(x))} {dx} = 1 - \tanh^2(x) $$


Associated function lists are:

```python
fns   = [x(5), np.tanh, x(3), np.tanh]
fDiff = [dx(5), dTanh, dx(3), dTanh]
```

In [94]:
def x(n):
    return lambda m: n*m

def dTanh(x):
    return 1 - np.tanh(x)**2

def dx(n):
    return lambda m: n

fns   = [x(5), np.tanh, x(3), np.tanh]
fDiff = [dx(5), dTanh, dx(3), dTanh]

In [95]:
def fnAll(fnList, fDiff, x, verbose=False):

    result  = x
    dResult = 1
    
    for i, f in enumerate(fnList):
        if verbose:
            print('[{:05d}] --> {} {}'.format(i, result, dResult))
        dResult *= fDiff[i]( result ) 
        result  = f(result)

    i += 1
    if verbose:
        print('[{:05d}] --> {} {}'.format(i, result, dResult))
    return result, dResult

In [96]:
fnAll(fns, fDiff, 0.1, verbose=True)

[00000] --> 0.1 1
[00001] --> 0.5 5
[00002] --> 0.46211715726000974 3.932238664829637
[00003] --> 1.3863514717800292 11.796715994488912
[00004] --> 0.8823655878825888 2.6121582574596642


(0.8823655878825888, 2.6121582574596642)

## Simulate the vector case

Remember that in this case, we are dealing with partial differential equations:


In [23]:
N  = 4
i  = 2

xn  = np.random.rand(N).reshape((-1, 1))
dxn = np.zeros(xn.shape); dxn[i, 0] = 1

In [40]:
def V(M):
    '''
    M = a matrix of shape(m,n)
    
    Returns
    -------
    A function that takes a vector of
    shape (n,1) and returns a vector 
    of shape (m,1)
    '''
    return lambda m: np.matmul(M, m)

def dTanh(x):
    return 1 - np.tanh(x)**2



A = np.random.rand(5, 4)
B = np.random.rand(2, 5)
C = np.random.rand(1, 2)

fns   = [(V(A), np.tanh), (V(B), np.tanh), (V(C), np.tanh)]
fDiff = [(V(A), dTanh), (V(B), dTanh), (V(C), dTanh)]

In [41]:
def fnAll(fnList, fDiff, xn, dxn, verbose=False):

    result  = xn.copy()
    dResult = dxn.copy()

    if verbose:
        print('[{:05d}] result: {} | {}'.format(-1, result.T, dResult.T))


    for i, (W, a) in enumerate(fnList):
        
        result  = W(result)
        
        W1, a1  =  fDiff[i]
        dResult =  W1(dResult)
        dResult *= a1(result)
        
        
        if verbose:
            print('[{:05d}] result: {} | {}'.format(i, result.T, dResult.T))
        result  = a(result)
        if verbose:
            print('[{:05d}] result: {} | {}'.format(i, result.T, dResult.T))
        
                
    return result

In [42]:
fnAll(fns, fDiff, xn, dxn, verbose=True)

[-0001] result: [[0.39644335 0.4235196  0.1987985  0.39992922]] | [[0. 0. 1. 0.]]
[00000] result: [[0.69035062 1.06598603 0.85873202 0.71542908 0.5247565 ]] | [[0.15914934 0.11020659 0.47874477 0.2275865  0.67711029]]
[00000] result: [[0.5982072  0.78794415 0.6956038  0.61406999 0.48136274]] | [[0.15914934 0.11020659 0.47874477 0.2275865  0.67711029]]
[00001] result: [[1.88955672 1.95783325]] | [[0.07074623 0.08465793]]
[00001] result: [[0.95533442 0.96092415]] | [[0.07074623 0.08465793]]
[00002] result: [[0.63017194]] | [[0.03688988]]
[00002] result: [[0.5581706]] | [[0.03688988]]


array([[0.5581706]])

In [43]:
delXn = 1e-10
xn1 = xn.copy()
xn1[i, 0] += delXn

(fnAll(fns, fDiff, xn1, dxn) -  fnAll(fns, fDiff, xn, dxn))/delXn

array([[0.03689049]])