In [1]:
import numpy as np
import tensorflow as tf

# 07-05: Efficient Convolution Backward Pass using col2im

Try calculating with NumPy, then check with TensorFlow's autograd

#### 1. Basic 2D Convolution
$(4 \times 4) * (3 \times 3) = (2 \times 2)$

#### 2. Padding
$(4 \times 4) * (3 \times 3) = (4 \times 4)$ where $P=1$

#### 3. Stride
$(7 \times 7) * (3 \times 3) = (3 \times 3)$ where $S=2$

#### 4. Padding and Stride
$(7 \times 7) * (3 \times 3) = (4 \times 4)$ where $P=1, S=2$

#### 5. Channel
$(4 \times 4 \times 3) * (3 \times 3 \times 3) = (2 \times 2)$

#### 6. Channel and bias 
$(4 \times 4 \times 3) * (3 \times 3 \times 3) + (1) = (2 \times 2)$

#### 7. Multiple Filters
$(4 \times 4 \times 3) * (3 \times 3 \times 3 \times 4) = (2 \times 2 \times 4)$

#### 8.Multiple Filters + bias 
$(4 \times 4 \times 3) * (3 \times 3 \times 3 \times 4) + (4)= (2 \times 2 \times 4)$

#### 9. Mini-batch + bias
$(3 \times 4 \times 4 \times 3) * (3 \times 3 \times 3 \times 4) + (4)= (3 \times 2 \times 2 \times 4)$

#### 10. RGB Mini-batch $*$ Multiple Filters with stride and padding
$(3 \times 7 \times 7 \times 3) * (3 \times 3 \times 3 \times 4) + (4)= (3 \times 4 \times 4 \times 4)$ where $P=1, S=2$


In [2]:
def float_sequence(size):
    return np.arange(size, dtype=np.float32)

### 1. Convolution Backward using col2im

$(4 \times 4) * (3 \times 3) = (2 \times 2)$

In [24]:
X = float_sequence(4*4).reshape(4,4)
W = 12 - float_sequence(3*3).reshape(3,3)

#================== Forward ==================
# We have X_col and W_col from forward pass

X_col = np.zeros((4,9))
for h in range(2):
    for w in range(2):
        h_start = h
        h_end   = h_start + 3
        w_start = w
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end]
        X_col_row_index = h * 2 + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.reshape(-1, 1)


#================== dY ==================
dY = np.ones((2,2))
dY_col = dY.reshape(4, 1)

# print("=== dY ===")
# print(dY)
# print("=== dY_col ===")
# print(dY_col)

#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3)
        
print("=== dW ===")     
print(dW)

#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((4,4))
for h in range(2):
    for w in range(2):
        h_start = h
        h_end   = h_start + 3
        w_start = w
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end] += dX_col_slice.reshape(3, 3)


print("=== dX_col ===")     
print(dX_col)

print("=== dX ===")     
print(dX)

#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 1))
    tf_W = tf.Variable(W.reshape(3, 3, 1, 1))
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='VALID')
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW = sess.run(tf_grad)
    print("=== dX (tf) ===")     
    print(tf_dX[0, :, :, 0])
    print("=== dW (tf) ===")     
    print(tf_dW[:, :, 0, 0])

print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, 0]))
print("dW: ", np.all(dW == tf_dW[:, :, 0, 0]))

=== dW ===
[[ 10.  14.  18.]
 [ 26.  30.  34.]
 [ 42.  46.  50.]]
=== dX_col ===
[[ 12.  11.  10.   9.   8.   7.   6.   5.   4.]
 [ 12.  11.  10.   9.   8.   7.   6.   5.   4.]
 [ 12.  11.  10.   9.   8.   7.   6.   5.   4.]
 [ 12.  11.  10.   9.   8.   7.   6.   5.   4.]]
=== dX ===
[[ 12.  23.  21.  10.]
 [ 21.  40.  36.  17.]
 [ 15.  28.  24.  11.]
 [  6.  11.   9.   4.]]
