In [21]:
import numpy as np
import sys
sys.path.append("..")
%load_ext autoreload

In [95]:
np.random.seed(1)
x = np.random.randn(16, 3, 32, 32)
W = np.random.randn(8, 3, 2, 2)
b = np.random.randn(8, 1)

## Layer

In [105]:
%autoreload

from src.layers.pooling import Pool2D

conv = Pool2D((3, 3), stride=3, mode='avg')
conv.init_params(x.shape)
conv.forward(x)

array([[[[-1.94034949e-02, -1.21751450e-01,  1.91568016e-01, ...,
           2.67011683e-01,  5.11125972e-01, -1.31721787e-02],
         [-3.00269200e-01,  2.39818798e-02,  4.42508846e-01, ...,
          -4.21098032e-03,  4.41788673e-01,  1.59768802e-01],
         [ 8.24997430e-01, -4.14880407e-01,  3.46074606e-01, ...,
           2.53384165e-01,  7.01967441e-02,  2.61746794e-01],
         ...,
         [-4.86520747e-01, -6.43915407e-03, -4.69610096e-01, ...,
          -9.04644801e-02,  2.49868664e-01,  4.27480374e-01],
         [-2.75072826e-01,  2.17742311e-01,  4.38931985e-01, ...,
           3.69127904e-01,  3.31069363e-01, -2.37631915e-01],
         [-9.96343715e-02, -1.65785669e-01,  1.63011988e-01, ...,
          -1.63812937e-01,  2.04527635e-01, -2.48942967e-01]],

        [[ 4.86001743e-01,  2.29998833e-01, -3.25817801e-02, ...,
           1.96565070e-01,  1.13494411e-01, -3.44495232e-01],
         [ 1.06508119e-01,  5.29191188e-02,  1.68261855e-01, ...,
           3.46717681e

## Orig

In [74]:
def pool_forward(A_prev, hparameters, mode = "max"):
    (m, n_H_prev, n_W_prev, n_C_prev) = A_prev.shape
    f = hparameters["f"]
    stride = hparameters["stride"]
    n_H = int(1 + (n_H_prev - f) / stride)
    n_W = int(1 + (n_W_prev - f) / stride)
    n_C = n_C_prev
    A = np.zeros((m, n_H, n_W, n_C))
    for i in range(m):
        for h in range(n_H):
            for w in range(n_W):
                for c in range(n_C):
                    vert_start = h * stride
                    vert_end = h * stride + f
                    horiz_start = w * stride
                    horiz_end = w * stride + f
                    a_prev_slice = A_prev[i, vert_start:vert_end, horiz_start:horiz_end, c]
                    if mode == "max":
                        A[i, h, w, c] = np.max(a_prev_slice)
                    elif mode == "average":
                        A[i, h, w, c] = np.mean(a_prev_slice)
    cache = (A_prev, hparameters)
    assert(A.shape == (m, n_H, n_W, n_C))
    
    return A, cache

def create_mask_from_window(x):
    mask = x == np.max(x)
    return mask

def distribute_value(dz, shape):
    (n_H, n_W) = shape
    average = dz / (n_H * n_W)
    a = np.full(shape, average)
    return a

def pool_backward(dA, cache, mode = "max"):
    (A_prev, hparameters) = cache
    stride = hparameters['stride']
    f = hparameters['f']
    m, n_H_prev, n_W_prev, n_C_prev = A_prev.shape
    m, n_H, n_W, n_C = dA.shape
    dA_prev = np.zeros(A_prev.shape)
    
    for i in range(m):
        a_prev = A_prev[i]
        
        for h in range(n_H):
            for w in range(n_W):
                for c in range(n_C):
                    vert_start = h * stride
                    vert_end = h * stride + f
                    horiz_start = w * stride
                    horiz_end = w * stride + f
                    if mode == "max":
                        
                        a_prev_slice = a_prev[vert_start:vert_end, horiz_start:horiz_end, c]
                        mask = create_mask_from_window(a_prev_slice)
                        dA_prev[i, vert_start:vert_end, horiz_start:horiz_end, c] += mask * dA[i, h, w, c]
                        
                    elif mode == "average":
                        da = dA[i, h, w, c]
                        shape = (f, f)
                        dA_prev[i, vert_start: vert_end, horiz_start: horiz_end, c] += distribute_value(da, shape)
    assert(dA_prev.shape == A_prev.shape)
    
    return dA_prev

In [104]:
pool_forward(x.transpose(0, 2, 3, 1), {'stride': 3, 'f': 3}, mode='average')[0].transpose(0, 3, 1, 2)

array([[[[-1.94034949e-02, -1.21751450e-01,  1.91568016e-01, ...,
           2.67011683e-01,  5.11125972e-01, -1.31721787e-02],
         [-3.00269200e-01,  2.39818798e-02,  4.42508846e-01, ...,
          -4.21098032e-03,  4.41788673e-01,  1.59768802e-01],
         [ 8.24997430e-01, -4.14880407e-01,  3.46074606e-01, ...,
           2.53384165e-01,  7.01967441e-02,  2.61746794e-01],
         ...,
         [-4.86520747e-01, -6.43915407e-03, -4.69610096e-01, ...,
          -9.04644801e-02,  2.49868664e-01,  4.27480374e-01],
         [-2.75072826e-01,  2.17742311e-01,  4.38931985e-01, ...,
           3.69127904e-01,  3.31069363e-01, -2.37631915e-01],
         [-9.96343715e-02, -1.65785669e-01,  1.63011988e-01, ...,
          -1.63812937e-01,  2.04527635e-01, -2.48942967e-01]],

        [[ 4.86001743e-01,  2.29998833e-01, -3.25817801e-02, ...,
           1.96565070e-01,  1.13494411e-01, -3.44495232e-01],
         [ 1.06508119e-01,  5.29191188e-02,  1.68261855e-01, ...,
           3.46717681e