In [1]:
import numpy as np
import tensorflow as tf

# Notation

Max pooling reduces the width and height of given volumes.


| Symbols                   | Meaning               | Size                                              |
|:--------------------------|:----------------------|:--------------------------------------------------|
| $\boldsymbol{X}$          | 2D Input Matrix       | $(H_{in}$, $W_{in})$                              |
| $\boldsymbol{\mathsf{X}}$ | 3D Input Tensor       | $(H_{in}$, $W_{in}$, $C_{in})$                    |
| $\boldsymbol{\mathsf{X}}$ | 4D Input Tensor       | $(N_{batch}$, $H_{in}$, $W_{in}$, $C_{in})$       |
| $\boldsymbol{Y}$          | 2D Output Matrix      | $(H_{out}$, $W_{out})$                            |
| $\boldsymbol{\mathsf{Y}}$ | 3D Output Tensor      | $(H_{out}$, $W_{out}$, $C_{out})$                 |
| $\boldsymbol{\mathsf{Y}}$ | 4D Output Tensor      | $(N_{batch}$, $H_{out}$, $W_{out}$, $C_{out})$    |

As opposed to convolution, max pooling doesn't have any learnable parameters,
so there's no $\boldsymbol{\mathsf{W}}$ or $\boldsymbol{b}$ in the list of notation above.

# 08-2: Max Pooling Backward Pass

Try calculation with NumPy and TensorFlow for the followings

#### 1. Max Pooling with stride
$(4 \times 4) \rightarrow (2 \times 2)$ where $S=2$

#### 2. Multiple Channels
$(4 \times 4 \times 3) \rightarrow (2 \times 2 \times 3)$ where $S=2$

#### 3. Mini-batch
$(4 \times 4 \times 4 \times 3) \rightarrow (4 \times 2 \times 2 \times 3)$ where $S=2$

In [2]:
def float_sequence(size):
    return np.arange(size, dtype=np.float32)

## 1. Max Pooling with stride backward
$(4 \times 4) \rightarrow (2 \times 2)$ where $S=2$

TensorFlow: [tf.nn.max_pool](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool)

For the size of output width and height, the same fomula as convolution operation hold.

$$
H_{out} = \frac{H_{in} + 2P - H_{filter}}{S} + 1
$$

$$
W_{out} = \frac{W_{in} + 2P - F_{filter}}{S} + 1
$$

However, padding $P$ is not used in max pooling, and often filter size $H_{filter}$ and $F_{filter}$ is the same as stride $S$.
As a result, the size can be calculated as follwing.

$$
H_{out} = \frac{H_{in}}{S}
$$

$$
W_{out} = \frac{W_{in}}{S}
$$

In [26]:
X = float_sequence(4*4).reshape(4,4)
S = 2

H_out = 4 // S
W_out = 4 // S

print("=== X ===")
print(X)

dY = np.ones((H_out, W_out))

dX = np.zeros((4, 4))
for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S
        
        X_slice = X[h_start:h_end, w_start:w_end]
        X_slice_mask = X_slice == np.max(X_slice)
        current_dY = dY[h,w]
        dX[h_start:h_end, w_start:w_end] = X_slice_mask * current_dY
#         print(current_dY)
#         print(X_slice_mask)

print("=== dX ===")     
print(dX)

with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 1))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, 0])

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, 0]))

=== X ===
[[  0.   1.   2.   3.]
 [  4.   5.   6.   7.]
 [  8.   9.  10.  11.]
 [ 12.  13.  14.  15.]]
=== dX ===
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  1.]
 [ 0.  0.  0.  0.]
 [ 0.  1.  0.  1.]]
=== tf_dX ===
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  1.]
 [ 0.  0.  0.  0.]
 [ 0.  1.  0.  1.]]
=== Matched? ===
True


## 2. Multiple Channels backward
$(4 \times 4 \times 3) \rightarrow (2 \times 2 \times 3)$ where $S=2$

In [44]:
X = float_sequence(4*4*3).reshape(4,4,3)

S = 2

H_out = 4 // S
W_out = 4 // S

# print("=== X ===")
# print(X.transpose(2, 0, 1))

dY = np.ones((H_out, W_out, 3))
dX = np.zeros((4, 4, 3))

for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S
        
        X_slice = X[h_start:h_end, w_start:w_end, :]
        X_slice_mask = X_slice == np.max(X_slice, axis=(0,1))
        current_dY = dY[h,w,:]
        dX_slice = X_slice_mask * current_dY
        dX[h_start:h_end, w_start:w_end, :] = dX_slice


print("=== dX ===")     
print(dX.transpose(2, 0, 1))

