In [1]:
import numpy as np
from numpy import ndarray
from typing import Callable
from s_utils import deriv, sigmoid

In [15]:
def matrix_function_backward(X: ndarray, W: ndarray, func: Callable[[ndarray], [ndarray]]) -> ndarray:
    assert X.shape[1] == W.shape[0]
    # forward pass
    N = np.dot(X, W)
    S = func(N)
    
    # backward pass, chain rule
    # deriv of sigmoid wrt N, outer func
    dsdn = deriv(func, N)
    # deriv of N wrt X, inner func
    dndx = np.transpose(W, (1, 0))
    # multilply both
    return (N,S),(dsdn*dndx)

x = np.array([[4,2,1]])
w = np.array([[0.2], [0.1], [0.3]])
forward1, backward1 = matrix_function_backward(x, w, sigmoid)
print('example 1: ', forward1, backward1)
x = np.array([[0.2, 0.1, -0.5]])
w = np.array([[1], [2], [1]])
forward2, backward2 = matrix_function_backward(x, w, sigmoid)
print('example 2: ', forward2, backward2)
x = np.array([[4,2,1]])
w = np.array([[8], [5], [4]])
forward3, backward3 = matrix_function_backward(x, w, sigmoid)
print('example 3: ', forward3, backward3) # vanishing grad in sigmoid

example 1:  (array([[1.3]]), array([[0.78583498]])) [[0.03365967 0.01682984 0.05048951]]
example 2:  (array([[-0.1]]), array([[0.47502081]])) [[0.24937602 0.49875204 0.24937602]]
example 3:  (array([[46]]), array([[1.]])) [[0. 0. 0.]]