=== dX (tf) ===
[[ 12.  23.  21.  10.]
 [ 21.  40.  36.  17.]
 [ 15.  28.  24.  11.]
 [  6.  11.   9.   4.]]
=== dW (tf) ===
[[ 10.  14.  18.]
 [ 26.  30.  34.]
 [ 42.  46.  50.]]
=== Matched? ===
dX:  True
dW:  True


### 2. Convolution with padding using col2im

$(4 \times 4) * (3 \times 3) = (4 \times 4)$ where $P=1$

In [32]:
X_org = float_sequence(4*4).reshape(4,4)
P = 1
X = np.pad(X_org, ((P, P), (P, P)), 'constant')
W = 12 - float_sequence(3*3).reshape(3,3)

H_out = (4 + 2*P - 3) + 1
W_out = (4 + 2*P - 3) + 1


#================== Forward ==================
X_col = np.zeros((16,9))
for h in range(H_out):
    for w in range(W_out):
        h_start = h
        h_end   = h_start + 3
        w_start = w
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end]
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.reshape(-1, 1)


#================== dY ==================
dY = np.ones((4,4))
dY_col = dY.reshape(16, 1)


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3)
        
print("=== dW ===")     
print(dW)

#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((6,6))
for h in range(H_out):
    for w in range(W_out):
        h_start = h
        h_end   = h_start + 3
        w_start = w
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end] += dX_col_slice.reshape(3, 3)
dX = dX[P:-P, P:-P]

#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X_org.reshape(1, 4, 4, 1))
    tf_W = tf.Variable(W.reshape(3, 3, 1, 1))
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='SAME')
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW = sess.run(tf_grad)
    print("=== dX (tf) ===")     
    print(tf_dX[0, :, :, 0])
    print("=== dW (tf) ===")     
    print(tf_dW[:, :, 0, 0])

print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, 0]))
print("dW: ", np.all(dW == tf_dW[:, :, 0, 0]))

=== dW ===
[[  45.   66.   54.]
 [  84.  120.   96.]
 [  81.  114.   90.]]
=== dX (tf) ===
[[ 40.  57.  57.  36.]
 [ 51.  72.  72.  45.]
 [ 51.  72.  72.  45.]
 [ 28.  39.  39.  24.]]
=== dW (tf) ===
[[  45.   66.   54.]
 [  84.  120.   96.]
 [  81.  114.   90.]]
=== Matched? ===
dX:  True
dW:  True


### 3. Convolution with Stride using col2im

$(7 \times 7) * (3 \times 3) = (3 \times 3)$ where $S=2$

In [46]:
P = 0
S = 2
X = float_sequence(7*7).reshape(7,7)
W = 12 - float_sequence(3*3).reshape(3,3)

H_out = (7 + 2*P - 3) // S + 1
W_out = (7 + 2*P - 3) // S + 1

#================== Forward ==================
X_col = np.zeros((H_out * W_out, 9))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end]
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)
W_col = W.reshape(-1, 1)




#================== dY ==================
dY = np.ones((3,3))
dY_col = dY.reshape(9, 1)


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3)

# print("=== dW ===")     
# print(dW)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((7,7))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end] += dX_col_slice.reshape(3, 3)


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 7, 7, 1))
    tf_W = tf.Variable(W.reshape(3, 3, 1, 1))
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='VALID')
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW = sess.run(tf_grad)
    print("=== dX (tf) ===")     
    print(tf_dX[0, :, :, 0])
    print("=== dW (tf) ===")     
    print(tf_dW[:, :, 0, 0])

print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, 0]))
print("dW: ", np.all(dW == tf_dW[:, :, 0, 0]))

