diff --git a/tfjs-core/src/ops/pool_test.ts b/tfjs-core/src/ops/pool_test.ts index acf682779d9..ef0ed6895c7 100644 --- a/tfjs-core/src/ops/pool_test.ts +++ b/tfjs-core/src/ops/pool_test.ts @@ -1489,8 +1489,7 @@ describeWithFlags('maxPoolWithArgmax', ALL_ENVS, () => { const padding = 0; - const {result, indexes} = - tf.maxPoolWithArgmax(x, [1, 1], [1, 1], padding); + const {result, indexes} = tf.maxPoolWithArgmax(x, [1, 1], [1, 1], padding); expectArraysClose(await result.data(), [0]); expectArraysClose(await indexes.data(), [0]); }); @@ -1509,8 +1508,7 @@ describeWithFlags('maxPoolWithArgmax', ALL_ENVS, () => { it('x=[2,2,2,1] f=[2,2,2] s=1 p=valid includeBatchInIndex=true', async () => { const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); - const {result, indexes} = - tf.maxPoolWithArgmax(x, 2, 1, 'valid', true); + const {result, indexes} = tf.maxPoolWithArgmax(x, 2, 1, 'valid', true); expect(result.shape).toEqual([2, 1, 1, 1]); expectArraysClose(await result.data(), [4, 8]); @@ -1563,15 +1561,16 @@ describeWithFlags('maxPoolWithArgmax', ALL_ENVS, () => { expect(indexes.shape).toEqual([2, 2, 2, 1]); expectArraysClose(await indexes.data(), [4, 5, 7, 7, 13, 14, 16, 17]); }); + it('[x=[1,3,3,1] f=[2,2] s=1 ignores NaNs', async () => { - const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, NaN, 9], [1, 3, 3, 1]); + const x = tf.tensor4d([NaN, 1, 2, 3, 4, 5, 6, 7, 9], [1, 3, 3, 1]); const {result, indexes} = tf.maxPoolWithArgmax(x, 2, 1, 0); expect(result.shape).toEqual([1, 2, 2, 1]); - expectArraysClose(await result.data(), [5, 6, 7, 9]); + expectArraysClose(await result.data(), [4, 5, 7, 9]); expect(indexes.shape).toEqual([1, 2, 2, 1]); - expectArraysClose(await indexes.data(), [4, 5, 6, 8]); + expectArraysClose(await indexes.data(), [4, 5, 7, 8]); }); it('x=[1, 3,3,2] f=[2,2] s=1', async () => { @@ -1620,8 +1619,7 @@ describeWithFlags('maxPoolWithArgmax', ALL_ENVS, () => { }); it('throws when passed a non-tensor', () => { - expect( - () => tf.maxPoolWithArgmax({} as tf.Tensor4D, 2, 1, 'valid')) + expect(() => tf.maxPoolWithArgmax({} as tf.Tensor4D, 2, 1, 'valid')) .toThrowError( /Argument 'x' passed to 'maxPoolWithArgmax' must be a Tensor/); });