In [2]:
import numpy as np

In [84]:
class  MaxPool1D():
    def __init__(self, pool_size: int=2, stride: int=1, axis: int=1):
        self.pool_size = pool_size
        self.stride = stride
        self.axis = axis
        
    def forward(self, inputs: np.ndarray) -> np.ndarray:
        """inputs: np.ndarray(batch_size, n_elements)
           return (inputs[1] - pool_size + 1) // stride"""
        pooled = np.apply_along_axis(
            func1d=self.__pool,
            axis=self.axis,
            arr=inputs
        )
        result = np.take(pooled, indices=0, axis=self.axis)
        self.max_inds = np.take(pooled, indices=1, axis=self.axis).astype(np.uint32)
        return result
    
    def __pool(self, inputs: np.ndarray) -> np.ndarray:
        output_size = np.floor((inputs.size - self.pool_size + 1) / self.stride + 0.5).astype(int)
        result = np.zeros(output_size, dtype=np.float32)
        max_inds = np.zeros_like(result, dtype=np.uint32)
        for out_ind in range(output_size):
            start = out_ind * self.stride
            stop = start + self.pool_size
            max_inds[out_ind] = np.argmax(inputs[start: stop]) + start
            result[out_ind] = inputs[max_inds[out_ind]]
        return result, max_inds
    
    def backward(self,
                 inputs: np.ndarray,
                 error_grad_mat: np.ndarray,
                 l1: np.float32 = 0.001,
                 l2: np.float32 = 0.001) -> np.ndarray:
        error_by_inputs = np.zeros_like(inputs)
        indexes = np.array(list(np.ndindex(*(error_grad_mat.shape)))) 
        np.put(indexes, np.arange(indexes.shape[0]) * indexes.shape[1] + self.axis, self.max_inds.flatten())
        error_by_inputs[tuple(indexes.T)] = error_grad_mat.flatten()
        return error_by_inputs

In [88]:
a = np.random.rand(2,3, 4)
print("a", a, sep="\n")
pooling = MaxPool1D(pool_size=3, stride=2, axis=1)
res = pooling.forward(a)
print("res", res, sep="\n")
pooling.backward(a, res)


a
[[[0.9838107  0.02120399 0.73411237 0.99301228]
  [0.0226108  0.03694803 0.70356613 0.99976487]
  [0.57210921 0.90463535 0.28209485 0.10785797]]

 [[0.47585868 0.31439878 0.98992964 0.48488062]
  [0.46186045 0.44829631 0.26614923 0.39666725]
  [0.0807282  0.80308296 0.34583937 0.03564313]]]
res
[[[0.98381072 0.90463537 0.73411238 0.99976486]]

 [[0.47585869 0.80308294 0.98992962 0.48488063]]]


array([[[0.98381072, 0.        , 0.73411238, 0.        ],
        [0.        , 0.        , 0.        , 0.99976486],
        [0.        , 0.90463537, 0.        , 0.        ]],

       [[0.47585869, 0.        , 0.98992962, 0.48488063],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.80308294, 0.        , 0.        ]]])

In [97]:
pooling.result_inds

array([[1, 3, 5],
       [1, 3, 5]], dtype=uint32)

In [74]:
a = np.arange(24).reshape((2, 3,-1))
b = np.array(list(np.ndindex(*(a.shape))))
np.put(b, np.arange(b.shape[0]) * b.shape[1] + 1, np.array([10] * b.shape[0]))
b

array([[ 0, 10,  0],
       [ 0, 10,  1],
       [ 0, 10,  2],
       [ 0, 10,  3],
       [ 0, 10,  0],
       [ 0, 10,  1],
       [ 0, 10,  2],
       [ 0, 10,  3],
       [ 0, 10,  0],
       [ 0, 10,  1],
       [ 0, 10,  2],
       [ 0, 10,  3],
       [ 1, 10,  0],
       [ 1, 10,  1],
       [ 1, 10,  2],
       [ 1, 10,  3],
       [ 1, 10,  0],
       [ 1, 10,  1],
       [ 1, 10,  2],
       [ 1, 10,  3],
       [ 1, 10,  0],
       [ 1, 10,  1],
       [ 1, 10,  2],
       [ 1, 10,  3]])