In [1]:
import numpy as np
import tensorflow as tf

# Notation

Max pooling reduces the width and height of given volumes.


| Symbols                   | Meaning               | Size                                              |
|:--------------------------|:----------------------|:--------------------------------------------------|
| $\boldsymbol{X}$          | 2D Input Matrix       | $(H_{in}$, $W_{in})$                              |
| $\boldsymbol{\mathsf{X}}$ | 3D Input Tensor       | $(H_{in}$, $W_{in}$, $C_{in})$                    |
| $\boldsymbol{\mathsf{X}}$ | 4D Input Tensor       | $(N_{batch}$, $H_{in}$, $W_{in}$, $C_{in})$       |
| $\boldsymbol{Y}$          | 2D Output Matrix      | $(H_{out}$, $W_{out})$                            |
| $\boldsymbol{\mathsf{Y}}$ | 3D Output Tensor      | $(H_{out}$, $W_{out}$, $C_{out})$                 |
| $\boldsymbol{\mathsf{Y}}$ | 4D Output Tensor      | $(N_{batch}$, $H_{out}$, $W_{out}$, $C_{out})$    |

As opposed to convolution, max pooling doesn't have any learnable parameters,
so there's no $\boldsymbol{\mathsf{W}}$ or $\boldsymbol{b}$ in the list of notation above.

# 08-2: Max Pooling Backward Pass

Try calculation with NumPy and TensorFlow for the followings

#### 1. Max Pooling with stride
$(4 \times 4) \rightarrow (2 \times 2)$ where $S=2$

#### 2. Multiple Channels
$(4 \times 4 \times 3) \rightarrow (2 \times 2 \times 3)$ where $S=2$

#### 3. Mini-batch
$(4 \times 4 \times 4 \times 3) \rightarrow (4 \times 2 \times 2 \times 3)$ where $S=2$

In [2]:
def float_sequence(size):
    return np.arange(size, dtype=np.float32)

## 1. Max Pooling with stride backward
$(4 \times 4) \rightarrow (2 \times 2)$ where $S=2$

TensorFlow: [tf.nn.max_pool](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool)

For the size of output width and height, the same fomula as convolution operation hold.

$$
H_{out} = \frac{H_{in} + 2P - H_{filter}}{S} + 1
$$

$$
W_{out} = \frac{W_{in} + 2P - F_{filter}}{S} + 1
$$

However, padding $P$ is not used in max pooling, and often filter size $H_{filter}$ and $F_{filter}$ is the same as stride $S$.
As a result, the size can be calculated as follwing.

$$
H_{out} = \frac{H_{in}}{S}
$$

$$
W_{out} = \frac{W_{in}}{S}
$$

In [13]:
W = float_sequence(2*2).reshape(2,2) + 1
X = float_sequence(4*4).reshape(4,4)
S = 2

H_out = 4 // S
W_out = 4 // S

print("=== X ===")
print(X)

dY = W

dX = np.zeros((4, 4))
for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S

        current_dY = dY[h,w]

        X_slice = X[h_start:h_end, w_start:w_end]
        flat_X_slice = X_slice.reshape(-1)
        max_index = np.argmax(flat_X_slice)

        gradient = np.zeros_like(flat_X_slice)
        gradient[max_index] = current_dY
        gradient = gradient.reshape(X_slice.shape)
        
#         print("gradient")
#         print(gradient)
                
        dX[h_start:h_end, w_start:w_end] = gradient

print("=== dX ===")     
print(dX)

with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 1))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W.reshape(1, 2, 2, 1))
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, 0])

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, 0]))

=== X ===
[[  0.   1.   2.   3.]
 [  4.   5.   6.   7.]
 [  8.   9.  10.  11.]
 [ 12.  13.  14.  15.]]
=== dX ===
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  2.]
 [ 0.  0.  0.  0.]
 [ 0.  3.  0.  4.]]
=== tf_dX ===
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  2.]
 [ 0.  0.  0.  0.]
 [ 0.  3.  0.  4.]]
=== Matched? ===
True


## 2. Multiple Channels backward
$(4 \times 4 \times 3) \rightarrow (2 \times 2 \times 3)$ where $S=2$

In [54]:
W = float_sequence(2*2*3).reshape(2,2,3) + 1
X = float_sequence(4*4*3).reshape(4,4,3)

S = 2

H_out = 4 // S
W_out = 4 // S

# print("=== X ===")
# print(X.transpose(2, 0, 1))

dY = W
dX = np.zeros((4, 4, 3))

for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S
        
        current_dY = dY[h,w,:]
#         print("====current_dY")
#         print(current_dY)

        X_slice = X[h_start:h_end, w_start:w_end, :]
#         print("====X_slice (raw)")
#         print(X_slice)
#         print("====X_slice (by channel)")
#         print(X_slice.transpose(2, 0, 1))
        
        flat_X_slice_by_channel = X_slice.transpose(2, 0, 1).reshape(3, -1)