=== dX (tf) ===
[[ 12.  11.  22.  11.  22.  11.  10.]
 [  9.   8.  16.   8.  16.   8.   7.]
 [ 18.  16.  32.  16.  32.  16.  14.]
 [  9.   8.  16.   8.  16.   8.   7.]
 [ 18.  16.  32.  16.  32.  16.  14.]
 [  9.   8.  16.   8.  16.   8.   7.]
 [  6.   5.  10.   5.  10.   5.   4.]]
=== dW (tf) ===
[[ 144.  153.  162.]
 [ 207.  216.  225.]
 [ 270.  279.  288.]]
=== Matched? ===
dX:  True
dW:  True


### 4. Padding and Stride using col2im

$(7 \times 7) * (3 \times 3) = (4 \times 4)$ where $P=1, S=2$

In [50]:
P = 1
S = 2
X_org = float_sequence(7*7).reshape(7,7)
X = np.pad(X_org, ((P, P), (P, P)), 'constant')
W = 12 - float_sequence(3*3).reshape(3,3)

H_out = (7 + 2*P - 3) // S + 1
W_out = (7 + 2*P - 3) // S + 1


#================== Forward ==================
X_col = np.zeros((H_out * W_out, 9))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end]
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.reshape(-1, 1)



#================== dY ==================
dY = np.ones((4,4))
dY_col = dY.reshape(16, 1)


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3)

# print("=== dW ===")     
# print(dW)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((7+2*P, 7+2*P))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end] += dX_col_slice.reshape(3, 3)
dX = dX[P:-P, P:-P]


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X_org.reshape(1, 7, 7, 1))
    tf_W = tf.Variable(W.reshape(3, 3, 1, 1))
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='SAME')
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW = sess.run(tf_grad)
    print("=== L (tf) ===")     
    print(tf_L_val)
    print("=== dX (tf) ===")     
    print(tf_dX[0, :, :, 0])
    print("=== dW (tf) ===")     
    print(tf_dW[:, :, 0, 0])

# print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, 0]))
print("dY: ", np.all(dW == tf_dW[:, :, 0, 0]))

=== L (tf) ===
19200.0
=== dX (tf) ===
[[  8.  16.   8.  16.   8.  16.   8.]
 [ 16.  32.  16.  32.  16.  32.  16.]
 [  8.  16.   8.  16.   8.  16.   8.]
 [ 16.  32.  16.  32.  16.  32.  16.]
 [  8.  16.   8.  16.   8.  16.   8.]
 [ 16.  32.  16.  32.  16.  32.  16.]
 [  8.  16.   8.  16.   8.  16.   8.]]
=== dW (tf) ===
[[ 216.  288.  216.]
 [ 288.  384.  288.]
 [ 216.  288.  216.]]
dX:  True
dY:  True


### 5. Channel with col2im

$(4 \times 4 \times 3) * (3 \times 3 \times 3) = (2 \times 2)$

In [75]:
P = 0
S = 1
X = float_sequence(4*4*3).reshape(4,4,3)
W = 30 - float_sequence(3*3*3).reshape(3,3,3)

H_out = (4 + 2*P - 3) // S + 1
W_out = (4 + 2*P - 3) // S + 1

#================== Forward ==================
X_col = np.zeros((H_out * W_out, 3*3*3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.transpose(2, 0, 1).reshape(-1, 1)      



#================== dY ==================
dY = np.ones((2,2))
dY_col = dY.reshape(4, 1)


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3,3).transpose(1, 2, 0)

# print("=== dW ===")     
# print(dW)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((4, 4, 3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(3, 3, 3).transpose(1, 2, 0)



#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_W = tf.Variable(W.reshape(3, 3, 3, 1))
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='VALID')
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW = sess.run(tf_grad)
    print("=== L (tf) ===")     
    print(tf_L_val)
#     print("=== dX (tf) ===")     
#     print(tf_dX[0, :, :, :])
#     print("=== dW (tf) ===")     
#     print(tf_dW[:, :, :, 0])

print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, :]))
print("dW: ", np.all(dW == tf_dW[:, :, :, 0]))

=== L (tf) ===
34650.0
=== Matched? ===
dX:  True
dW:  True


### 6. Channel and bias using col2im

$(4 \times 4 \times 3) * (3 \times 3 \times 3) + (1) = (2 \times 2)$

In [74]:
P = 0
S = 1
X = float_sequence(4*4*3).reshape(4,4,3)
W = 30 - float_sequence(3*3*3).reshape(3,3,3)
b = np.array([10], dtype=np.float32)

H_out = (4 + 2*P - 3) // S + 1
W_out = (4 + 2*P - 3) // S + 1

#================== Forward ==================
X_col = np.zeros((H_out * W_out, 3*3*3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.transpose(2, 0, 1).reshape(-1, 1)      


#================== dY ==================
dY = np.ones((2,2))
dY_col = dY.reshape(4, 1)

#================== db ==================
db = np.sum(dY)

#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3,3).transpose(1, 2, 0)

# print("=== dW ===")     
# print(dW)

#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((4, 4, 3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(3, 3, 3).transpose(1, 2, 0)


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_W = tf.Variable(W.reshape(3, 3, 3, 1))
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='VALID') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW, tf_db = sess.run(tf_grad)
#     print("=== L (tf) ===")     
#     print(tf_L_val)
#     print("=== dX (tf) ===")     
#     print(tf_dX[0, :, :, :])
#     print("=== dW (tf) ===")     
#     print(tf_dW[:, :, :, 0])
#     print("=== db (tf) ===")     
#     print(tf_db[0])

print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, :]))
print("dW: ", np.all(dW == tf_dW[:, :, :, 0]))
print("db: ", np.all(db == tf_db[0]))

=== Matched? ===
dX:  True
dW:  True
db:  True


### 7. Multiple Filters using col2im

$(4 \times 4 \times 3) * (3 \times 3 \times 3 \times 4) = (2 \times 2 \times 4)$

In [73]:
P = 0
S = 1
X = float_sequence(4*4*3).reshape(4,4,3)
W = 120 - float_sequence(3*3*3*4).reshape(3,3,3,4)

H_out = (4 + 2*P - 3) // S + 1
W_out = (4 + 2*P - 3) // S + 1


#================== Foward ==================
X_col = np.zeros((H_out * W_out, 3*3*3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.transpose(2, 0, 1, 3).reshape(-1, 4)


#================== dY ==================
dY = np.ones((2,2,4))
dY_col = dY.reshape(4, 4)


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3,3,4).transpose(1, 2, 0, 3)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((4, 4, 3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(3, 3, 3).transpose(1, 2, 0)


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_W = tf.Variable(W.reshape(3, 3, 3, 4))
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='VALID')
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW = sess.run(tf_grad)
#     print("=== L (tf) ===")     
#     print(tf_L_val)
#     print("=== dX (tf) ===")     
#     print(tf_dX[0, :, :, :])
#     print("=== dW (tf) ===")     
#     print(tf_dW[:, :, :, :])

print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, :]))
print("dW: ", np.all(dW == tf_dW[:, :, :, :]))

=== Matched? ===
dX:  True
dW:  True


### 8. Multiple Filters + bias using col2im

$(4 \times 4 \times 3) * (3 \times 3 \times 3 \times 4) + (4)= (2 \times 2 \times 4)$

In [72]:
P = 0
S = 1
X = float_sequence(4*4*3).reshape(4,4,3)
W = 120 - float_sequence(3*3*3*4).reshape(3,3,3,4)
b = np.array([10, 100, 1000, 10000], dtype=np.float32)


#================== Foward ==================
X_col = np.zeros((H_out * W_out, 3*3*3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        X_slice = X[h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
        X_col_row_index = h * H_out + w
        X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.transpose(2, 0, 1, 3).reshape(-1, 4)


#================== dY ==================
dY = np.ones((2,2,4))
dY_col = dY.reshape(4, 4)


#================== db ==================
db = np.sum(dY, axis=(0,1))


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3,3,4).transpose(1, 2, 0, 3)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((4, 4, 3))
for h in range(H_out):
    for w in range(W_out):
        h_start = h * S
        h_end   = h_start + 3
        w_start = w * S
        w_end   = w_start + 3
        
        dX_col_row_index = h * 2 + w
        dX_col_slice = dX_col[dX_col_row_index, :]
        dX[h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(3, 3, 3).transpose(1, 2, 0)


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(1, 4, 4, 3))
    tf_W = tf.Variable(W.reshape(3, 3, 3, 4))
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='VALID') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW, tf_db = sess.run(tf_grad)
#     print("=== L (tf) ===")     
#     print(tf_L_val)
#     print("=== dX (tf) ===")     
#     print(tf_dX[0, :, :, :])
#     print("=== dW (tf) ===")     
#     print(tf_dW[:, :, :, :])
    print("=== db (tf) ===")     
    print(tf_db)


print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[0, :, :, :]))
print("dW: ", np.all(dW == tf_dW[:, :, :, :]))
print("db: ", np.all(db == tf_db))

=== db (tf) ===
[ 4.  4.  4.  4.]
=== Matched? ===
dX:  True
dW:  True
db:  True


### 9. Mini-batch + bias using col2im

$(3 \times 4 \times 4 \times 3) * (3 \times 3 \times 3 \times 4) + (4)= (3 \times 2 \times 2 \times 4)$

In [80]:
P = 0
S = 1
X = float_sequence(3*4*4*3).reshape(3,4,4,3)
W = 120 - float_sequence(3*3*3*4).reshape(3,3,3,4)
b = np.array([10, 100, 1000, 10000], dtype=np.float32)

H_out = (4 + 2*P - 3) // S + 1
W_out = (4 + 2*P - 3) // S + 1

#================== Forward ==================
X_col = np.zeros((3 * H_out * W_out, 3*3*3))
for n_batch in range(3):
    for h in range(H_out):
        for w in range(W_out):
            h_start = h * S
            h_end   = h_start + 3
            w_start = w * S
            w_end   = w_start + 3

            X_slice = X[n_batch, h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
            X_col_row_index = n_batch * (H_out * W_out) + h * H_out + w
#             print("index:", X_col_row_index)
            X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.transpose(2, 0, 1, 3).reshape(-1, 4)


#================== dY ==================
dY = np.ones((3,2,2,4))
dY_col = dY.reshape(12, 4)

#================== db ==================
db = np.sum(dY, axis=(0,1,2))


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3,3,4).transpose(1, 2, 0, 3)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((3, 4+2*P, 4+2*P, 3))
for n_batch in range(3):
    for h in range(H_out):
        for w in range(W_out):
            h_start = h * S
            h_end   = h_start + 3
            w_start = w * S
            w_end   = w_start + 3

            dX_col_row_index = n_batch * (H_out * W_out) + h * H_out + w
            dX_col_slice = dX_col[dX_col_row_index, :]
            
            dX[n_batch, h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(3, 3, 3).transpose(1, 2, 0)


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X.reshape(3, 4, 4, 3))
    tf_W = tf.Variable(W.reshape(3, 3, 3, 4))
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, 1, 1, 1], padding='VALID') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW, tf_db = sess.run(tf_grad)
#     print("=== L (tf) ===")     
#     print(tf_L_val)
#     print("=== dX (tf) ===")     
#     print(tf_dX[0, :, :, :])
#     print("=== dW (tf) ===")     
#     print(tf_dW[:, :, :, :])
#     print("=== db (tf) ===")     
#     print(tf_db)


print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[:, :, :, :]))
print("dY: ", np.all(dW == tf_dW[:, :, :, :]))
print("db: ", np.all(db == tf_db))

=== Matched? ===
dX:  True
dY:  True
db:  True


