diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index d08ba5f4f27..2e257e6f90a 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -224,17 +224,17 @@ const TEST_FILTERS: TestFilter[] = [ 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. - 'broadcasting Tensor4D shapes' // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. ] }, { include: 'greaterEqual', excludes: [ - 'gradient', // Not yet implemented. + 'gradient', // Not yet implemented. 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. - 'broadcasting Tensor4D shapes' // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. ] }, { @@ -243,13 +243,7 @@ const TEST_FILTERS: TestFilter[] = [ 'axis=0', // Reduction not supported along inner dimensions. ] }, - { - startsWith: 'sum ', - excludes: [ - 'axis=0', // Reduction not supported along inner dimensions. - 'axis=[-1,-2]', // Reduction not supported along inner dimensions. - ] - } + {startsWith: 'sum '} ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/reduction_ops.ts b/tfjs-core/src/ops/reduction_ops.ts index a275b830bef..37d8f0dfa8d 100644 --- a/tfjs-core/src/ops/reduction_ops.ts +++ b/tfjs-core/src/ops/reduction_ops.ts @@ -135,7 +135,7 @@ function sum_( return {x: () => gradFunc(dy)}; }; - const attrs = {axes}; + const attrs = {axes: reductionAxes}; let value = ENGINE.runKernelFunc( backend => backend.sum(permutedX, reductionAxes), {x: permutedX}, gradInputs, 'Sum', attrs);