Skip to content

Commit

Permalink
[WASM] Use XNNPack for div (#3480)
Browse files Browse the repository at this point in the history
FEATURE
Co-authored-by: Ann Yuan <annyuan@gmail.com>
  • Loading branch information
pvaneck committed Jun 25, 2020
1 parent 69dfce1 commit b7a28db
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 12 deletions.
1 change: 0 additions & 1 deletion tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ tfjs_cc_library(
name = "Div",
srcs = ["kernels/Div.cc"],
deps = [
":backend",
":binary",
":util",
],
Expand Down
9 changes: 5 additions & 4 deletions tfjs-backend-wasm/src/cc/kernels/Div.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
#include <emscripten.h>
#endif

#include <xnnpack.h>
#include <cstddef>

#include "src/cc/backend.h"
#include "src/cc/binary.h"
#include "src/cc/util.h"

Expand All @@ -40,10 +40,11 @@ EMSCRIPTEN_KEEPALIVE
void Div(const size_t a_id, const size_t* a_shape_ptr, const size_t a_shape_len,
const size_t b_id, const size_t* b_shape_ptr, const size_t b_shape_len,
const DType dtype, const size_t out_id) {
auto& a_info = backend::get_tensor_info(a_id);
switch (dtype) {
case DType::float32:
binary_f32(a_id, b_id, out_id, div<float>);
binary_xnn_f32(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr,
b_shape_len, out_id, xnn_create_divide_nd_f32,
xnn_setup_divide_nd_f32);
break;
case DType::int32:
binary_i32(a_id, b_id, out_id, div<int32_t>);
Expand All @@ -52,7 +53,7 @@ void Div(const size_t a_id, const size_t* a_shape_ptr, const size_t a_shape_len,
binary_bool(a_id, b_id, out_id, div<bool>);
break;
default:
util::warn("Mul for tensor ids %d and %d failed. Unknown dtype %d", a_id,
util::warn("Div for tensor ids %d and %d failed. Unknown dtype %d", a_id,
b_id, dtype);
}
}
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/kernels/Div.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
*/

import {registerBinaryKernel} from './binary_kernel';
const supportsFullBroadcast = false;
const supportsFullBroadcast = true;
registerBinaryKernel('Div', supportsFullBroadcast);
5 changes: 3 additions & 2 deletions tfjs-backend-wasm/src/kernels/binary_kernel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ export function registerBinaryKernel(
aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length,
CppDType[a.dtype], outId);

if (supportsFullBroadcast) {
// Currently only some float operations support full broadcast.
if (supportsFullBroadcast && a.dtype === 'float32') {
kernelFunc();
return out;
}
Expand All @@ -78,7 +79,7 @@ export function registerBinaryKernel(
} else {
throw new Error(
`Broadcasting along outer dims is not yet ` +
`supported for ${kernelName}.`);
`supported for ${a.dtype} ${kernelName}.`);
}
}

Expand Down
5 changes: 1 addition & 4 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,7 @@ const TEST_FILTERS: TestFilter[] = [
excludes: [
'gradient', // Gradient not defined yet.
'upcasts', // Cast not supported yet.
'broadcasting same rank Tensors different shape', // Broadcasting along
// inner dims not
// supported yet.
'divNoNan' // divNoNan not yet implemented.
'divNoNan' // divNoNan not yet implemented.
]
},
{
Expand Down

0 comments on commit b7a28db

Please sign in to comment.