Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@ export class BackendWasm extends KernelBackend {

constructor(public wasm: BackendWasmModule) {
super();
this.wasm.tfjs.init();
const newFuncPtr =
this.wasm.addFunction((strOffset: number, strSize: number) => {
// Convert the ASCII error string back to a JavaScript string to throw
// an error.
const heapValue =
this.wasm.HEAPU8.slice(strOffset, strOffset + strSize);
const str =
Array.from(heapValue).map(x => String.fromCharCode(x)).join('');
throw new Error(str);
}, 'vii');
this.wasm.tfjs.init(newFuncPtr);

this.dataIdMap = new DataStorage(this, engine());
}

Expand Down Expand Up @@ -163,7 +174,11 @@ async function init(): Promise<{wasm: BackendWasmModule}> {
const voidReturnType: string = null;
// Using the tfjs namespace to avoid conflict with emscripten's API.
wasm.tfjs = {
init: wasm.cwrap('init', null, []),
init: wasm.cwrap(
'init', null,
[
'number', // throwFnPointer
]),
registerTensor: wasm.cwrap(
'register_tensor', null,
[
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ cc_binary(
"-s MODULARIZE=1",
"-s EXPORT_NAME=WasmBackendModule",
"-s MALLOC=emmalloc",
"-s RESERVED_FUNCTION_POINTERS=20",
"-s EXTRA_EXPORTED_RUNTIME_METHODS=\"[\'addFunction\', \'removeFunction\', \'cwrap\']\"",
],
deps = [
":backend",
Expand Down
35 changes: 34 additions & 1 deletion tfjs-backend-wasm/src/cc/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
#endif

#include <xnnpack.h>
#include <cstring>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "src/cc/backend.h"
#include "src/cc/util.h"

namespace {
// Maps a unique tensor id to info about that tensor. The map owns all of its
Expand All @@ -32,6 +35,10 @@ std::unordered_map<int, TensorInfo> data;
// id.
std::unordered_map<int, std::vector<tfjs::backend::DisposeFunction>>
disposal_callbacks;

// Holds the function pointer to call back into JavaScript to throw an
// exception.
void (*throw_js_exception_fn)(char *str_location, int str_length);
} // namespace

namespace tfjs {
Expand Down Expand Up @@ -60,6 +67,27 @@ void register_disposal_callback(const int tensor_id,

const int num_tensors() { return data.size(); }

void throw_js_exception(const char *format, ...) {
va_list args;
va_start(args, format);

int size = tfjs::util::buffer_size(format, args);
tfjs::util::log("size %d", size);

char *cstr = new char[size]();
vsprintf(cstr, format, args);

tfjs::util::log(cstr);

throw_js_exception_fn(cstr, size);

va_end(args);
}

void set_throw_js_exception_fn(void (*f)(char *, int)) {
throw_js_exception_fn = f;
}

} // namespace backend

namespace wasm {
Expand All @@ -69,7 +97,12 @@ extern "C" {
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void init() { xnn_initialize(); }
void init(int throw_js_exception_ptr) {
xnn_initialize();

tfjs::backend::set_throw_js_exception_fn(
reinterpret_cast<void (*)(char *, int)>(throw_js_exception_ptr));
}

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
Expand Down
6 changes: 5 additions & 1 deletion tfjs-backend-wasm/src/cc/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,18 @@ void register_disposal_callback(int tensor_id, DisposeFunction dispose_fn);
// Returns the number of tensors registered and owned by the backend.
const int num_tensors();

// Throws a JavaScript exception.
void throw_js_exception(const char *format, ...);
void set_throw_js_exception_fn(void (*f)(char *, int));

// The number of instantiated XNN operators.
extern int xnn_operator_count;
} // namespace backend

namespace wasm {
extern "C" {
// Initializes the WASM backend.
void init();
void init(int fp);

// Registers a tensor with a tensor ID, size, and the pointer to where the
// tensor data lives.
Expand Down
39 changes: 18 additions & 21 deletions tfjs-backend-wasm/src/cc/backend_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

#include <gtest/gtest.h>

#include <string.h>

#include "src/cc/backend.h"
#include "src/cc/kernels/Prelu.h"
#include "src/cc/util.h"

TEST(BACKEND, register_tensor) {
tfjs::wasm::init();
tfjs::wasm::init(0);

ASSERT_EQ(0, tfjs::backend::num_tensors());

Expand Down Expand Up @@ -55,7 +58,7 @@ void fake_dispose_tensor_callback(int tensor_id) {
}
}
TEST(BACKEND, disposal_callback) {
tfjs::wasm::init();
tfjs::wasm::init(0);

ASSERT_EQ(0, tfjs::backend::num_tensors());

Expand Down Expand Up @@ -92,28 +95,22 @@ TEST(BACKEND, disposal_callback) {
tensor_1_callback_count = 0;
}

TEST(BACKEND, dispose_backend) {
tfjs::wasm::init();

ASSERT_EQ(0, tfjs::backend::num_tensors());
// 100 is longer than the longest string for error messages used.
char last_str[200];
int last_str_size = -1;
void throw_js_exception_fake(char* str, int str_size) {
strcpy(last_str, str);
last_str_size = str_size;
}

const int tensor_id_0 = 0;
const int tensor_id_1 = 1;
const int size = 2;
float values_0[size] = {1, 2};
float values_1[size] = {3, 4};
TEST(BACKEND, throw_js_exception) {
tfjs::backend::set_throw_js_exception_fn(&throw_js_exception_fake);

tfjs::wasm::register_tensor(tensor_id_0, size, values_0);
tfjs::wasm::register_tensor(tensor_id_1, size, values_1);
ASSERT_EQ(2, tfjs::backend::num_tensors());
ASSERT_EQ(0, tfjs::backend::xnn_operator_count);
ASSERT_STREQ(last_str, "");
ASSERT_EQ(last_str_size, -1);

// One new xnn_operator should be created for the first call to prelu.
tfjs::wasm::Prelu(tensor_id_0, tensor_id_0, tensor_id_1);
ASSERT_EQ(1, tfjs::backend::xnn_operator_count);
tfjs::backend::throw_js_exception("fake error message %d", 22);
ASSERT_STREQ("fake error message 22", last_str);

// Dispose removes all tensors and xnn operators.
tfjs::wasm::dispose();
ASSERT_EQ(0, tfjs::backend::num_tensors());
ASSERT_EQ(0, tfjs::backend::xnn_operator_count);
}
6 changes: 4 additions & 2 deletions tfjs-backend-wasm/src/cc/kernels/Add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ EMSCRIPTEN_KEEPALIVE
#endif
void Add(const int a_id, const int b_id, const DType dtype, const int out_id) {
auto& a_info = backend::get_tensor_info(a_id);

switch (dtype) {
case DType::float32:
binary_f32(a_id, b_id, out_id, add<float>);
Expand All @@ -48,8 +49,9 @@ void Add(const int a_id, const int b_id, const DType dtype, const int out_id) {
binary_bool(a_id, b_id, out_id, add<bool>);
break;
default:
util::warn("Add for tensor ids %d and %d failed. Unknown dtype %d", a_id,
b_id, dtype);
tfjs::backend::throw_js_exception(
"Add for tensor ids %d and %d failed. Unknown dtype %d", a_id, b_id,
dtype);
}
}

Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-wasm/src/cc/kernels/Conv2D.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void Conv2D(const int x_id, const int batch_size, const int input_height,
output_channels /* output_pixel_stride */, transposed_filter.data(),
bias_buf, output_min, output_max, flags, &conv2d_op);
if (status != xnn_status_success) {
util::warn(
tfjs::backend::throw_js_exception(
"XNN status for xnn_create_convolution2d_nhwc_f32 is not successful. "
"Got status %d. Use -c dbg to see XNN logs.",
status);
Expand Down Expand Up @@ -155,7 +155,7 @@ void Conv2D(const int x_id, const int batch_size, const int input_height,
conv2d_op, batch_size, input_height, input_width, x_buf, out_buf,
nullptr /* thread pool */);
if (status != xnn_status_success) {
util::warn(
tfjs::backend::throw_js_exception(
"XNN status for xnn_setup_convolution2d_nhwc_f32 is not successful. "
"Got status %d. Use -c dbg to see XNN logs.",
status);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/Conv2D_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "src/cc/util.h"

TEST(CONV2D, xnn_operator_lifetime) {
tfjs::wasm::init();
tfjs::wasm::init(0);

ASSERT_EQ(0, tfjs::backend::num_tensors());

Expand Down
5 changes: 3 additions & 2 deletions tfjs-backend-wasm/src/cc/kernels/Div.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ void Div(const int a_id, const int b_id, const DType dtype, const int out_id) {
binary_bool(a_id, b_id, out_id, div<bool>);
break;
default:
util::warn("Mul for tensor ids %d and %d failed. Unknown dtype %d", a_id,
b_id, dtype);
tfjs::backend::throw_js_exception(
"Mul for tensor ids %d and %d failed. Unknown dtype %d", a_id, b_id,
dtype);
}
}

Expand Down
5 changes: 3 additions & 2 deletions tfjs-backend-wasm/src/cc/kernels/Mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ void Mul(const int a_id, const int b_id, const DType dtype, const int out_id) {
binary_bool(a_id, b_id, out_id, mul<bool>);
break;
default:
util::warn("Mul for tensor ids %d and %d failed. Unknown dtype %d", a_id,
b_id, dtype);
tfjs::backend::throw_js_exception(
"Mul for tensor ids %d and %d failed. Unknown dtype %d", a_id, b_id,
dtype);
}
}

Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-wasm/src/cc/kernels/Prelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void Prelu(const int x_id, const int weights_id, const int out_id) {
xnn_create_prelu_nc_f32(channels, strides, strides, weights_buf,
output_min, output_max, flags, &prelu_op);
if (status != xnn_status_success) {
util::warn(
tfjs::backend::throw_js_exception(
"XNN status for xnn_create_prelu_nc_f32 is not successful. Got "
"status %d. Use -c dbg to see XNN logs.",
status);
Expand All @@ -90,7 +90,7 @@ void Prelu(const int x_id, const int weights_id, const int out_id) {
xnn_status status = xnn_setup_prelu_nc_f32(
prelu_op, batch_size, x_buf, out_buf, nullptr /* thread pool */);
if (status != xnn_status_success) {
util::warn(
tfjs::backend::throw_js_exception(
"XNN status for xnn_setup_prelu_nc_f32 is not successful. Got "
"status %d. Use -c dbg to see XNN logs.",
status);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/cc/kernels/Prelu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "src/cc/kernels/Prelu.h"

TEST(PRELU, xnn_operator_lifetime) {
tfjs::wasm::init();
tfjs::wasm::init(0);

ASSERT_EQ(0, tfjs::backend::num_tensors());

Expand Down
5 changes: 3 additions & 2 deletions tfjs-backend-wasm/src/cc/kernels/Sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ void Sub(const int a_id, const int b_id, const DType dtype, const int out_id) {
binary_bool(a_id, b_id, out_id, sub<bool>);
break;
default:
util::warn("Sub for tensor ids %d and %d failed. Unknown dtype %d", a_id,
b_id, dtype);
tfjs::backend::throw_js_exception(
"Sub for tensor ids %d and %d failed. Unknown dtype %d", a_id, b_id,
dtype);
}
}

Expand Down
9 changes: 9 additions & 0 deletions tfjs-backend-wasm/src/cc/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ inline void print_log(const char* format, ...) {
va_list args;
va_start(args, format);
print_log(format, args);
va_end(args);
}

inline void print_warn(const char* format, ...) {
va_list args;
va_start(args, format);
print_warn(format, args);
va_end(args);
}

// Logs and flushes the message to the js console (console.log).
Expand All @@ -53,6 +55,7 @@ inline void log(const char* format, ...) {
va_start(args, format);
print_log(format, args);
print_log("\n");
va_end(args);
}

// Logs and flushes the message to the js console (console.err).
Expand All @@ -61,6 +64,7 @@ inline void warn(const char* format, ...) {
va_start(args, format);
print_warn(format, args);
print_warn("\n");
va_end(args);
}

// Helper method to log values in a vector. Used for debugging.
Expand All @@ -82,6 +86,11 @@ inline int size_from_shape(const std::vector<int>& shape) {
return prod;
}

inline int buffer_size(const char* format, va_list args) {
int result = vsnprintf(NULL, 0, format, args);
return result;
}

// Returns the indices of an n-dim tensor given the flat offset and its strides.
inline const std::vector<int> offset_to_loc(int index,
const std::vector<int>& strides) {
Expand Down
12 changes: 12 additions & 0 deletions tfjs-backend-wasm/test.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<script src="../tfjs-core/dist/tf-core.js"></script>
<script src="./dist/tf-backend-wasm.js"></script>


<script>

tf.setBackend('wasm');
tf.ready().then(() => {
tf.add(3, 3).print();
})

</script>
6 changes: 5 additions & 1 deletion tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

export interface BackendWasmModule extends EmscriptenModule {
onRuntimeInitialized: () => void;
callMain: (fns: string[]) => void;
addFunction:
(fn: (strOffset: number, strSize: number) => void,
signature: string) => number;
// Using the tfjs namespace to avoid conflict with emscripten's API.
tfjs: {
init(): void,
init(throwFnPointer: number): void,
registerTensor(dataId: number, size: number, memoryOffset: number): void,
// Disposes the data behind the data bucket.
disposeData(dataId: number): void,
Expand Down