In [1]:
import numpy as np

In [99]:
class  MaxPool1D():
    def __init__(self, pool_size: int=2, stride: int=1):
        self.pool_size = pool_size
        self.stride = stride
        
    def forward(self, inputs: np.ndarray) -> np.ndarray:
        """inputs: np.ndarray(batch_size, n_elements)
           return (inputs[1] - pool_size + 1) // stride"""
        assert len(inputs.shape) == 2 
        output_size = np.floor((inputs.shape[1] - self.pool_size + 1) / self.stride + 0.5).astype(int)
        result = np.empty((inputs.shape[0], output_size), dtype=np.float32)
        self.max_inds = np.empty_like(result, dtype=np.uint32)
        for out_ind in range(output_size):
            start = out_ind * self.stride
            stop = start + self.pool_size
            self.max_inds[:, out_ind: out_ind + 1] = np.argmax(inputs[:, start: stop], axis=1, keepdims=True) + start
            result[:, out_ind] = inputs[range(inputs.shape[0]), self.max_inds[:, out_ind]]
        return result
    
    def backward(self,
                 inputs: np.ndarray,
                 error_grad_mat: np.ndarray,
                 l1: np.float32 = 0.001,
                 l2: np.float32 = 0.001) -> np.ndarray:
        error_by_inputs = np.zeros_like(inputs)
        rows = np.repeat(range(error_grad_mat.shape[0]), error_grad_mat.shape[1]).flatten()
        cols = self.max_inds.flatten()
        error_by_inputs[rows, cols] = error_grad_mat.flatten()
        return error_by_inputs

In [105]:
a = np.random.rand(2,5)
print(a)
pooling = MaxPool1D(stride=2)
res = pooling.forward(a)
display(res)
pooling.backward(a, res)


[[0.01913571 0.92748992 0.99478995 0.53823202 0.23681796]
 [0.00688022 0.33551887 0.9533459  0.7388261  0.04049291]]


array([[0.92748994, 0.99478996],
       [0.33551887, 0.9533459 ]], dtype=float32)

array([[0.        , 0.92748994, 0.99478996, 0.        , 0.        ],
       [0.        , 0.33551887, 0.95334589, 0.        , 0.        ]])

In [97]:
pooling.result_inds

array([[1, 3, 5],
       [1, 3, 5]], dtype=uint32)

In [None]:
1