#### 10. RGB Mini-batch $*$ Multiple Filters with stride and padding
$(3 \times 7 \times 7 \times 3) * (3 \times 3 \times 3 \times 4) + (4)= (3 \times 4 \times 4 \times 4)$ where $P=1, S=2$

In [83]:
P = 1
S = 2
X_org = float_sequence(3*7*7*3).reshape(3,7,7,3)
X = np.pad(X_org, ((0, 0), (P, P), (P, P), (0, 0)), 'constant')
W = 120 - float_sequence(3*3*3*4).reshape(3,3,3,4)
b = np.array([10, 100, 1000, 10000], dtype=np.float32)

H_out = (7 + 2*P - 3) // S + 1
W_out = (7 + 2*P - 3) // S + 1


#================== Forward ==================
X_col = np.zeros((3 * H_out * W_out, 3*3*3))
for n_batch in range(3):
    for h in range(H_out):
        for w in range(W_out):
            h_start = h * S
            h_end   = h_start + 3
            w_start = w * S
            w_end   = w_start + 3

            X_slice = X[n_batch, h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
            X_col_row_index = n_batch * (H_out * W_out) + h * H_out + w
            X_col[X_col_row_index, :] = X_slice.reshape(1, -1)

W_col = W.transpose(2, 0, 1, 3).reshape(-1, 4)


#================== dY ==================
dY = np.ones((3,4,4,4))
dY_col = dY.reshape(48, 4)


#================== db ==================
db = np.sum(dY, axis=(0,1,2))


#================== dW ==================
dW_col = np.dot(X_col.T, dY_col)
dW = dW_col.reshape(3,3,3,4).transpose(1, 2, 0, 3)


#================== dX ==================
dX_col = np.dot(dY_col, W_col.T)
dX = np.zeros((3, 7+2*P, 7+2*P, 3))
for n_batch in range(3):
    for h in range(H_out):
        for w in range(W_out):
            h_start = h * S
            h_end   = h_start + 3
            w_start = w * S
            w_end   = w_start + 3

            dX_col_row_index = n_batch * (H_out * W_out) + h * H_out + w
            dX_col_slice = dX_col[dX_col_row_index, :]
            
            dX[n_batch, h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(3, 3, 3).transpose(1, 2, 0)

dX = dX[:, P:-P, P:-P, :] # unpad


#================== tf ==================
with tf.Session() as sess:
    tf_X = tf.constant(X_org.reshape(3, 7, 7, 3))
    tf_W = tf.Variable(W.reshape(3, 3, 3, 4))
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='SAME') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW, tf_db = sess.run(tf_grad)
#     print("=== L (tf) ===")     
#     print(tf_L_val)
#     print("=== dX (tf) ===")     
#     print(tf_dX[0, :, :, :])
#     print("=== dW (tf) ===")     
#     print(tf_dW[:, :, :, :])
#     print("=== db (tf) ===")     
#     print(tf_db)


print("=== Matched? ===")    
print("dX: ", np.all(dX == tf_dX[:, :, :, :]))
print("dY: ", np.all(dW == tf_dW[:, :, :, :]))
print("db: ", np.all(db == tf_db))

=== Matched? ===
dX:  True
dY:  True
db:  True


# Generalized efficient convolution backward using col2im

In [121]:
def conv_backward(dY, X, W, b, P=0, S=1):
    N_batch, H_in, W_in, C_in = X.shape
    H_filter, W_filter, _, C_out = W.shape
    _, H_out, W_out, _ = dY.shape
    
    if P > 0:
        X = np.pad(X, ((0, 0), (P, P), (P, P), (0, 0)), 'constant')
    
    #================== forward ==================
    X_col = np.zeros((N_batch * H_out * W_out, H_filter*W_filter*C_in))
    X_col_row_index = 0
    for n_batch in range(N_batch): # TODO: Maybe I can remove this loop over N_batch?
        for h in range(H_out):
            for w in range(W_out):
                h_start = h * S
                h_end   = h_start + H_filter
                w_start = w * S
                w_end   = w_start + W_filter

                X_slice = X[n_batch, h_start:h_end, w_start:w_end, :].transpose(2, 0, 1)
                X_col[X_col_row_index, :] = X_slice.reshape(1, -1)
                
                X_col_row_index += 1 # X_col_row_index = n_batch * (H_out * W_out) + h * W_out + w

    W_col = W.transpose(2, 0, 1, 3).reshape(-1, C_out)
    
    
    #================== dY ==================
    dY_col = dY.reshape(-1, C_out)
    
    #================== db ==================
    db = np.sum(dY, axis=(0,1,2))

    
    #================== dW ==================
    dW_col = np.dot(X_col.T, dY_col)
    dW = dW_col.reshape(C_in, H_filter, W_filter, C_out).transpose(1, 2, 0, 3)


    #================== dX ==================
    dX_col = np.dot(dY_col, W_col.T)
    dX = np.zeros((N_batch, H_in+2*P, W_in+2*P, C_in))
    dX_col_row_index = 0
    for n_batch in range(N_batch):
        for h in range(H_out):
            for w in range(W_out):
                h_start = h * S
                h_end   = h_start + H_filter
                w_start = w * S
                w_end   = w_start + W_filter

                dX_col_slice = dX_col[dX_col_row_index, :]
                dX[n_batch, h_start:h_end, w_start:w_end, :] += dX_col_slice.reshape(C_in, H_filter, W_filter).transpose(1, 2, 0)

                dX_col_row_index += 1 # dX_col_row_index = n_batch * (H_out * W_out) + h * W_out + w

    if P > 0:
        dX = dX[:, P:-P, P:-P, :]
    
    return (dX, dW, db)

In [122]:
P = 1
S = 2
X = np.random.randn(3, 7, 7, 3).astype(np.float32)
W = np.random.randn(3,3,3,4).astype(np.float32)
b = np.random.randn(4).astype(np.float32)

dY = np.ones((3,4,4,4))
dX, dW, db = conv_backward(dY, X, W, b, P, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_W = tf.Variable(W)
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='SAME') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW, tf_db = sess.run(tf_grad)
        
check_dX = np.linalg.norm(dX - tf_dX) / (np.linalg.norm(dX) + np.linalg.norm(tf_dX))
check_dW = np.linalg.norm(dW - tf_dW) / (np.linalg.norm(dW) + np.linalg.norm(tf_dW))
check_db = np.linalg.norm(db - tf_db) / (np.linalg.norm(db) + np.linalg.norm(tf_db))

print("=== Matched? ===")    
print("dX: ", check_dX < 1e-7, check_dX)
print("dW: ", check_dW < 1e-7, check_dW)
print("db: ", check_db < 1e-7, check_db)

=== Matched? ===
dX:  True 1.71279033933e-08
dW:  True 9.32435609026e-08
db:  True 0.0


# Benchmark

In [123]:
P = 1
S = 1
X = np.random.randn(128, 28, 28, 3).astype(np.float32)
W = np.random.randn(3, 3, 3, 16).astype(np.float32)
b = np.random.randn(16).astype(np.float32)


dY = np.ones((128, 28, 28, 16))
dX, dW, db = conv_backward(dY, X, W, b, P, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_W = tf.Variable(W)
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='SAME') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_dX, tf_dW, tf_db = sess.run(tf_grad)
        
check_dX = np.linalg.norm(dX - tf_dX) / (np.linalg.norm(dX) + np.linalg.norm(tf_dX))
check_dW = np.linalg.norm(dW - tf_dW) / (np.linalg.norm(dW) + np.linalg.norm(tf_dW))
check_db = np.linalg.norm(db - tf_db) / (np.linalg.norm(db) + np.linalg.norm(tf_db))

print("=== Matched? ===")    
print("dX: ", check_dX < 1e-7, check_dX)
print("dW: ", check_dW < 1e-7, check_dW)
print("db: ", check_db < 1e-7, check_db)

=== Matched? ===
dX:  True 6.78232892095e-08
dW:  False 4.23263412303e-07
db:  True 0.0


In [126]:
%%timeit -n3 -r3

conv_backward(dY, X, W, b, P, S)

693 ms ± 858 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)


