# 使用 im2col 实现卷积

2019 年 5 月 16 日

参考资料: [numpy实现卷积转换成矩阵乘法算法](https://zhuanlan.zhihu.com/p/63974249?utm_source=wechat_session&utm_medium=social&utm_oi=32625469161472)

In [3]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline

In [46]:
def im2col(img, ksize, stride=1):
    """
    Parameters:
        img: image of shape (h, w, c)
        ksize: kernel size (kh, kw)
        stride: ...
    Return:
        col: convert image to its column version
    """
    img = img.transpose(2, 0, 1)
    print(repr(img))
    c, h, w = img.shape
    kh, kw = ksize
    out_h, out_w = int((h - kh) / stride) + 1, int((w - kw) / stride) + 1
    col_h, col_w = out_h * out_w, kh * kw * c
    col = np.empty((col_h, col_w))
    for idx in range(col_h):
        i, j = int(idx / out_w), idx % out_w
        col[idx] = img[:, i : i + kh, j : j + kw].reshape(col_w)
    return col

In [45]:
img = np.arange(18).reshape(1, 3, 3, 2)
print(repr(img.squeeze()))
ksize = (2, 2)
im2col(img, ksize, stride=2)

array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15],
        [16, 17]]])
array([[[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 1,  3,  5],
        [ 7,  9, 11],
        [13, 15, 17]]])


array([[0., 2., 6., 8., 1., 3., 7., 9.]])

来自: [numpy实现卷积转换成矩阵乘法算法](https://zhuanlan.zhihu.com/p/63974249?utm_source=wechat_session&utm_medium=social&utm_oi=32625469161472), 感觉有点问题. 

In [47]:
def im2col(img, ksize, stride=1):
    N, H, W, C = img.shape
    out_h = (H - ksize[0]) // stride + 1
    out_w = (W - ksize[1]) // stride + 1
    col = np.empty((N*out_h*out_w, ksize[0]*ksize[1]*C))
    outsize = out_w*out_h
    for y in range(out_h):
        y_min = y * stride
        y_max = y_min + ksize[0]
        y_start = y * out_w
        for x in range(out_w):
            x_min = x * stride
            x_max = x_min + ksize[1]
            col[y_start+x::outsize, :] = img[:, y_min:y_max, x_min:x_max, :].reshape(N, -1)
    return col

In [48]:
img = np.arange(18).reshape(1, 3, 3, 2)
print(repr(img.squeeze()))
ksize = (2, 2)
im2col(img, ksize, stride=2)

array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15],
        [16, 17]]])


array([[0., 1., 2., 3., 6., 7., 8., 9.]])

在上面 [numpy实现卷积转换成矩阵乘法算法](https://zhuanlan.zhihu.com/p/63974249?utm_source=wechat_session&utm_medium=social&utm_oi=32625469161472) 这篇博文中, 作者还给出了另一种思路: [卷积算法另一种高效实现，as_strided详解](https://zhuanlan.zhihu.com/p/64933417), 利用 `np.lib.stride_tricks.as_strided` 这个函数来实现 im2col, 
由于这个函数有点复杂, 这篇博文作者对这个函数进行了详细的介绍, 为了确定我理解了该函数的用法, 自行实现作者给出的一个例子.

In [52]:
X = np.arange(10, 26).reshape(4, 4)
print(X)
X.dtype

[[10 11 12 13]
 [14 15 16 17]
 [18 19 20 21]
 [22 23 24 25]]


dtype('int64')

In [54]:
Y = np.lib.stride_tricks.as_strided(X, shape=(2, 2, 3, 3), strides=(32, 8, 32, 8))
Y

array([[[[10, 11, 12],
         [14, 15, 16],
         [18, 19, 20]],

        [[11, 12, 13],
         [15, 16, 17],
         [19, 20, 21]]],


       [[[14, 15, 16],
         [18, 19, 20],
         [22, 23, 24]],

        [[15, 16, 17],
         [19, 20, 21],
         [23, 24, 25]]]])

上面例子中输入数据是 2D 的, 下面看输入数据是 3D 的

In [55]:
X = np.arange(18, dtype=np.int32).reshape(1, 2, 3, 3)
X

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]], dtype=int32)

In [68]:
np.lib.stride_tricks.as_strided(X, shape=(1, 2, 2, 2, 2, 2), strides=(72, 36, 12, 4, 12, 4))

array([[[[[[ 0,  1],
           [ 3,  4]],

          [[ 1,  2],
           [ 4,  5]]],


         [[[ 3,  4],
           [ 6,  7]],

          [[ 4,  5],
           [ 7,  8]]]],



        [[[[ 9, 10],
           [12, 13]],

          [[10, 11],
           [13, 14]]],


         [[[12, 13],
           [15, 16]],

          [[13, 14],
           [16, 17]]]]]], dtype=int32)

如果 X 的维度是 (H, W, C), 情况又如何呢?

In [92]:
def split_by_strides(X, kh, kw, s):
    N, H, W, C = X.shape
    oh = (H - kh) // s + 1
    ow = (W - kw) // s + 1
    shape = (N, oh, ow, kh, kw, C)
    strides = (X.strides[0], X.strides[1]*s, X.strides[2]*s, *X.strides[1:])
    A = np.lib.stride_tricks.as_strided(X, shape=shape, strides=strides)
    return A

# X = np.arange(18, dtype=np.int32).reshape(1, 2, 3, 3).transpose(0, 2, 3, 1)
X = np.array([[[[ 0,  9],
         [ 1, 10],
         [ 2, 11]],

        [[ 3, 12],
         [ 4, 13],
         [ 5, 14]],

        [[ 6, 15],
         [ 7, 16],
         [ 8, 17]]]], dtype=np.int32)
X

(72, 24, 8, 4)

In [98]:
# np.lib.stride_tricks.as_strided(X, shape=(1, 2, 2, 2, 2, 2), strides=(72, 12, 4, 12, 4, 36))
split_by_strides(X, 2, 2, 1).strides
np.lib.stride_tricks.as_strided(X, shape=(1, 2, 2, 2, 2, 2), strides=(72, 24, 8, 24, 8, 4))

(72, 24, 8, 24, 8, 4)

In [75]:
print(X.shape)
print(X.strides)
print(X.transpose(0, 3, 1, 2).shape)
print(X.transpose(0, 3, 1, 2).strides)

(1, 3, 3, 2)
(72, 12, 4, 36)
(1, 2, 3, 3)
(72, 36, 12, 4)


In [76]:
Y = np.arange(12).reshape(1, 2, 2, 3)
Y.strides

(96, 48, 24, 8)

In [80]:
a = np.arange(18, dtype=np.int32).reshape(1, 3, 3, 2)
a.strides
print(a.transpose(0, 2, 3, 1).strides)

(72, 8, 4, 24)


In [87]:
X = np.arange(18, dtype=np.int32).reshape(2, 3, 3)
X.strides

nX = X.transpose(1, 2, 0)
nX.strides

nX[0, 0, 0] = 100

X[0, 0, 0]

100

In [101]:
X = np.arange(4).reshape(1, 1, 2, 2)
A = split_by_strides(X, 1, 1, 1)
print(A.shape)
A[0, 0, 0, 0, 0, 0] = 100
X

(1, 1, 2, 1, 1, 2)


array([[[[100,   1],
         [  2,   3]]]])