In [1]:
import numpy as np
import tensorflow as tf

# Notation

Max pooling reduces the width and height of given volumes.


| Symbols                   | Meaning               | Size                                              |
|:--------------------------|:----------------------|:--------------------------------------------------|
| $\boldsymbol{X}$          | 2D Input Matrix       | $(H_{in}$, $W_{in})$                              |
| $\boldsymbol{\mathsf{X}}$ | 3D Input Tensor       | $(H_{in}$, $W_{in}$, $C_{in})$                    |
| $\boldsymbol{\mathsf{X}}$ | 4D Input Tensor       | $(N_{batch}$, $H_{in}$, $W_{in}$, $C_{in})$       |
| $\boldsymbol{Y}$          | 2D Output Matrix      | $(H_{out}$, $W_{out})$                            |
| $\boldsymbol{\mathsf{Y}}$ | 3D Output Tensor      | $(H_{out}$, $W_{out}$, $C_{out})$                 |
| $\boldsymbol{\mathsf{Y}}$ | 4D Output Tensor      | $(N_{batch}$, $H_{out}$, $W_{out}$, $C_{out})$    |

As opposed to convolution, max pooling doesn't have any learnable parameters,
so there's no $\boldsymbol{\mathsf{W}}$ or $\boldsymbol{b}$ in the list of notation above.

# 08-2: Max Pooling Backward Pass

Try calculation with NumPy and TensorFlow for the followings

#### 1. Max Pooling with stride
$(4 \times 4) \rightarrow (2 \times 2)$ where $S=2$

#### 2. Multiple Channels
$(4 \times 4 \times 3) \rightarrow (2 \times 2 \times 3)$ where $S=2$

#### 3. Mini-batch
$(4 \times 4 \times 4 \times 3) \rightarrow (4 \times 2 \times 2 \times 3)$ where $S=2$

In [2]:
def float_sequence(size):
    return np.arange(size, dtype=np.float32)

## 1. Max Pooling with stride backward
$(4 \times 4) \rightarrow (2 \times 2)$ where $S=2$

TensorFlow: [tf.nn.max_pool](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool)

For the size of output width and height, the same fomula as convolution operation hold.

$$
H_{out} = \frac{H_{in} + 2P - H_{filter}}{S} + 1
$$

$$
W_{out} = \frac{W_{in} + 2P - F_{filter}}{S} + 1
$$

However, padding $P$ is not used in max pooling, and often filter size $H_{filter}$ and $F_{filter}$ is the same as stride $S$.
As a result, the size can be calculated as follwing.

$$
H_{out} = \frac{H_{in}}{S}
$$

$$
W_{out} = \frac{W_{in}}{S}
$$

In [13]:
W = float_sequence(2*2).reshape(2,2) + 1
X = float_sequence(4*4).reshape(4,4)
S = 2

H_out = 4 // S
W_out = 4 // S

print("=== X ===")
print(X)

dY = W

dX = np.zeros((4, 4))
for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S

        current_dY = dY[h,w]

        X_slice = X[h_start:h_end, w_start:w_end]
        flat_X_slice = X_slice.reshape(-1)
        max_index = np.argmax(flat_X_slice)

        gradient = np.zeros_like(flat_X_slice)
        gradient[max_index] = current_dY
        gradient = gradient.reshape(X_slice.shape)
        
#         print("gradient")
#         print(gradient)
                
        dX[h_start:h_end, w_start:w_end] = gradient

print("=== dX ===")     
print(dX)

with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 1))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W.reshape(1, 2, 2, 1))
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, 0])

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, 0]))

=== X ===
[[  0.   1.   2.   3.]
 [  4.   5.   6.   7.]
 [  8.   9.  10.  11.]
 [ 12.  13.  14.  15.]]
=== dX ===
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  2.]
 [ 0.  0.  0.  0.]
 [ 0.  3.  0.  4.]]
=== tf_dX ===
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  2.]
 [ 0.  0.  0.  0.]
 [ 0.  3.  0.  4.]]
=== Matched? ===
True


## 2. Multiple Channels backward
$(4 \times 4 \times 3) \rightarrow (2 \times 2 \times 3)$ where $S=2$

In [54]:
W = float_sequence(2*2*3).reshape(2,2,3) + 1
X = float_sequence(4*4*3).reshape(4,4,3)

S = 2

H_out = 4 // S
W_out = 4 // S