with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, :].transpose(2, 0, 1))

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, :]))

=== dX ===
[[[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]

 [[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]

 [[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]]
=== tf_dX ===
[[[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]

 [[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]

 [[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]]
=== Matched? ===
True


## 3. Mini-batch
$(4 \times 4 \times 4 \times 3) \rightarrow (4 \times 2 \times 2 \times 3)$ where $S=2$

In [54]:
X = float_sequence(4*4*4*3).reshape(4,4,4,3)

S = 2

H_out = 4 // S
W_out = 4 // S

# print("=== X (first) ===")
# print(X[0, :, :, :].transpose(2, 0, 1))


dY = np.ones((4, H_out, W_out, 3))
dX = np.zeros((4, 4, 4, 3))

for n_
for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S
        
        X_slice = X[h_start:h_end, w_start:w_end, :]
        X_slice_mask = X_slice == np.max(X_slice, axis=(0,1))
        current_dY = dY[h,w,:]
        dX_slice = X_slice_mask * current_dY
        dX[h_start:h_end, w_start:w_end, :] = dX_slice


print("=== dX ===")     
print(dX[0, :, :, :].transpose(2, 0, 1))


with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(4, 4, 4, 3))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, :].transpose(2, 0, 1))

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, :]))

X_slice.shape: (4, 2, 2, 3)
X_slice_max.shape: (4, 3)
False
X_slice.shape: (4, 2, 2, 3)
X_slice_max.shape: (4, 3)
False
X_slice.shape: (4, 2, 2, 3)
X_slice_max.shape: (4, 3)
False
X_slice.shape: (4, 2, 2, 3)
X_slice_max.shape: (4, 3)
False
=== dX ===
[[[ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]]]
=== tf_dX ===
[[[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]

 [[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]

 [[ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]
  [ 0.  0.  0.  0.]
  [ 0.  1.  0.  1.]]]
=== Matched? ===
False




# Generalize naive max pooling foward

In [60]:
def max_pool_naive_foward(X, S):
    N_batch, H_in, W_in, C_in = X.shape
    
    H_out = H_in // S
    W_out = W_in // S

    Y = np.zeros((N_batch, H_out, W_out, C_in))
    for h in range(H_out):
        h_start = h * S
        h_end   = h_start + S
        for w in range(W_out):
            w_start = w * S
            w_end   = w_start + S
            X_slice = X[:, h_start:h_end, w_start:w_end, :]
            Y[:, h, w, :] = np.max(X_slice, axis=(1,2))

    return Y

In [61]:
S = 2
X = np.random.randn(10, 8, 8, 3).astype(np.float32)

Y = max_pool_naive_foward(X, S)

with tf.Session() as sess:
    Y_tf = sess.run(tf.nn.max_pool(
        X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ))

print("=== Matched? ===")    
check = np.linalg.norm(Y - Y_tf) / ((np.linalg.norm(Y) + np.linalg.norm(Y_tf)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


# Benchmark

In [62]:
S = 2
X = np.random.randn(128, 28, 28, 3).astype(np.float32)

Y = max_pool_naive_foward(X, S)

with tf.Session() as sess:
    Y_tf = sess.run(tf.nn.max_pool(
        X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ))

print("=== Matched? ===")    
check = np.linalg.norm(Y - Y_tf) / ((np.linalg.norm(Y) + np.linalg.norm(Y_tf)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


In [63]:
%%timeit -n3 -r3

max_pool_naive_foward(X, S)

9.91 ms ± 753 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)


# Shape Test

In [66]:
S = 3
X = np.random.randn(128, 30, 12, 6).astype(np.float32)

Y = max_pool_naive_foward(X, S)
print(Y.shape)

with tf.Session() as sess:
    Y_tf = sess.run(tf.nn.max_pool(
        X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ))

print("=== Matched? ===")    
check = np.linalg.norm(Y - Y_tf) / ((np.linalg.norm(Y) + np.linalg.norm(Y_tf)))
print(check < 1e-7, check)

(128, 10, 4, 6)
=== Matched? ===
True 0.0


In [67]:
S = 5
X = np.random.randn(1, 100, 200, 16).astype(np.float32)

Y = max_pool_naive_foward(X, S)
print(Y.shape)

with tf.Session() as sess:
    Y_tf = sess.run(tf.nn.max_pool(
        X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ))

print("=== Matched? ===")    
check = np.linalg.norm(Y - Y_tf) / ((np.linalg.norm(Y) + np.linalg.norm(Y_tf)))
print(check < 1e-7, check)

(1, 20, 40, 16)
=== Matched? ===
True 0.0
