In [1]:
import numpy as np
from numpy.lib.stride_tricks import as_strided

In [4]:
def split_by_strides(input_data: np.ndarray, kernel_x, kernel_y, stride):
    """
    将张量按卷积核尺寸与步长进行分割
    :param input_data: 被卷积的张量(四维)
    :param kernel_x: 卷积核的高度
    :param kernel_y: 卷积核的宽度
    :param stride: 步长
    :return: output_data: 按卷积步骤展开后的矩阵

    Example: [[1, 2, 3, 4],    2, 2, 2       [[[[1, 2],
              [5, 6, 7, 8],             =>      [5, 6]],
              [9, 10, 11, 12],                 [[3, 4],
              [13, 14, 15, 16]]                 [7, 8]]],
                                              [[[9, 10],
                                                [13, 14]],
                                               [[11, 12],
                                                [15, 16]]]]
    """
    batches, channels, x, y = input_data.shape
    out_x, out_y = (x - kernel_x) // stride + 1, (y - kernel_y) // stride + 1
    shape = (batches, channels, out_x, out_y, kernel_x, kernel_y)
    strides = (*input_data.strides[:-2], input_data.strides[-2] * stride,
               input_data.strides[-1] * stride, *input_data.strides[-2:])
    output_data = as_strided(input_data, shape, strides=strides)
    return output_data

In [5]:
x=np.random.randint(0,10,(2,3,4,4))

In [6]:
y=split_by_strides(x, 2, 2, 2)

In [7]:
y.shape

(2, 3, 2, 2, 2, 2)

In [8]:
y.strides

(192, 64, 32, 8, 16, 4)

In [17]:
y=y.max((-1,-2))

In [18]:
y

array([[[[9, 9],
         [7, 9]],

        [[9, 5],
         [7, 5]],

        [[7, 4],
         [9, 8]]],


       [[[9, 8],
         [9, 9]],

        [[9, 9],
         [7, 6]],

        [[6, 9],
         [9, 8]]]])

In [34]:
y.repeat(2,axis=-1).repeat(2, axis=-2)

array([[[[9, 9, 9, 9],
         [9, 9, 9, 9],
         [7, 7, 9, 9],
         [7, 7, 9, 9]],

        [[9, 9, 5, 5],
         [9, 9, 5, 5],
         [7, 7, 5, 5],
         [7, 7, 5, 5]],

        [[7, 7, 4, 4],
         [7, 7, 4, 4],
         [9, 9, 8, 8],
         [9, 9, 8, 8]]],


       [[[9, 9, 8, 8],
         [9, 9, 8, 8],
         [9, 9, 9, 9],
         [9, 9, 9, 9]],

        [[9, 9, 9, 9],
         [9, 9, 9, 9],
         [7, 7, 6, 6],
         [7, 7, 6, 6]],

        [[6, 6, 9, 9],
         [6, 6, 9, 9],
         [9, 9, 8, 8],
         [9, 9, 8, 8]]]])

In [23]:
as_strided(y, (2,3,4,4), (48,8,0,0))

array([[[[9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9]],

        [[7, 7, 7, 7],
         [7, 7, 7, 7],
         [7, 7, 7, 7],
         [7, 7, 7, 7]],

        [[9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9]]],


       [[[9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9]],

        [[9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9]],

        [[9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9],
         [9, 9, 9, 9]]]])