# print("=== X ===")
# print(X.transpose(2, 0, 1))

dY = W
dX = np.zeros((4, 4, 3))

for h in range(W_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + S
        w_start = w * S
        w_end   = w_start + S
        
        current_dY = dY[h,w,:]
#         print("====current_dY")
#         print(current_dY)

        X_slice = X[h_start:h_end, w_start:w_end, :]
#         print("====X_slice (raw)")
#         print(X_slice)
#         print("====X_slice (by channel)")
#         print(X_slice.transpose(2, 0, 1))
        
        flat_X_slice_by_channel = X_slice.transpose(2, 0, 1).reshape(3, -1)
#         print("====flat_X_slice_by_channel")
#         print(flat_X_slice_by_channel)
        max_index = np.argmax(flat_X_slice_by_channel, axis=1)
#         print("====max_index")
#         print(max_index)

        gradient = np.zeros_like(flat_X_slice_by_channel)
        gradient[np.arange(3), max_index] = current_dY
#         print("gradient")
#         print(gradient)

        gradient = gradient.reshape(X_slice.shape[2], X_slice.shape[0], X_slice.shape[1]).transpose(1, 2, 0)
        
#         print("gradient")
#         print(gradient)
                
        dX[h_start:h_end, w_start:w_end, :] = gradient


print("=== dX ===")     
print(dX.transpose(2, 0, 1))

with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W.reshape(1, 2, 2, 3))
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX ===")     
    print(tf_dX[0, :, :, :].transpose(2, 0, 1))

print("=== Matched? ===")    
print(np.all(dX == tf_dX[0, :, :, :]))

=== dX ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== tf_dX ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== Matched? ===
True


## 3. Mini-batch
$(4 \times 4 \times 4 \times 3) \rightarrow (4 \times 2 \times 2 \times 3)$ where $S=2$

In [55]:
W = float_sequence(4*2*2*3).reshape(4,2,2,3) + 1
X = float_sequence(4*4*4*3).reshape(4,4,4,3)

S = 2

H_out = 4 // S
W_out = 4 // S

# print("=== X (first) ===")
# print(X[0, :, :, :].transpose(2, 0, 1))

dY = W
dX = np.zeros((4, 4, 4, 3))

for n_batch in range(4):
    for h in range(W_out):
        for w in range(W_out):
            h_start = h * S
            h_end   = h_start + S
            w_start = w * S
            w_end   = w_start + S

            
            
            current_dY = dY[n_batch, h,w,:]

            X_slice = X[h_start:h_end, w_start:w_end, :]
#         print("====X_slice (raw)")
#         print(X_slice)
#         print("====X_slice (by channel)")
#         print(X_slice.transpose(2, 0, 1))
        
        flat_X_slice_by_channel = X_slice.transpose(2, 0, 1).reshape(3, -1)
#         print("====flat_X_slice_by_channel")
#         print(flat_X_slice_by_channel)
        max_index = np.argmax(flat_X_slice_by_channel, axis=1)
#         print("====max_index")
#         print(max_index)

        gradient = np.zeros_like(flat_X_slice_by_channel)
        gradient[np.arange(3), max_index] = current_dY
#         print("gradient")
#         print(gradient)

        gradient = gradient.reshape(X_slice.shape[2], X_slice.shape[0], X_slice.shape[1]).transpose(1, 2, 0)
        
#         print("gradient")
#         print(gradient)
                
        dX[h_start:h_end, w_start:w_end, :] = gradient
            
            
            
            
            
            
            
            X_slice = X[n_batch, h_start:h_end, w_start:w_end, :]
            X_slice_mask = X_slice == np.max(X_slice, axis=(0,1))
            
            dX_slice = X_slice_mask * current_dY
            dX[n_batch, h_start:h_end, w_start:w_end, :] = dX_slice


print("=== dX (first) ===")     
print(dX[0, :, :, :].transpose(2, 0, 1))


with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(4, 4, 4, 3))
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W.reshape(4, 2, 2, 3))
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)
    print("=== tf_dX (first) ===")     
    print(tf_dX[0, :, :, :].transpose(2, 0, 1))

print("=== Matched? ===")    
print(np.all(dX == tf_dX[:, :, :, :]))

