Skip to content

Commit

Permalink
webnn: Migrate argMinMax validation tests to WPTs
Browse files Browse the repository at this point in the history
This CL adds WPT tests for argMinMax and removes the unit tests
`MLGraphBuilderTest.ArgMinMaxTest` and `MLGraphTestMojo.ArgMinMaxTest`.

Bug: 327337526, 328026885
Change-Id: I8fcce0db9e4b1673e721baf0b7abdfe3a06bdf2b
  • Loading branch information
mei1127 authored and chromium-wpt-export-bot committed May 29, 2024
1 parent fd2e974 commit 6f76cd1
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 11 deletions.
21 changes: 11 additions & 10 deletions webnn/resources/utils_validation.js
Original file line number Diff line number Diff line change
Expand Up @@ -306,26 +306,27 @@ function validateOptionsAxes(operationName) {
}
}, `[${subOperationName}] TypeError is expected if any of options.axes elements is not an unsigned long interger`);

// DataError is expected if any of options.axes elements is greater or equal to the size of input
// TypeError is expected if any of options.axes elements is greater or equal
// to the size of input
promise_test(async t => {
for (let dataType of allWebNNOperandDataTypes) {
for (let dimensions of allWebNNDimensionsArray) {
const rank = getRank(dimensions);
if (rank >= 1) {
const input =
builder.input(`input${++inputIndex}`, {dataType, dimensions});
assert_throws_dom(
'DataError',
assert_throws_js(
TypeError,
() => builder[subOperationName](input, {axes: [rank]}));
assert_throws_dom(
'DataError',
assert_throws_js(
TypeError,
() => builder[subOperationName](input, {axes: [rank + 1]}));
}
}
}
}, `[${subOperationName}] DataError is expected if any of options.axes elements is greater or equal to the size of input`);
}, `[${subOperationName}] TypeError is expected if any of options.axes elements is greater or equal to the size of input`);

// DataError is expected if two or more values are same in the axes sequence
// TypeError is expected if two or more values are same in the axes sequence
promise_test(async t => {
for (let dataType of allWebNNOperandDataTypes) {
for (let dimensions of allWebNNDimensionsArray) {
Expand All @@ -336,13 +337,13 @@ function validateOptionsAxes(operationName) {
const axesArrayContainSameValues =
getAxesArrayContainSameValues(dimensions);
for (let axes of axesArrayContainSameValues) {
assert_throws_dom(
'DataError', () => builder[subOperationName](input, {axes}));
assert_throws_js(
TypeError, () => builder[subOperationName](input, {axes}));
}
}
}
}
}, `[${subOperationName}] DataError is expected if two or more values are same in the axes sequence`);
}, `[${subOperationName}] TypeError is expected if two or more values are same in the axes sequence`);
}
}

Expand Down
112 changes: 111 additions & 1 deletion webnn/validation_tests/argMinMax.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,117 @@ const kArgMinMaxOperators = [
'argMax',
];

const tests = [
{
name: '[argMin/Max] Test with default options.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
output: {dataType: 'float32', dimensions: []}
},
{
name: '[argMin/Max] Test with axes=[].',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [],
},
output: {dataType: 'float32', dimensions: [1, 2, 3, 4]}
},
{
name: '[argMin/Max] Test scalar input with empty axes.',
input: {dataType: 'float32', dimensions: []},
options: {
axes: [],
},
output: {dataType: 'float32', dimensions: []}
},
{
name: '[argMin/Max] Test with axes=[1].',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [1],
},
output: {dataType: 'float32', dimensions: [1, 3, 4]}
},
{
name: '[argMin/Max] Test with axes=[1, 3] and keepDimensions=true.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [1, 3],
keepDimensions: true,
},
output: {dataType: 'float32', dimensions: [1, 1, 3, 1]}
},
{
name: '[argMin/Max] Test with axes=[1, 3] and keepDimensions=false.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [1, 3],
keepDimensions: false,
},
output: {dataType: 'float32', dimensions: [1, 3]}
},
{
name: '[argMin/Max] Test with axes=[1] and selectLastIndex=true.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [1],
selectLastIndex: true,
},
output: {dataType: 'float32', dimensions: [1, 3, 4]}
},
{
name: '[argMin/Max] Test with axes=[1] and selectLastIndex=false.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [1],
selectLastIndex: false,
},
output: {dataType: 'float32', dimensions: [1, 3, 4]}
},
{
name:
'[argMin/Max] Throw if the value in axes is greater than or equal to input rank.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [4],
},
},
{
name:
'[argMin/Max] Throw if two or more values are same in the axes sequence.',
input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
options: {
axes: [1, 1],
},
},
{
name: '[argMin/Max] Throw if input is a scalar and axes is non-empty.',
input: {dataType: 'float32', dimensions: []},
options: {
axes: [1],
},
},
];

function runTests(operatorName, tests) {
tests.forEach(test => {
promise_test(async t => {
const input = builder.input(
'input',
{dataType: test.input.dataType, dimensions: test.input.dimensions});

if (test.output) {
const output = builder[operatorName](input, test.options);
assert_equals(output.dataType(), 'int64');
assert_array_equals(output.shape(), test.output.dimensions);
} else {
assert_throws_js(
TypeError, () => builder[operatorName](input, test.options));
}
}, test.name.replace('[argMin/Max]', `[${operatorName}]`));
});
}

kArgMinMaxOperators.forEach((operatorName) => {
validateOptionsAxes(operatorName);
validateInputFromAnotherBuilder(operatorName);
runTests(operatorName, tests);
});

0 comments on commit 6f76cd1

Please sign in to comment.