#         print("====flat_X_slice_by_channel")
#         print(flat_X_slice_by_channel)
        max_index = np.argmax(flat_X_slice_by_channel, axis=1)
#         print("====max_index")
#         print(max_index)

        gradient = np.zeros_like(flat_X_slice_by_channel)
        gradient[np.arange(3), max_index] = current_dY
#         print("gradient")
#         print(gradient)

        gradient = gradient.reshape(X_slice.shape[2], X_slice.shape[0], X_slice.shape[1]).transpose(1, 2, 0)
        
#         print("gradient")
#         print(gradient)
                
        dX[h_start:h_end, w_start:w_end, :] = gradient


print("=== dX ===")     
print(dX.transpose(2, 0, 1))

with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W.reshape(1, 2, 2, 3))
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, :].transpose(2, 0, 1))

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, :]))

=== dX ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== tf_dX ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== Matched? ===
True


## 3. Mini-batch backward
$(4 \times 4 \times 4 \times 3) \rightarrow (4 \times 2 \times 2 \times 3)$ where $S=2$

In [57]:
W = float_sequence(4*2*2*3).reshape(4,2,2,3) + 1
X = float_sequence(4*4*4*3).reshape(4,4,4,3)

S = 2

H_out = 4 // S
W_out = 4 // S

# print("=== X (first) ===")
# print(X[0, :, :, :].transpose(2, 0, 1))

dY = W
dX = np.zeros((4, 4, 4, 3))

for n_batch in range(4):
    for h in range(W_out):
        for w in range(W_out):
            h_start = h * S
            h_end   = h_start + S
            w_start = w * S
            w_end   = w_start + S

            current_dY = dY[n_batch, h, w, :]

            X_slice = X[n_batch, h_start:h_end, w_start:w_end, :]
            flat_X_slice_by_channel = X_slice.transpose(2, 0, 1).reshape(3, -1)
            max_index = np.argmax(flat_X_slice_by_channel, axis=1)

            gradient = np.zeros_like(flat_X_slice_by_channel)
            gradient[np.arange(3), max_index] = current_dY
            gradient = gradient.reshape(X_slice.shape[2], X_slice.shape[0], X_slice.shape[1]).transpose(1, 2, 0)
                
            dX[n_batch, h_start:h_end, w_start:w_end, :] = gradient
            

print("=== dX (first) ===")     
print(dX[0, :, :, :].transpose(2, 0, 1))


with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(4, 4, 4, 3))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W.reshape(4, 2, 2, 3))
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX (first) ===")     
    print(tf_dX[0, :, :, :].transpose(2, 0, 1))

print("=== Matched? ===")    
print(np.all(dX == tf_dX[:, :, :, :]))

=== dX (first) ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== tf_dX (first) ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== Matched? ===
True


# Generalized naive max pooling backward

In [74]:
def max_pool_naive_backward(dY, X, S):
    N_batch, H_in, W_in, C_in = X.shape
    
    H_out = H_in // S
    W_out = W_in // S
    
    dX = np.zeros_like(X)

    for n_batch in range(N_batch):
        for h in range(H_out):
            for w in range(W_out):
                h_start = h * S
                h_end   = h_start + S
                w_start = w * S
                w_end   = w_start + S

                
                current_dY = dY[n_batch, h,w,:]

                X_slice = X[n_batch, h_start:h_end, w_start:w_end, :]
                flat_X_slice_by_channel = X_slice.transpose(2, 0, 1).reshape(C_in, -1)
                max_index = np.argmax(flat_X_slice_by_channel, axis=1)

                gradient = np.zeros_like(flat_X_slice_by_channel)
                gradient[np.arange(C_in), max_index] = current_dY
                gradient = gradient.reshape(X_slice.shape[2], X_slice.shape[0], X_slice.shape[1]).transpose(1, 2, 0)

                dX[n_batch, h_start:h_end, w_start:w_end, :] = gradient

    return dX

In [75]:
S = 2
W = float_sequence(10*4*4*3).reshape(10,4,4,3) + 1
X = np.random.randn(10, 8, 8, 3).astype(np.float32) * 45676.362342398
dY = W

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W)
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


# Benchmark

In [76]:
S = 2
X = np.random.randn(128, 28, 28, 3).astype(np.float32)
dY = np.ones((128, 14, 14, 3))

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


In [77]:
%%timeit -n3 -r3

max_pool_naive_backward(dY, X, S)

261 ms ± 2.66 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


# Shape Test

In [78]:
S = 3
X = np.random.randn(128, 30, 12, 6).astype(np.float32)
dY = np.ones((128, 10, 4, 6))

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


In [79]:
S = 5
X = np.random.randn(1, 100, 200, 16).astype(np.float32)
dY = np.ones((1, 20, 40, 16))

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0