=== dX (first) ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== tf_dX (first) ===
[[[  0.   0.   0.   0.]
  [  0.   1.   0.   4.]
  [  0.   0.   0.   0.]
  [  0.   7.   0.  10.]]

 [[  0.   0.   0.   0.]
  [  0.   2.   0.   5.]
  [  0.   0.   0.   0.]
  [  0.   8.   0.  11.]]

 [[  0.   0.   0.   0.]
  [  0.   3.   0.   6.]
  [  0.   0.   0.   0.]
  [  0.   9.   0.  12.]]]
=== Matched? ===
True


# Generalize naive max pooling backward

In [133]:
def max_pool_naive_backward(dY, X, S):
    N_batch, H_in, W_in, C_in = X.shape
    
    H_out = H_in // S
    W_out = W_in // S
    
    dX = np.zeros_like(X)

    for n_batch in range(N_batch):
        for h in range(H_out):
            for w in range(W_out):
                h_start = h * S
                h_end   = h_start + S
                w_start = w * S
                w_end   = w_start + S

                X_slice = X[n_batch, h_start:h_end, w_start:w_end, :]
                X_slice_mask = X_slice == np.max(X_slice, axis=(0,1))
                current_dY = dY[n_batch, h,w,:]
                dX_slice = X_slice_mask * current_dY
                dX[n_batch, h_start:h_end, w_start:w_end, :] = dX_slice

    return dX

In [137]:
S = 2
W = float_sequence(10*4*4*3).reshape(10,4,4,3) + 1
X = np.random.randn(10, 8, 8, 3).astype(np.float32) * 45676.362342398
dY = W

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W)
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


# Check with numerical gradient

In [96]:
def numerical_gradient(loss):
    h = 1e-5
    gradients = np.zeros_like(X)

    itr = np.nditer(X, flags=['multi_index'], op_flags=['readwrite'])
    while not itr.finished:
        original = itr[0].copy()

        itr[0] = original + h
        print("original + h: {}".format(itr[0]))
        v1 = loss()
        itr[0] = original - h
        print("original - h: {}".format(itr[0]))
        v2 = loss()
        gradients[itr.multi_index] = (v1 - v2) / (2 * h)
        print("grad: {}".format(gradients[itr.multi_index]))

        itr[0] = original
        itr.iternext()

    return gradients

In [97]:
total = lambda: np.sum(max_pool_naive_backward(dY, X, S))

dX_numerical = numerical_gradient(total)

print(dX_numerical)
# print(dX - dX_numerical)
# print(np.linalg.norm(dX))
# print(np.linalg.norm(dX_numerical))

print("=== Matched? ===")
check = np.linalg.norm(dX - dX_numerical) / (np.linalg.norm(dX) + np.linalg.norm(dX_numerical))
print(check < 1e-7, check)

original + h: 0.8382022976875305
original - h: 0.838182270526886
grad: 0.0
original + h: -0.3872354328632355
original - h: -0.38725546002388
grad: 0.0
original + h: -4.137702941894531
original - h: -4.137722969055176
grad: 0.0
original + h: 0.22300973534584045
original - h: 0.2229897379875183
grad: 0.0
original + h: 0.6211698651313782
original - h: 0.6211498379707336
grad: 0.0
original + h: 0.5609484910964966
original - h: 0.560928463935852
grad: 0.0
original + h: 0.692089319229126
original - h: 0.6920692920684814
grad: 0.0
original + h: -0.22145341336727142
original - h: -0.22147341072559357
grad: 0.0
original + h: 1.0013631582260132
original - h: 1.0013431310653687
grad: 0.0
original + h: -1.0250244140625
original - h: -1.0250444412231445
grad: 0.0
original + h: -0.5196880102157593
original - h: -0.5197080373764038
grad: 0.0
original + h: -0.16758014261722565
original - h: -0.1676001399755478
grad: 0.0
original + h: -0.8113705515861511
original - h: -0.8113905787467957
grad: 0.0
orig

grad: 0.0
original + h: -0.04352133721113205
original - h: -0.04354133456945419
grad: 0.0
original + h: 0.028296200558543205
original - h: 0.028276199474930763
grad: 0.0
original + h: 1.8697636127471924
original - h: 1.8697435855865479
grad: 0.0
original + h: -2.2315330505371094
original - h: -2.231553077697754
grad: 0.0
original + h: -1.1276960372924805
original - h: -1.127716064453125
grad: 0.0
original + h: -0.8171078562736511
original - h: -0.8171278834342957
grad: 0.0
original + h: 0.36280742287635803
original - h: 0.3627873957157135
grad: 0.0
original + h: -0.33513307571411133
original - h: -0.33515310287475586
grad: 0.0
original + h: -0.2741439640522003
original - h: -0.27416399121284485
grad: 0.0
original + h: 1.1198334693908691
original - h: 1.1198134422302246
grad: 0.0
original + h: 0.05569630116224289
original - h: 0.055676303803920746
grad: 0.0
original + h: -0.9885122776031494
original - h: -0.988532304763794
grad: 0.0
original + h: -1.0627262592315674
original - h: -1.062

original - h: -1.412673830986023
grad: 0.0
original + h: -1.6700736284255981
original - h: -1.6700936555862427
grad: 0.0
original + h: -0.24514368176460266
original - h: -0.2451636791229248
grad: 0.0
original + h: -1.4256255626678467
original - h: -1.4256455898284912
grad: 0.0
original + h: 1.986966848373413
original - h: 1.9869468212127686
grad: 0.0
original + h: 0.5099841356277466
original - h: 0.509964108467102
grad: 0.0
original + h: -2.0652244091033936
original - h: -2.065244436264038
grad: 0.0
original + h: -1.2192121744155884
original - h: -1.219232201576233
grad: 0.0
original + h: -0.3195916414260864
original - h: -0.31961166858673096
grad: 0.0
original + h: -0.8336932063102722
original - h: -0.8337132334709167
grad: 0.0
original + h: 0.5287779569625854
original - h: 0.5287579298019409
grad: 0.0
original + h: 0.03746993839740753
original - h: 0.03744994103908539
grad: 0.0
original + h: 0.3713480532169342
original - h: 0.3713280260562897
grad: 0.0
original + h: 0.978754043579101

original + h: -1.6410590410232544
original - h: -1.641079068183899
grad: 0.0
original + h: 0.23330265283584595
original - h: 0.2332826554775238
grad: 0.0
original + h: -0.041563261300325394
original - h: -0.04158325865864754
grad: 0.0
original + h: 0.3703995645046234
original - h: 0.3703795373439789
grad: 0.0
original + h: 0.3012899160385132
original - h: 0.30126988887786865
grad: 0.0
original + h: -0.9762188196182251
original - h: -0.9762388467788696
grad: 0.0
original + h: 0.09773822873830795
original - h: 0.09771823137998581
grad: 0.0
original + h: -0.698483943939209
original - h: -0.6985039710998535
grad: 0.0
original + h: 0.7441554069519043
original - h: 0.7441353797912598
grad: 0.0
original + h: 0.23292307555675507
original - h: 0.23290307819843292
grad: 0.0
original + h: -1.9155515432357788
original - h: -1.9155715703964233
grad: 0.0
original + h: -0.985560953617096
original - h: -0.9855809807777405
grad: 0.0
original + h: 1.2292439937591553
original - h: 1.2292239665985107
grad

original + h: -0.6474200487136841
original - h: -0.6474400758743286
grad: 0.0
original + h: 1.7134207487106323
original - h: 1.7134007215499878
grad: 0.0
original + h: 0.9388141632080078
original - h: 0.9387941360473633
grad: 0.0
original + h: -0.7538413405418396
original - h: -0.7538613677024841
grad: 0.0
original + h: -0.30938974022865295
original - h: -0.3094097673892975
grad: 0.0
original + h: 1.8536862134933472
original - h: 1.8536661863327026
grad: 0.0
original + h: -2.0213842391967773
original - h: -2.021404266357422
grad: 0.0
original + h: -0.1679116040468216
original - h: -0.16793160140514374
grad: 0.0
original + h: -0.6018453240394592
original - h: -0.6018653512001038
grad: 0.0
original + h: 0.15234538912773132
original - h: 0.15232539176940918
grad: 0.0
original + h: -1.4364103078842163
original - h: -1.4364303350448608
grad: 0.0
original + h: -1.9480576515197754
original - h: -1.94807767868042
grad: 0.0
original + h: -0.47653934359550476
original - h: -0.4765593707561493
gr

grad: 0.0
original + h: 0.9765400290489197
original - h: 0.9765200018882751
grad: 0.0
original + h: -0.5263507962226868
original - h: -0.5263708233833313
grad: 0.0
original + h: -0.10296695679426193
original - h: -0.10298695415258408
grad: 0.0
original + h: 1.0001765489578247
original - h: 1.0001565217971802
grad: 0.0
original + h: -0.23083211481571198
original - h: -0.23085211217403412
grad: 0.0
original + h: -1.7636572122573853
original - h: -1.7636772394180298
grad: 0.0
original + h: -0.7683173418045044
original - h: -0.7683373689651489
grad: 0.0
original + h: 0.326436311006546
original - h: 0.3264162838459015
grad: 0.0
original + h: 0.525391697883606
original - h: 0.5253716707229614
grad: 0.0
original + h: -1.7205579280853271
original - h: -1.7205779552459717
grad: 0.0
original + h: -0.6196492910385132
original - h: -0.6196693181991577
grad: 0.0
original + h: 0.8174890279769897
original - h: 0.8174690008163452
grad: 0.0
original + h: -0.21998053789138794
original - h: -0.2200005352

grad: 0.0
original + h: -0.5371118783950806
original - h: -0.5371319055557251
grad: 0.0
original + h: -1.5542477369308472
original - h: -1.5542677640914917
grad: 0.0
original + h: -1.8024804592132568
original - h: -1.8025004863739014
grad: 0.0
original + h: -0.27021732926368713
original - h: -0.27023735642433167
grad: 0.0
original + h: -0.4972253441810608
original - h: -0.4972453713417053
grad: 0.0
original + h: 0.5449322462081909
original - h: 0.5449122190475464
grad: 0.0
original + h: -1.260988712310791
original - h: -1.2610087394714355
grad: 0.0
original + h: -1.3805646896362305
original - h: -1.380584716796875
grad: 0.0
original + h: 0.07310640811920166
original - h: 0.07308641076087952
grad: 0.0
original + h: 0.5100798010826111
original - h: 0.5100597739219666
grad: 0.0
original + h: 0.15693353116512299
original - h: 0.15691353380680084
grad: 0.0
original + h: -0.8968213796615601
original - h: -0.8968414068222046
grad: 0.0
original + h: 1.404510259628296
original - h: 1.4044902324

grad: 0.0
original + h: 0.7086333632469177
original - h: 0.7086133360862732
grad: 0.0
original + h: -1.2170346975326538
original - h: -1.2170547246932983
grad: 0.0
original + h: -0.4780312180519104
original - h: -0.47805124521255493
grad: 0.0
original + h: -1.1483337879180908
original - h: -1.1483538150787354
grad: 0.0
original + h: -2.4549992084503174
original - h: -2.455019235610962
grad: 0.0
original + h: -1.5800632238388062
original - h: -1.5800832509994507
grad: 0.0
original + h: 0.5735613703727722
original - h: 0.5735413432121277
grad: 0.0
original + h: 0.09856375306844711
original - h: 0.09854375571012497
grad: 0.0
original + h: -0.8086383938789368
original - h: -0.8086584210395813
grad: 0.0
original + h: 0.906036913394928
original - h: 0.9060168862342834
grad: 0.0
original + h: -0.4430875778198242
original - h: -0.44310760498046875
grad: 0.0
original + h: -1.4152435064315796
original - h: -1.4152635335922241
grad: 0.0
original + h: -0.721487283706665
original - h: -0.7215073108

grad: 0.0
original + h: -0.22065147757530212
original - h: -0.22067147493362427
grad: 0.0
original + h: -0.17495228350162506
original - h: -0.1749722808599472
grad: 0.0
original + h: 1.0958714485168457
original - h: 1.0958514213562012
grad: 0.0
original + h: 0.9834659695625305
original - h: 0.983445942401886
grad: 0.0
original + h: -1.2791131734848022
original - h: -1.2791332006454468
grad: 0.0
original + h: -0.639503002166748
original - h: -0.6395230293273926
grad: 0.0
original + h: 1.5283727645874023
original - h: 1.5283527374267578
grad: 0.0
original + h: 1.154041051864624
original - h: 1.1540210247039795
grad: 0.0
original + h: -2.321488618850708
original - h: -2.3215086460113525
grad: 0.0
original + h: 0.7102553844451904
original - h: 0.7102353572845459
grad: 0.0
original + h: 0.7052170038223267
original - h: 0.7051969766616821
grad: 0.0
original + h: 1.0764306783676147
original - h: 1.0764106512069702
grad: 0.0
original + h: -0.5387367606163025
original - h: -0.538756787776947
gr

