In [27]:
import numpy as np
from numpy import ndarray
from abc import ABC, abstractmethod

In [28]:
def assert_same_shape(array: ndarray, array_grad: ndarray):
    assert array.shape == array_grad.shape, \
    f'''
    Two ndarrays should have the same shape;
    instead, first ndarray's shape is {array.shape}
    and second ndarray's shape is {array_grad.shape}
    '''
    return None

<h1> <code>Operation</code> and <code>ParamOperation</code></h1>

In [29]:
# abstract class for any Operation
class Operation(ABC):
    def __init__(self):
        pass
    
    def forward(self, input_: ndarray):
        '''
        Stores input in the self._input instance variable 
        Calls the self._output() function
        '''
        self.input_ = input_
        
        self.output = self._output()
        
        return self.output
    
    
    def backward(self, output_grad: ndarray):
        '''
        Calls the self._input_grad() function.
        Checks that the appropriate shapes match.
        '''
        assert_same_shape(self.output, output_grad)
        
        self.input_grad = self._input_grad(output_grad)
        
        assert_same_shape(self.input_, self._input_grad)
        
        return self.input_grad
    
    @abstractmethod
    def _output(self): 
        '''
        The output method must be defined for each Operation
        '''
        pass
    
    
    @abstractmethod
    def _input_grad(self, output_grad: ndarray) -> ndarray:
        '''
        The _input_grad method must be defined for each Operation
        '''
        pass
    

In [30]:
# Another abstract class for "parameter" operations
class ParamOperation(Operation):
    
    def __init__(self, param: ndarray):
        super().__init__()
        self.param = param
        
    
    def backward(self, output_grad: ndarray):
        '''
        Calls the self._input_grad and self._param_grad.
        Checks appropriate shapes
        '''
        assert_same_shape(self.output, output_grad)
        
        self.input_grad = self._input_grad(output_grad)
        self.param_grad = self._param_grad(output_grad)
        
        assert_same_shape(self.input_, self.input_grad)
        assert_same_shape(self.param, self.param_grad)
    
    
    @abstractmethod
    def _param_grad(self, output_grad: ndarray):
        '''
        Every subclass of ParamOperation must implement _param_grad
        '''
        pass

# Specific Operations

## 1. weight multiply

In [40]:
class WeightMultiply(ParamOperation):
    '''
    Weight multiplication operation for a neural network
    '''
    def __init__(self, W: ndarray):
        '''Initialize Operation with self.param = W'''
        super().__init__(W)
        
    
    def _output(self) -> ndarray:
        '''Compute the output'''
        return np.dot(self.input_, self.param)
    
    
    def _input_grad(self, output_grad: ndarray) -> ndarray: 
        '''Compute the gradient'''
        return np.dot(output_grad, np.transpose(self.param, (1, 0)))
    
    
    def _param_grad(self, output_grad: ndarray) -> ndarray:
        '''Compute the parameter gradient'''
        return np.dot(np.transpose(self.input_, (1, 0)), output_grad) 
        

## 2. bias add

In [33]:
class BiasAdd(ParamOperation):
    '''
    Compute bias addition
    '''
    
    def __init__(self, B: ndarray):
        
        assert B.shape[0] == 1
        
        super().__init__(B)
        
    
    def _output(self) -> ndarray:
        '''
        Compute output
        '''
        return self.input_ + self.param
    
    
    def _input_grad(self, output_grad: ndarray):
        '''
        Compute input gradient
        '''
        return np.ones_like(self.input_) * output_grad
    
    
    def _param_grad(self, ouput_grad: ndarray):
        '''Compute the param grad'''
        param_grad = np.ones_like(self.param) * output_grad
        return np.sum(param_grad, axis=0).reshape(1, param_grad.shape[1])

In [37]:
np.random.seed(100)

X = np.random.random((10, 3))

print(X)

[[0.54340494 0.27836939 0.42451759]
 [0.84477613 0.00471886 0.12156912]
 [0.67074908 0.82585276 0.13670659]
 [0.57509333 0.89132195 0.20920212]
 [0.18532822 0.10837689 0.21969749]
 [0.97862378 0.81168315 0.17194101]
 [0.81622475 0.27407375 0.43170418]
 [0.94002982 0.81764938 0.33611195]
 [0.17541045 0.37283205 0.00568851]
 [0.25242635 0.79566251 0.01525497]]


In [38]:
W = np.random.random((X.shape[1], 1))

print(W)

[[0.59884338]
 [0.60380454]
 [0.10514769]]


In [41]:
w = WeightMultiply(W)

In [44]:
w.backward(X)

array([[0.53813219],
       [0.52152057],
       [0.91470167],
       [0.90457219],
       [0.19952172],
       [1.09421954],
       [0.69967045],
       [1.09197243],
       [0.3307592 ],
       [0.63319251]])