diff --git a/webnn/validation_tests/reduction.https.any.js b/webnn/validation_tests/reduction.https.any.js index 7da6b24dcf0583..e9643c0f8a5917 100644 --- a/webnn/validation_tests/reduction.https.any.js +++ b/webnn/validation_tests/reduction.https.any.js @@ -24,15 +24,18 @@ const kFloatRestrictReductionOperators = [ 'reduceMean', ]; -const kFloatAllowReductionOperators = [ +const kFloatInt32Uint32RestrictReductionOperators = [ 'reduceL1', - 'reduceMax', - 'reduceMin', 'reduceProduct', 'reduceSum', 'reduceSumSquare', ]; +const kNoTypeRestrictReductionOperators = [ + 'reduceMax', + 'reduceMin', +]; + const allReductionOperatorsTests = [ { name: '[reduce] Test reduce with default options.', @@ -76,8 +79,7 @@ const allReductionOperatorsTests = [ const kFloatRestrictOperatorsTests = [ { - name: - '[reduce] Throw if the input data type is not one of the floating point.', + name: '[reduce] Throw if the input data type is int32.', input: {dataType: 'int32', dimensions: [1, 2, 5, 5]}, options: { axes: [0, 1], @@ -85,13 +87,48 @@ const kFloatRestrictOperatorsTests = [ }, ]; -const kFloatAllowOperatorsTests = [ +const kFloatInt32Uint32RestrictOperatorsTests = [ { - name: - '[reduce] Test when the input data type is not one of the floating point.', - input: {dataType: 'int32', dimensions: [1, 3, 4, 4]}, + name: '[reduce] Test reduce when input\'s datatype is int32.', + input: {dataType: 'int32', dimensions: [1, 2, 5, 5]}, output: {dataType: 'int32', dimensions: []} }, + { + name: '[reduce] Test reduce when input\'s datatype is uint32.', + input: {dataType: 'uint32', dimensions: [1, 2, 5, 5]}, + output: {dataType: 'uint32', dimensions: []} + }, + { + name: + '[reduce] Throw if the input data type is not one of the {float32, float16, int32, uint32}.', + input: {dataType: 'int64', dimensions: [1, 2, 5, 5]}, + options: { + axes: [0, 1], + }, + }, +]; + +const kNoTypeRestrictOperatorsTests = [ + { + name: '[reduce] Test reduce when input\'s datatype is int64.', + input: {dataType: 'int64', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'int64', dimensions: []} + }, + { + name: '[reduce] Test reduce when input\'s datatype is uint64.', + input: {dataType: 'uint64', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'uint64', dimensions: []} + }, + { + name: '[reduce] Test reduce when input\'s datatype is int8.', + input: {dataType: 'int8', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'int8', dimensions: []} + }, + { + name: '[reduce] Test reduce when input\'s datatype is uint8.', + input: {dataType: 'uint8', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'uint8', dimensions: []} + }, ]; function runReductionTests(operatorName, tests) { @@ -122,6 +159,10 @@ kFloatRestrictReductionOperators.forEach((operatorName) => { runReductionTests(operatorName, kFloatRestrictOperatorsTests); }); -kFloatAllowReductionOperators.forEach((operatorName) => { - runReductionTests(operatorName, kFloatAllowOperatorsTests); +kFloatInt32Uint32RestrictReductionOperators.forEach((operatorName) => { + runReductionTests(operatorName, kFloatInt32Uint32RestrictOperatorsTests); +}); + +kNoTypeRestrictReductionOperators.forEach((operatorName) => { + runReductionTests(operatorName, kNoTypeRestrictOperatorsTests); });