original - h: 0.8221956491470337
grad: 0.0
original + h: -0.37921953201293945
original - h: -0.379239559173584
grad: 0.0
original + h: -0.5950680375099182
original - h: -0.5950880646705627
grad: 0.0
original + h: 0.7599470019340515
original - h: 0.759926974773407
grad: 0.0
original + h: 0.8493183255195618
original - h: 0.8492982983589172
grad: 0.0
original + h: 0.8157874941825867
original - h: 0.8157674670219421
grad: 0.0
original + h: -0.03924637660384178
original - h: -0.039266373962163925
grad: 0.0
original + h: 1.0374723672866821
original - h: 1.0374523401260376
grad: 0.0
original + h: 0.8730651140213013
original - h: 0.8730450868606567
grad: 0.0
original + h: 0.9972822666168213
original - h: 0.9972622394561768
grad: 0.0
original + h: -0.16990503668785095
original - h: -0.1699250340461731
grad: 0.0
original + h: 0.4145919978618622
original - h: 0.41457197070121765
grad: 0.0
original + h: 0.56217360496521
original - h: 0.5621535778045654
grad: 0.0
original + h: 0.6950957179069519
or

grad: 0.0
original + h: 1.256126046180725
original - h: 1.2561060190200806
grad: 0.0
original + h: -0.053203877061605453
original - h: -0.0532238744199276
grad: 0.0
original + h: -0.8788478374481201
original - h: -0.8788678646087646
grad: 0.0
original + h: 1.0839519500732422
original - h: 1.0839319229125977
grad: 0.0
original + h: 0.8056451082229614
original - h: 0.8056250810623169
grad: 0.0
original + h: 0.22377106547355652
original - h: 0.22375106811523438
grad: 0.0
original + h: -0.6366736888885498
original - h: -0.6366937160491943
grad: 0.0
original + h: -2.7180778980255127
original - h: -2.7180979251861572
grad: 0.0
original + h: 0.4947183132171631
original - h: 0.49469828605651855
grad: 0.0
original + h: 0.07889744639396667
original - h: 0.07887744903564453
grad: 0.0
original + h: -0.6822459697723389
original - h: -0.6822659969329834
grad: 0.0
original + h: 1.824265956878662
original - h: 1.8242459297180176
grad: 0.0
original + h: 1.437852144241333
original - h: 1.437832117080688

original - h: -0.9049864411354065
grad: 0.0
original + h: 0.05585339665412903
original - h: 0.055833399295806885
grad: 0.0
original + h: 1.061158537864685
original - h: 1.0611385107040405
grad: 0.0
original + h: 0.28013208508491516
original - h: 0.28011205792427063
grad: 0.0
original + h: 1.8392763137817383
original - h: 1.8392562866210938
grad: 0.0
original + h: -0.6119778752326965
original - h: -0.6119979023933411
grad: 0.0
original + h: 0.8393813967704773
original - h: 0.8393613696098328
grad: 0.0
original + h: -0.6304080486297607
original - h: -0.6304280757904053
grad: 0.0
original + h: -0.3026934862136841
original - h: -0.3027135133743286
grad: 0.0
original + h: 0.028129354119300842
original - h: 0.0281093530356884
grad: 0.0
original + h: 2.0877254009246826
original - h: 2.087705373764038
grad: 0.0
original + h: -0.11273603141307831
original - h: -0.11275602877140045
grad: 0.0
original + h: 0.3388834297657013
original - h: 0.33886340260505676
grad: 0.0
original + h: -1.11104309558

original - h: -0.47021928429603577
grad: 0.0
original + h: -0.46415072679519653
original - h: -0.46417075395584106
grad: 0.0
original + h: -0.7681942582130432
original - h: -0.7682142853736877
grad: 0.0
original + h: -0.7254016995429993
original - h: -0.7254217267036438
grad: 0.0
original + h: -0.2421034425497055
original - h: -0.24212343990802765
grad: 0.0
original + h: -0.7767100930213928
original - h: -0.7767301201820374
grad: 0.0
original + h: -0.08737002313137054
original - h: -0.08739002048969269
grad: 0.0
original + h: 0.1370571106672287
original - h: 0.13703711330890656
grad: 0.0
original + h: 0.49729302525520325
original - h: 0.4972729980945587
grad: 0.0
original + h: 0.20930488407611847
original - h: 0.20928488671779633
grad: 0.0
original + h: -0.06739649921655655
original - h: -0.06741649657487869
grad: 0.0
original + h: 0.5611785054206848
original - h: 0.5611584782600403
grad: 0.0
original + h: 0.9599012136459351
original - h: 0.9598811864852905
grad: 0.0
original + h: 0.77

grad: 0.0
original + h: 0.12129584699869156
original - h: 0.12127584964036942
grad: 0.0
original + h: -0.4754384756088257
original - h: -0.4754585027694702
grad: 0.0
original + h: -2.427929162979126
original - h: -2.4279491901397705
grad: 0.0
original + h: -1.2079322338104248
original - h: -1.2079522609710693
grad: 0.0
original + h: -0.15523231029510498
original - h: -0.15525230765342712
grad: 0.0
original + h: -0.8350353837013245
original - h: -0.835055410861969
grad: 0.0
original + h: 0.08567081391811371
original - h: 0.08565081655979156
grad: 0.0
original + h: 0.1256280094385147
original - h: 0.12560801208019257
grad: 0.0
original + h: 0.7972666025161743
original - h: 0.7972465753555298
grad: 0.0
original + h: 0.9868473410606384
original - h: 0.9868273138999939
grad: 0.0
original + h: -1.496427297592163
original - h: -1.4964473247528076
grad: 0.0
original + h: 0.6888688206672668
original - h: 0.6888487935066223
grad: 0.0
original + h: -0.43473049998283386
original - h: -0.4347505271

original - h: -1.5330089330673218
grad: 0.0
original + h: -0.4310056269168854
original - h: -0.4310256540775299
grad: 0.0
original + h: 0.7322885394096375
original - h: 0.7322685122489929
grad: 0.0
original + h: -0.6928749084472656
original - h: -0.6928949356079102
grad: 0.0
original + h: 0.05987907573580742
original - h: 0.059859078377485275
grad: 0.0
original + h: -1.0379937887191772
original - h: -1.0380138158798218
grad: 0.0
original + h: 0.8945986032485962
original - h: 0.8945785760879517
grad: 0.0
original + h: 1.441392421722412
original - h: 1.4413723945617676
grad: 0.0
original + h: -1.0972809791564941
original - h: -1.0973010063171387
grad: 0.0
original + h: 0.8025840520858765
original - h: 0.8025640249252319
grad: 0.0
original + h: 1.5981472730636597
original - h: 1.5981272459030151
grad: 0.0
original + h: 0.6203116774559021
original - h: 0.6202916502952576
grad: 0.0
original + h: 0.17585940659046173
original - h: 0.1758394092321396
grad: 0.0
original + h: -1.9283305406570435

original - h: -0.11375462263822556
grad: 0.0
original + h: -0.1854669153690338
original - h: -0.18548691272735596
grad: 0.0
original + h: -0.6152654886245728
original - h: -0.6152855157852173
grad: 0.0
original + h: -1.0193904638290405
original - h: -1.019410490989685
grad: 0.0
original + h: -1.3993350267410278
original - h: -1.3993550539016724
grad: 0.0
original + h: 0.09594280272722244
original - h: 0.0959228053689003
grad: 0.0
original + h: 0.14164267480373383
original - h: 0.14162267744541168
grad: 0.0
original + h: 0.36980605125427246
original - h: 0.36978602409362793
grad: 0.0
original + h: -0.8815073370933533
original - h: -0.8815273642539978
grad: 0.0
original + h: 1.4705644845962524
original - h: 1.470544457435608
grad: 0.0
original + h: -0.04456494003534317
original - h: -0.044584937393665314
grad: 0.0
original + h: 0.7907955050468445
original - h: 0.7907754778862
grad: 0.0
original + h: -0.5959331393241882
original - h: -0.5959531664848328
grad: 0.0
original + h: -1.87948858

grad: 0.0
original + h: -0.11904595047235489
original - h: -0.11906594783067703
grad: 0.0
original + h: -0.6177168488502502
original - h: -0.6177368760108948
grad: 0.0
original + h: 0.47421082854270935
original - h: 0.4741908013820648
grad: 0.0
original + h: -0.21876947581768036
original - h: -0.2187894731760025
grad: 0.0
original + h: 0.46607664227485657
original - h: 0.46605661511421204
grad: 0.0
original + h: 0.4105904996395111
original - h: 0.4105704724788666
grad: 0.0
original + h: 1.3664027452468872
original - h: 1.3663827180862427
grad: 0.0
original + h: -0.17825521528720856
original - h: -0.1782752126455307
grad: 0.0
original + h: -0.06682686507701874
original - h: -0.06684686243534088
grad: 0.0
original + h: 0.5412506461143494
original - h: 0.5412306189537048
grad: 0.0
original + h: -1.6551541090011597
original - h: -1.6551741361618042
grad: 0.0
original + h: 1.0711899995803833
original - h: 1.0711699724197388
grad: 0.0
original + h: -0.17537908256053925
original - h: -0.17539

In [90]:
Y.shape
np.sum(max_pool_naive_backward(dY, X, S))

480.0

# Debug

In [144]:
class MaxPoolingLayer:
    def __init__(self, stride=1):
        self.stride = stride
        self.X = None

    def forward(self, X):
        N_batch, H_in, W_in, C_in = X.shape

        H_out = H_in // self.stride
        W_out = W_in // self.stride

        Y = np.zeros((N_batch, H_out, W_out, C_in))
        for h in range(H_out):
            h_start = h * self.stride
            h_end = h_start + self.stride
            for w in range(W_out):
                w_start = w * self.stride
                w_end = w_start + self.stride
                X_slice = X[:, h_start:h_end, w_start:w_end, :]
                Y[:, h, w, :] = np.max(X_slice, axis=(1, 2))

        self.X = X

        return Y

    def backward(self, dY):
        N_batch, H_in, W_in, C_in = self.X.shape

        H_out = H_in // self.stride
        W_out = W_in // self.stride

        dX = np.zeros_like(self.X)

        for n_batch in range(N_batch):
            for h in range(H_out):
                for w in range(W_out):
                    h_start = h * self.stride
                    h_end = h_start + self.stride
                    w_start = w * self.stride
                    w_end = w_start + self.stride

                    X_slice = self.X[n_batch, h_start:h_end, w_start:w_end, :]
                    X_slice_mask = X_slice == np.max(X_slice, axis=(0, 1))
                    current_dY = dY[n_batch, h, w, :]
                    dX_slice = X_slice_mask * current_dY
                    dX[n_batch, h_start:h_end, w_start:w_end, :] = dX_slice

        return dX

In [147]:
S = 2
W = float_sequence(2*2*2*1).reshape(2,2,2,1) + 46.789793
X = np.random.randn(2, 4, 4, 1).astype(np.float32) * 523.34637
dY = W

pool = MaxPoolingLayer(stride=S)
pool.forward(X)
dX = pool.backward(dY)
print(dX[0,:,:,:])

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    ) * tf.constant(W)
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)


print("=== Matched? ===")    
print(dX - tf_dX)
print(np.linalg.norm(dX))
print(np.linalg.norm(tf_dX))
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

[[[  0.        ]
  [  0.        ]
  [  0.        ]
  [  0.        ]]

 [[ 46.78979111]
  [  0.        ]
  [  0.        ]
  [ 47.78979111]]

 [[  0.        ]
  [  0.        ]
  [ 49.78979111]
  [  0.        ]]

 [[  0.        ]
  [ 48.78979111]
  [  0.        ]
  [  0.        ]]]
=== Matched? ===
[[[[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]

  [[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]

  [[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]

  [[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]]


 [[[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]

  [[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]

  [[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]

  [[ 0.]
   [ 0.]
   [ 0.]
   [ 0.]]]]
142.389
142.389
True 0.0


# Benchmark

In [69]:
S = 2
X = np.random.randn(128, 28, 28, 3).astype(np.float32)
dY = np.ones((128, 14, 14, 3))

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


In [70]:
%%timeit -n3 -r3

max_pool_naive_backward(dY, X, S)

212 ms ± 3.22 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


# Shape Test

In [71]:
S = 3
X = np.random.randn(128, 30, 12, 6).astype(np.float32)
dY = np.ones((128, 10, 4, 6))

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0


In [72]:
S = 5
X = np.random.randn(1, 100, 200, 16).astype(np.float32)
dY = np.ones((1, 20, 40, 16))

dX = max_pool_naive_backward(dY, X, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_Y = tf.nn.max_pool(
        tf_X,
        ksize=[1, S, S, 1],
        strides=[1, S, S, 1],
        padding='VALID'
    )
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X])
    tf_dX, = sess.run(tf_grad)

print("=== Matched? ===")    
check = np.linalg.norm(dX - tf_dX) / ((np.linalg.norm(dX) + np.linalg.norm(tf_dX)))
print(check < 1e-7, check)

=== Matched? ===
True 0.0