# Shape Test

In [127]:
P = 3
S = 1
X = np.random.randn(64, 11, 9, 16).astype(np.float32)
W = np.random.randn(7, 7, 16, 32).astype(np.float32)
b = np.random.randn(32).astype(np.float32)

dY = np.ones((64, 11, 9, 32))
dX, dW, db = conv_backward(dY, X, W, b, P, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_W = tf.constant(W)
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='SAME') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_grad_val = sess.run(tf_grad)
    
    tf_dX = tf_grad_val[0]
    tf_dW = tf_grad_val[1]
    tf_db = tf_grad_val[2]

check_dX = np.linalg.norm(dX - tf_dX) / (np.linalg.norm(dX) + np.linalg.norm(tf_dX))
check_dW = np.linalg.norm(dW - tf_dW) / (np.linalg.norm(dW) + np.linalg.norm(tf_dW))
check_db = np.linalg.norm(db - tf_db) / (np.linalg.norm(db) + np.linalg.norm(tf_db))

print(np.linalg.norm(dW - tf_dW))
print(np.linalg.norm(dW))
print(np.linalg.norm(tf_dW))

print("=== Matched? ===")    
print("dX: ", check_dX < 1e-6, check_dX)
print("dW: ", check_dW < 1e-6, check_dW)
print("db: ", check_db < 1e-6, check_db)

