You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, AveragePooling2D works when the input is "tiled" by the kernel/stride settings, (i.e. a 2x2 kernel over a 4x4 input with stride 2), and this is true when padding is SAME or VALID. It also works for more general cases when padding is VALID. However, when padding is SAME and the pool_size requires zero padding in order to complete the last pooling patch, the behavior in TFE diverges from TF. This bug relates to how the count_include_pad argument works in pytorch.
In TFE, it's as if we've chosen count_include_pad=True, so that the zero padding gets included in our calculation of the average.
In TF, pooling with SAME padding works as if count_include_pad=False, so that the zero padding does not get included in the calculation of the average.
Ideally, we'd just switch, but our current AveragePooling2D implements its average by summing over the correct pooling sections, constructing a public/private/masked tensor from the result of that sum, and then multiplying with the inverse of pool_height*pool_width. However, this means that the zero padding (if it exists) must necessarily be included in the average.
The only way to fix it within the current paradigm (still doing a private-public multiply after the sum operation) would be to separate the summed tensor into patches where the zero padding is included, multiply those with pool_height*pool_width - number_of_padded_zeros instead, and then recombine them all into the correct output tensor with padding not included. the major problem here is that we'd likely lose any speed we gain from using im2col in the first place, and a cleaner solution would be to do everything we need straight from the output of tf.extract_image_patches.
Definitely open to suggestions here! We should consider the trade-offs of overengineering this piece for complete compatibility versus documenting the discrepancy and advising users to stick with certain padding settings when doing pooling.
The text was updated successfully, but these errors were encountered:
bendecoste
added
the
bug
An issue focused on figuring out why something isn't working and fixing it.
label
Aug 31, 2018
Currently, AveragePooling2D works when the input is "tiled" by the kernel/stride settings, (i.e. a 2x2 kernel over a 4x4 input with stride 2), and this is true when padding is SAME or VALID. It also works for more general cases when padding is VALID. However, when padding is SAME and the pool_size requires zero padding in order to complete the last pooling patch, the behavior in TFE diverges from TF. This bug relates to how the count_include_pad argument works in pytorch.
In TFE, it's as if we've chosen count_include_pad=True, so that the zero padding gets included in our calculation of the average.
In TF, pooling with SAME padding works as if count_include_pad=False, so that the zero padding does not get included in the calculation of the average.
Ideally, we'd just switch, but our current AveragePooling2D implements its average by summing over the correct pooling sections, constructing a public/private/masked tensor from the result of that sum, and then multiplying with the inverse of pool_height*pool_width. However, this means that the zero padding (if it exists) must necessarily be included in the average.
The only way to fix it within the current paradigm (still doing a private-public multiply after the sum operation) would be to separate the summed tensor into patches where the zero padding is included, multiply those with pool_height*pool_width - number_of_padded_zeros instead, and then recombine them all into the correct output tensor with padding not included. the major problem here is that we'd likely lose any speed we gain from using im2col in the first place, and a cleaner solution would be to do everything we need straight from the output of tf.extract_image_patches.
Definitely open to suggestions here! We should consider the trade-offs of overengineering this piece for complete compatibility versus documenting the discrepancy and advising users to stick with certain padding settings when doing pooling.
The text was updated successfully, but these errors were encountered: