From bcb4ce2e034226159b32751aa1f46d7aee159bfc Mon Sep 17 00:00:00 2001 From: lisa0314 Date: Fri, 10 May 2024 01:16:16 -0700 Subject: [PATCH] webnn: Enforce input data type constraints for some reduce operators As specified in https://github.com/webmachinelearning/webnn/pull/646 Bug: 328567884 Change-Id: I5e600bdc791ecd4530408b65c15dd20611c211a3 Cq-Include-Trybots: luci.chromium.try:mac14-blink-rel,mac14.arm64-blink-rel --- webnn/validation_tests/reduction.https.any.js | 63 +++++++++++++++---- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/webnn/validation_tests/reduction.https.any.js b/webnn/validation_tests/reduction.https.any.js index 7da6b24dcf0583..e9643c0f8a5917 100644 --- a/webnn/validation_tests/reduction.https.any.js +++ b/webnn/validation_tests/reduction.https.any.js @@ -24,15 +24,18 @@ const kFloatRestrictReductionOperators = [ 'reduceMean', ]; -const kFloatAllowReductionOperators = [ +const kFloatInt32Uint32RestrictReductionOperators = [ 'reduceL1', - 'reduceMax', - 'reduceMin', 'reduceProduct', 'reduceSum', 'reduceSumSquare', ]; +const kNoTypeRestrictReductionOperators = [ + 'reduceMax', + 'reduceMin', +]; + const allReductionOperatorsTests = [ { name: '[reduce] Test reduce with default options.', @@ -76,8 +79,7 @@ const allReductionOperatorsTests = [ const kFloatRestrictOperatorsTests = [ { - name: - '[reduce] Throw if the input data type is not one of the floating point.', + name: '[reduce] Throw if the input data type is int32.', input: {dataType: 'int32', dimensions: [1, 2, 5, 5]}, options: { axes: [0, 1], @@ -85,13 +87,48 @@ const kFloatRestrictOperatorsTests = [ }, ]; -const kFloatAllowOperatorsTests = [ +const kFloatInt32Uint32RestrictOperatorsTests = [ { - name: - '[reduce] Test when the input data type is not one of the floating point.', - input: {dataType: 'int32', dimensions: [1, 3, 4, 4]}, + name: '[reduce] Test reduce when input\'s datatype is int32.', + input: {dataType: 'int32', dimensions: [1, 2, 5, 5]}, output: {dataType: 'int32', dimensions: []} }, + { + name: '[reduce] Test reduce when input\'s datatype is uint32.', + input: {dataType: 'uint32', dimensions: [1, 2, 5, 5]}, + output: {dataType: 'uint32', dimensions: []} + }, + { + name: + '[reduce] Throw if the input data type is not one of the {float32, float16, int32, uint32}.', + input: {dataType: 'int64', dimensions: [1, 2, 5, 5]}, + options: { + axes: [0, 1], + }, + }, +]; + +const kNoTypeRestrictOperatorsTests = [ + { + name: '[reduce] Test reduce when input\'s datatype is int64.', + input: {dataType: 'int64', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'int64', dimensions: []} + }, + { + name: '[reduce] Test reduce when input\'s datatype is uint64.', + input: {dataType: 'uint64', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'uint64', dimensions: []} + }, + { + name: '[reduce] Test reduce when input\'s datatype is int8.', + input: {dataType: 'int8', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'int8', dimensions: []} + }, + { + name: '[reduce] Test reduce when input\'s datatype is uint8.', + input: {dataType: 'uint8', dimensions: [1, 3, 4, 4]}, + output: {dataType: 'uint8', dimensions: []} + }, ]; function runReductionTests(operatorName, tests) { @@ -122,6 +159,10 @@ kFloatRestrictReductionOperators.forEach((operatorName) => { runReductionTests(operatorName, kFloatRestrictOperatorsTests); }); -kFloatAllowReductionOperators.forEach((operatorName) => { - runReductionTests(operatorName, kFloatAllowOperatorsTests); +kFloatInt32Uint32RestrictReductionOperators.forEach((operatorName) => { + runReductionTests(operatorName, kFloatInt32Uint32RestrictOperatorsTests); +}); + +kNoTypeRestrictReductionOperators.forEach((operatorName) => { + runReductionTests(operatorName, kNoTypeRestrictOperatorsTests); });