0.00315870523616
13703.7215824
13703.7
=== Matched? ===
dX:  True 9.91269930045e-08
dW:  True 1.15249913356e-07
db:  True 0.0


In [128]:
P = 0
S = 5
X = np.random.randn(4, 28, 46, 128).astype(np.float32)
W = np.random.randn(8, 16, 128, 2).astype(np.float32)
b = np.random.randn(2).astype(np.float32)

dY = np.ones((4, 5, 7, 2))
dX, dW, db = conv_backward(dY, X, W, b, P, S)

with tf.Session() as sess:
    tf_X = tf.constant(X)
    tf_W = tf.constant(W)
    tf_b = tf.constant(b)
    tf_Y = tf.nn.conv2d(tf_X, tf_W, strides=[1, S, S, 1], padding='VALID') + tf_b
    tf_L = tf.reduce_sum(tf_Y)
    tf_grad = tf.gradients(tf_L, [tf_X, tf_W, tf_b])
    
    sess.run(tf.global_variables_initializer())
    tf_L_val = sess.run(tf_L)
    tf_grad_val = sess.run(tf_grad)
    
    tf_dX = tf_grad_val[0]
    tf_dW = tf_grad_val[1]
    tf_db = tf_grad_val[2]

check_dX = np.linalg.norm(dX - tf_dX) / (np.linalg.norm(dX) + np.linalg.norm(tf_dX))
check_dW = np.linalg.norm(dW - tf_dW) / (np.linalg.norm(dW) + np.linalg.norm(tf_dW))
check_db = np.linalg.norm(db - tf_db) / (np.linalg.norm(db) + np.linalg.norm(tf_db))

print(np.linalg.norm(dW - tf_dW))
print(np.linalg.norm(dW))
print(np.linalg.norm(tf_dW))

print("=== Matched? ===")    
print("dX: ", check_dX < 1e-6, check_dX)
print("dW: ", check_dW < 1e-6, check_dW)
print("db: ", check_db < 1e-6, check_db)

0.000250347131214
2144.87415455
2144.87
=== Matched? ===
dX:  True 2.54974839091e-08
dW:  True 5.83594003252e-08
db:  True 0.0
