@@ -370,6 +370,33 @@ def testDepthwiseMaxPool2x2DepthWindow3(self):
370370 expected = [3.0 , 6.0 , 9.0 , 12.0 , 15.0 , 18.0 , 21.0 , 24.0 ],
371371 use_gpu = False )
372372
373+ def testKernelSmallerThanStride (self ):
374+ for use_gpu in [True , False ]:
375+ self ._VerifyValues (tf .nn .max_pool , input_sizes = [1 , 3 , 3 , 1 ],
376+ ksize = [1 , 1 , 1 , 1 ], strides = [1 , 2 , 2 , 1 ],
377+ padding = "SAME" ,
378+ expected = [1 , 3 , 7 , 9 ],
379+ use_gpu = use_gpu )
380+
381+ self ._VerifyValues (tf .nn .max_pool , input_sizes = [1 , 7 , 7 , 1 ],
382+ ksize = [1 , 2 , 2 , 1 ], strides = [1 , 3 , 3 , 1 ],
383+ padding = "VALID" ,
384+ expected = [9 , 12 , 30 , 33 ],
385+ use_gpu = use_gpu )
386+
387+ self ._VerifyValues (tf .nn .avg_pool , input_sizes = [1 , 3 , 3 , 1 ],
388+ ksize = [1 , 1 , 1 , 1 ], strides = [1 , 2 , 2 , 1 ],
389+ padding = "SAME" ,
390+ expected = [1 , 3 , 7 , 9 ],
391+ use_gpu = use_gpu )
392+
393+ self ._VerifyValues (tf .nn .avg_pool , input_sizes = [1 , 7 , 7 , 1 ],
394+ ksize = [1 , 2 , 2 , 1 ], strides = [1 , 3 , 3 , 1 ],
395+ padding = "VALID" ,
396+ expected = [5 , 8 , 26 , 29 ],
397+ use_gpu = use_gpu )
398+
399+
373400 def _testDepthwiseMaxPoolInvalidConfig (self , in_size , ksize , strides ,
374401 error_msg , use_gpu = False ):
375402 t = tf .constant (1.0 , shape = in_size )
@@ -885,20 +912,6 @@ def testShapeFunctionEdgeCases(self):
885912 shape = [32 , 20 , 20 , 3 ]),
886913 ksize = [1 , 21 , 20 , 1 ], strides = [1 , 1 , 1 , 1 ], padding = "SAME" )
887914
888- # Stride larger than filter.
889- for pool_func in [tf .nn .max_pool , tf .nn .avg_pool ,
890- tf .nn .max_pool_with_argmax ]:
891- with self .assertRaisesRegexp (
892- ValueError , "stride must be less than or equal to filter" ):
893- pool_func (tf .placeholder (tf .float32 ,
894- shape = [32 , 20 , 20 , 3 ]),
895- ksize = [1 , 5 , 3 , 1 ], strides = [1 , 5 , 5 , 1 ], padding = "SAME" )
896- with self .assertRaisesRegexp (
897- ValueError , "stride must be less than or equal to filter" ):
898- pool_func (tf .placeholder (tf .float32 ,
899- shape = [32 , 20 , 20 , 3 ]),
900- ksize = [1 , 3 , 5 , 1 ], strides = [1 , 5 , 5 , 1 ], padding = "SAME" )
901-
902915
903916def GetMaxPoolFwdTest (input_size , filter_size , strides , padding ):
904917 def Test (self ):
0 commit comments