In [1]:
import numpy as np

In [20]:
image = np.arange(28*28*64).reshape(64,1,28,28)

In [21]:
class naive_MaxPool():
    def __init__(self, F, stride):
        self.F = F
        self.S = stride
        self.cache = None

    def _forward(self, X):
        # X: (N, Cin, H, W): maxpool along 3rd, 4th dim
        (N,Cin,H,W) = X.shape
        F = self.F
        W_ = int(float(W)/F)
        H_ = int(float(H)/F)
        Y = np.zeros((N,Cin,W_,H_))
        M = np.zeros(X.shape) # mask
        for n in range(N):
            for cin in range(Cin):
                for w_ in range(W_):
                    for h_ in range(H_):
                        Y[n,cin,w_,h_] = np.max(X[n,cin,F*w_:F*(w_+1),F*h_:F*(h_+1)])
                        i,j = np.unravel_index(X[n,cin,F*w_:F*(w_+1),F*h_:F*(h_+1)].argmax(), (F,F))
                        M[n,cin,F*w_+i,F*h_+j] = 1
        self.cache = M
        return Y

    def _backward(self, dout):
        M = self.cache
        (N,Cin,H,W) = M.shape
        dout = np.array(dout)
        #print("dout.shape: %s, M.shape: %s" % (dout.shape, M.shape))
        dX = np.zeros(M.shape)
        for n in range(N):
            for c in range(Cin):
                #print("(n,c): (%s,%s)" % (n,c))
                dX[n,c,:,:] = dout[n,c,:,:].repeat(2, axis=0).repeat(2, axis=1)
        return dX*M

In [22]:
naive_maxpool = naive_MaxPool(2,2)
%timeit naive_maxpool._forward(image)

133 ms ± 620 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [23]:
out = naive_maxpool._forward(image)
%timeit naive_maxpool._backward(out)

556 µs ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [10]:
class im2col_MaxPool:
    def __init__(self, pool_h, pool_w, stride=1, pad=0):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.stride = stride
        self.pad = pad
        
        self.x = None
        self.arg_max = None

    def _forward(self, x):
        N, C, H, W = x.shape
        out_h = int(1 + (H - self.pool_h) / self.stride)
        out_w = int(1 + (W - self.pool_w) / self.stride)
        
        col = im2col(x, self.pool_h, self.pool_w, self.stride)
        col = col.reshape(-1, self.pool_h*self.pool_w)
        
        arg_max = np.argmax(col, axis=1)
        out = np.max(col, axis=1)

        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)

        self.x = x
        self.arg_max = arg_max

        return out

    def _backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)
        
        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size,)) 
        
        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
        
        return dx
    
def im2col(input_data, filter_h, filter_w, stride=1):
    N, C, H, W = input_data.shape  
    out_h = (H - filter_h)//stride + 1  
    out_w = (W - filter_w)//stride + 1  

    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
    
    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = input_data[:, :, y:y_max:stride, x:x_max:stride]
    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
    N, C, H, W = input_shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1
    col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)

    img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]

    return img[:, :, pad:H + pad, pad:W + pad]

In [24]:
im2col_maxpool = im2col_MaxPool(2,2,2)
%timeit im2col_maxpool._forward(image)

1.21 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
%timeit im2col_maxpool._backward(out)

282 µs ± 5.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [28]:
class reshape_MaxPool:
    def __init__(self, pool_h, pool_w, strides):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.strides = strides
        self.cache = None
    
    
    def _forward(self, x):
        N, C, H, W = x.shape
        assert self.pool_h == self.pool_w == self.strides, 'Invalid pool params'
        assert H % self.pool_h == 0
        assert W % self.pool_w == 0
        x_reshaped = x.reshape(N, C, H // self.pool_h, self.pool_h,
                               W // self.pool_w, self.pool_w)
        out = x_reshaped.max(axis=3).max(axis=4)

        self.cache = (x, x_reshaped, out)
        return out

    def _backward(self, dout): 
        dx_reshaped = np.zeros_like(self.cache[1])
        out_newaxis = self.cache[2][:, :, :, np.newaxis, :, np.newaxis]
        mask = (self.cache[1] == out_newaxis)
        dout_newaxis = dout[:, :, :, np.newaxis, :, np.newaxis]
        dout_broadcast, _ = np.broadcast_arrays(dout_newaxis, dx_reshaped)
        dx_reshaped = dout_broadcast * mask
        #The line blow is to ensure everyone get correct result
        #dx_reshaped /= np.sum(mask, axis=(3, 5), keepdims=True)
        dx = dx_reshaped.reshape(self.cache[0].shape)

        return dx

In [29]:
reshape_maxpool = reshape_MaxPool(2,2,2)
%timeit reshape_maxpool._forward(image)

109 µs ± 560 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [30]:
%timeit reshape_maxpool._backward(out)

1.49 ms ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
