In [3]:
import numpy as np
from numpy.lib.stride_tricks import as_strided

In [7]:
def split_by_strides(input_data: np.ndarray, kernel_x, kernel_y, stride):
    """
    将张量按卷积核尺寸与步长进行分割
    :param input_data: 被卷积的张量(四维)
    :param kernel_x: 卷积核的高度
    :param kernel_y: 卷积核的宽度
    :param stride: 步长
    :return: output_data: 按卷积步骤展开后的矩阵

    Example: [[1, 2, 3, 4],    2, 2, 2       [[[[1, 2],
              [5, 6, 7, 8],             =>      [5, 6]],
              [9, 10, 11, 12],                 [[3, 4],
              [13, 14, 15, 16]]                 [7, 8]]],
                                              [[[9, 10],
                                                [13, 14]],
                                               [[11, 12],
                                                [15, 16]]]]
    """
    batches, channels, x, y = input_data.shape
    out_x, out_y = (x - kernel_x) // stride + 1, (y - kernel_y) // stride + 1
    shape = (batches, channels, out_x, out_y, kernel_x, kernel_y)
    strides = (*input_data.strides[:-2], input_data.strides[-2] * stride,
               input_data.strides[-1] * stride, *input_data.strides[-2:])
    output_data = as_strided(input_data, shape, strides=strides)
    return output_data

In [8]:
x=np.random.randint(0,10,(2,3,4,4))

In [9]:
x

array([[[[0, 1, 4, 1],
         [8, 4, 0, 6],
         [7, 7, 3, 5],
         [0, 2, 1, 6]],

        [[1, 1, 6, 6],
         [0, 7, 7, 1],
         [2, 5, 8, 6],
         [9, 8, 3, 5]],

        [[7, 8, 9, 5],
         [7, 1, 6, 4],
         [3, 2, 9, 8],
         [4, 8, 1, 1]]],


       [[[9, 4, 0, 0],
         [9, 5, 9, 8],
         [5, 4, 1, 9],
         [2, 5, 6, 5]],

        [[8, 8, 0, 2],
         [0, 1, 6, 0],
         [5, 5, 1, 2],
         [9, 5, 6, 5]],

        [[0, 1, 0, 5],
         [6, 1, 8, 0],
         [0, 6, 0, 0],
         [2, 9, 2, 6]]]])

In [10]:
split_by_strides(x, 2, 2, 2)

array([[[[[[0, 1],
           [8, 4]],

          [[4, 1],
           [0, 6]]],


         [[[7, 7],
           [0, 2]],

          [[3, 5],
           [1, 6]]]],



        [[[[1, 1],
           [0, 7]],

          [[6, 6],
           [7, 1]]],


         [[[2, 5],
           [9, 8]],

          [[8, 6],
           [3, 5]]]],



        [[[[7, 8],
           [7, 1]],

          [[9, 5],
           [6, 4]]],


         [[[3, 2],
           [4, 8]],

          [[9, 8],
           [1, 1]]]]],




       [[[[[9, 4],
           [9, 5]],

          [[0, 0],
           [9, 8]]],


         [[[5, 4],
           [2, 5]],

          [[1, 9],
           [6, 5]]]],



        [[[[8, 8],
           [0, 1]],

          [[0, 2],
           [6, 0]]],


         [[[5, 5],
           [9, 5]],

          [[1, 2],
           [6, 5]]]],



        [[[[0, 1],
           [6, 1]],

          [[0, 5],
           [8, 0]]],


         [[[0, 6],
           [2, 9]],

          [[0, 0],
           [2, 6]]]]]])