In [9]:
import numpy as np
from typing import Tuple

In [43]:
import sys
sys.path.append('../')
from core.layers import MaxPool1D

In [44]:
a = np.random.rand(2,3, 4)
print("a", a, sep="\n")
pooling = MaxPool1D(pool_size=3, stride=2, axis=1)
res = pooling.forward(a)
print("res", res, sep="\n")
pooling.backward(a, res)


a
[[[0.87438169 0.58521346 0.92161849 0.95982559]
  [0.76362345 0.96857109 0.31032174 0.83557177]
  [0.30788351 0.37818086 0.28135505 0.29124525]]

 [[0.97339315 0.17963671 0.39660327 0.94229126]
  [0.71856004 0.2236836  0.40879094 0.24339306]
  [0.84279593 0.2700115  0.03313796 0.4881363 ]]]
res
[[[0.87438172 0.96857107 0.92161846 0.95982558]]

 [[0.97339314 0.27001148 0.40879095 0.94229126]]]


array([[[0.87438172, 0.        , 0.92161846, 0.95982558],
        [0.        , 0.96857107, 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.97339314, 0.        , 0.        , 0.94229126],
        [0.        , 0.        , 0.40879095, 0.        ],
        [0.        , 0.27001148, 0.        , 0.        ]]])

In [79]:
class MaxPool2D():
    def __init__(self, pool_sizes: Tuple[int, int]=(2, 2), strides: Tuple[int, int]=(1, 1), axes: Tuple[int, int]=(1, 2)):
        assert axes[0] < axes[1]
        self.pool_sizes = pool_sizes
        self.strides = strides
        self.axes = list(axes)
        
    def forward(self, inputs: np.ndarray) -> np.ndarray:
        """inputs: np.ndarray(batch_size, *dimensions)"""
        if len(inputs.shape) == 2:
            result, self.max_inds = self.__pool(inputs)
        else:
            result_dims = np.array(list(inputs.shape))
            result_dims[self.axes] = self.__calc_out_size(result_dims[self.axes]) 
            result = np.empty(result_dims, dtype=np.float32)
            self.max_inds = np.zeros_like(result, dtype=object)
            poped_axes = list(range(len(inputs.shape)))
            poped_axes.pop(self.axes[0])
            poped_axes.pop(self.axes[1] - 1)
            poped_axes = tuple(np.expand_dims(poped_axes, 0))
            poped_dims = np.array(inputs.shape)[poped_axes]
            indexes = np.array([slice(None)] * len(inputs.shape), dtype=object)
            for others in np.ndindex(*poped_dims):
                indexes[poped_axes] = others
                tupled = tuple(indexes)
                result[tupled], self.max_inds[tupled] = self.__pool(inputs[tupled]) 
        return result
    
    def __calc_out_size(self, inputs_shape: tuple):
        """inputs_shape: shape of pooled 2d matrix"""
        return [
            np.floor((inputs_shape[i] - self.pool_sizes[i] + 1) / self.strides[i] + 0.5).astype(int)
            for i in range(2)
        ]

    def __pool(self, inputs: np.ndarray) -> np.ndarray:
        """inputs: 2d array"""
        output_sizes = self.__calc_out_size(inputs.shape)
        result = np.zeros(output_sizes, dtype=np.float32)
        max_inds = np.zeros(result.shape, dtype=tuple)
        for out_inds in np.ndindex(*output_sizes):
            start = [out_inds[i] * self.strides[i] for i in range(2)]
            stop = [start[i] + self.pool_sizes[i] for i in range(2)]
            borders = tuple([slice(start[i], stop[i]) for i in range(2)])
            max_inds[out_inds] = tuple(np.unravel_index(np.argmax(inputs[borders]), self.pool_sizes) + np.array(start))
            result[out_inds] = inputs[max_inds[out_inds]]
        return result, max_inds
    
    def backward(self,
                inputs: np.ndarray,
                error_grad_mat: np.ndarray,
                l1: np.float32 = 0.001,
                l2: np.float32 = 0.001) -> np.ndarray:
        error_by_inputs = np.zeros_like(inputs)
        flattend_max = np.array(list(self.max_inds.flatten()))
        indexes = np.array(list(np.ndindex(*(error_grad_mat.shape)))) 
        for i in range(2):
            np.put(indexes, np.arange(indexes.shape[0]) * indexes.shape[1] + self.axes[i], flattend_max[:, i])
        error_by_inputs[tuple(indexes.T)] = error_grad_mat.flatten()
        return error_by_inputs

In [80]:
a = np.random.rand(3,4,4)
print("a", a, sep="\n")
pooling = MaxPool2D(pool_sizes=(2,2), strides=(1, 2), axes=(0,2))
res = pooling.forward(a)
print("res", res, sep="\n")
pooling.backward(a, res)


a
[[[0.55259416 0.74643588 0.56623384 0.19955806]
  [0.47414004 0.5795471  0.6162695  0.57581072]
  [0.41045298 0.95141272 0.89002325 0.73921007]
  [0.08843449 0.8691462  0.76143928 0.76678522]]

 [[0.20855568 0.82524254 0.32863929 0.09428927]
  [0.19208188 0.78783922 0.27354176 0.45712893]
  [0.51345745 0.15498382 0.79437955 0.62974118]
  [0.35168093 0.76424899 0.26166434 0.9026296 ]]

 [[0.21145086 0.49116235 0.81729649 0.5965204 ]
  [0.97959501 0.19245079 0.08605412 0.3967168 ]
  [0.67602722 0.1015473  0.50434688 0.10514434]
  [0.08864261 0.84909109 0.00251432 0.21780545]]]
res
[[[0.8252425  0.5662338 ]
  [0.78783923 0.6162695 ]
  [0.95141274 0.89002323]
  [0.8691462  0.9026296 ]]

 [[0.8252425  0.8172965 ]
  [0.979595   0.45712894]
  [0.67602724 0.79437953]
  [0.8490911  0.9026296 ]]]


array([[[0.        , 0.        , 0.56623381, 0.        ],
        [0.        , 0.        , 0.61626953, 0.        ],
        [0.        , 0.95141274, 0.89002323, 0.        ],
        [0.        , 0.86914623, 0.        , 0.        ]],

       [[0.        , 0.82524252, 0.        , 0.        ],
        [0.        , 0.78783923, 0.        , 0.45712894],
        [0.        , 0.        , 0.79437953, 0.        ],
        [0.        , 0.        , 0.        , 0.90262961]],

       [[0.        , 0.        , 0.8172965 , 0.        ],
        [0.97959501, 0.        , 0.        , 0.        ],
        [0.67602724, 0.        , 0.        , 0.        ],
        [0.        , 0.84909111, 0.        , 0.        ]]])