diff --git a/tfjs-backend-wasm/README.md b/tfjs-backend-wasm/README.md
index 79dc70eb7e..7e461c92ba 100644
--- a/tfjs-backend-wasm/README.md
+++ b/tfjs-backend-wasm/README.md
@@ -91,6 +91,36 @@ If the steps above are correctly done, you can check the Network tab from the
console and make sure the
tfjs-backend-wasm-threaded-simd.wasm
WASM binary is loaded.
+## Threads count
+
+By default, the backend will use the number of logical CPU cores as the
+threads count when creating the threadpool used by XNNPACK. You can use the
+`setThreadsCount` API to manually set it (must be called before calling
+`tf.setBackend('wasm')`). `getThreadsCount` API can be used to get the actual
+number of threads being used (must be called after the WASM backend is
+initialized).
+
+### Via NPM
+
+```js
+import * as tf from '@tensorflow/tfjs';
+import {getThreadsCount, setThreadsCount} from '@tensorflow/tfjs-backend-wasm';
+
+setThreadsCount(2);
+tf.setBackend('wasm').then(() => {
+ console.log(getThreadsCount());
+});
+```
+
+### Via script tag
+
+```js
+tf.wasm.setThreadsCount(2);
+tf.setBackend('wasm').then(() => {
+ consosle.log(tf.wasm.getThreadsCount());
+});
+```
+
## Running MobileNet
```js
diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts
index b229ad83f4..83c907b15a 100644
--- a/tfjs-backend-wasm/src/backend_wasm.ts
+++ b/tfjs-backend-wasm/src/backend_wasm.ts
@@ -42,9 +42,10 @@ export class BackendWasm extends KernelBackend {
private dataIdNextNumber = 1;
dataIdMap: DataStorage;
- constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) {
+ constructor(public wasm: BackendWasmModule|BackendWasmThreadedSimdModule) {
super();
- this.wasm.tfjs.init();
+ this.wasm.tfjs.initWithThreadsCount(threadsCount);
+ actualThreadsCount = this.wasm.tfjs.getThreadsCount();
this.dataIdMap = new DataStorage(this, engine());
}
@@ -210,7 +211,7 @@ export class BackendWasm extends KernelBackend {
}
function createInstantiateWasmFunc(path: string) {
- // this will be replace by rollup plugin patchWechatWebAssembly in
+ // this will be replace by rollup plugin patchWechatWebAssembly in
// minprogram's output.
// tslint:disable-next-line:no-any
return (imports: any, callback: any) => {
@@ -346,6 +347,9 @@ export async function init(): Promise<{wasm: BackendWasmModule}> {
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
+ initWithThreadsCount:
+ module.cwrap('init_with_threads_count', null, ['number']),
+ getThreadsCount: module.cwrap('get_threads_count', 'number', []),
registerTensor: module.cwrap(
'register_tensor', null,
[
@@ -474,3 +478,28 @@ export function resetWasmPath(): void {
customFetch = false;
initAborted = false;
}
+
+let threadsCount = -1;
+let actualThreadsCount = -1;
+
+/**
+ * Sets the number of threads that will be used by XNNPACK to create
+ * threadpool (default to the number of logical CPU cores).
+ *
+ * This must be called before calling `tf.setBackend('wasm')`.
+ */
+export function setThreadsCount(numThreads: number) {
+ threadsCount = numThreads;
+}
+
+/**
+ * Gets the actual threads count that is used by XNNPACK.
+ *
+ * It is set after the backend is intialized.
+ */
+export function getThreadsCount(): number {
+ if (actualThreadsCount === -1) {
+ throw new Error(`WASM backend not initialized.`);
+ }
+ return actualThreadsCount;
+}
diff --git a/tfjs-backend-wasm/src/base.ts b/tfjs-backend-wasm/src/base.ts
index ae9b9da1a6..25c050c20a 100644
--- a/tfjs-backend-wasm/src/base.ts
+++ b/tfjs-backend-wasm/src/base.ts
@@ -21,7 +21,7 @@ import {registerBackend} from '@tensorflow/tfjs-core';
import {BackendWasm, init} from './backend_wasm';
-export {BackendWasm, setWasmPath, setWasmPaths} from './backend_wasm';
+export {BackendWasm, getThreadsCount, setThreadsCount, setWasmPath, setWasmPaths} from './backend_wasm';
export {version as version_wasm} from './version';
const WASM_PRIORITY = 2;
diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD
index 12c4e97b67..1df5aa80e9 100644
--- a/tfjs-backend-wasm/src/cc/BUILD
+++ b/tfjs-backend-wasm/src/cc/BUILD
@@ -58,9 +58,12 @@ cc_binary(
"-s MALLOC=emmalloc",
"-s USE_PTHREADS=1",
"-s PROXY_TO_PTHREAD=1",
- # Many x86-64 processors have 2 threads per core, so we divide by 2.
- "-s PTHREAD_POOL_SIZE=" +
- "'Math.min(4, Math.max(1, (navigator.hardwareConcurrency || 1) / 2))'",
+ # Pre-create 8 webworkers (threads). The actual number of threads that
+ # will be used by XNNPACK for creating the threadpool might be fewer. It
+ # is by default set to the number of logical cores which in most cases
+ # should be fewer than 8. It can also be set manually by calling the
+ # `setThreadsCount` API.
+ "-s PTHREAD_POOL_SIZE=8",
],
deps = [
":all_kernels",
diff --git a/tfjs-backend-wasm/src/cc/backend.cc b/tfjs-backend-wasm/src/cc/backend.cc
index 72f2428c7f..5a5fee863f 100644
--- a/tfjs-backend-wasm/src/cc/backend.cc
+++ b/tfjs-backend-wasm/src/cc/backend.cc
@@ -52,18 +52,7 @@ TensorInfo &get_tensor_info_out(const size_t tensor_id) {
size_t xnn_operator_count = 0;
-// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency.
-// Many x86-64 processors have 2 threads per core, so we are dividing by 2.
-#ifdef __EMSCRIPTEN_PTHREADS__
-int num_cores = emscripten_num_logical_cores() / 2;
-#else
-int num_cores = 1;
-#endif
-
-int min_num_threads = 1;
-int max_num_threads = 4;
-pthreadpool *threadpool = pthreadpool_create(
- std::min(std::max(num_cores, min_num_threads), max_num_threads));
+pthreadpool *threadpool = NULL;
// Registers a disposal callback for a tensor id with a given callback function.
void register_disposal_callback(const size_t tensor_id,
@@ -82,13 +71,52 @@ const size_t num_tensors() { return data.size(); }
} // namespace backend
namespace wasm {
+
+// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency.
+// Many x86-64 processors have 2 threads per core, so we are dividing by 2.
+#ifdef __EMSCRIPTEN_PTHREADS__
+int num_cores = emscripten_num_logical_cores() / 2;
+#else
+int num_cores = 1;
+#endif
+
+int min_num_threads = 1;
+// In cc/BUILD, we pre-created 8 webworker threads which is the maximum number
+// of threads that the threadpool could use here.
+int max_num_threads = 8;
+
// We use C-style API to interface with Javascript.
+
extern "C" {
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
-void init() { xnn_initialize(nullptr); }
+void init() { init_with_threads_count(num_cores); }
+
+#ifdef __EMSCRIPTEN__
+EMSCRIPTEN_KEEPALIVE
+#endif
+void init_with_threads_count(const int threads_count) {
+ int count = threads_count;
+ if (threads_count < 0) {
+ count = num_cores;
+ }
+ int capped_num_threads = std::min(
+ std::min(std::max(count, min_num_threads), max_num_threads), num_cores);
+ tfjs::backend::threadpool = pthreadpool_create(capped_num_threads);
+ xnn_initialize(nullptr);
+}
+
+#ifdef __EMSCRIPTEN__
+EMSCRIPTEN_KEEPALIVE
+#endif
+int get_threads_count() {
+ if (tfjs::backend::threadpool == NULL) {
+ return -1;
+ }
+ return pthreadpool_get_threads_count(tfjs::backend::threadpool);
+}
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
@@ -139,6 +167,9 @@ void dispose() {
data.clear();
disposal_callbacks.clear();
+
+ pthreadpool_destroy(tfjs::backend::threadpool);
+ tfjs::backend::threadpool = NULL;
}
} // extern "C"
diff --git a/tfjs-backend-wasm/src/cc/backend.h b/tfjs-backend-wasm/src/cc/backend.h
index 1281ed16d0..89bfc8161d 100644
--- a/tfjs-backend-wasm/src/cc/backend.h
+++ b/tfjs-backend-wasm/src/cc/backend.h
@@ -99,6 +99,12 @@ extern "C" {
// Initializes the WASM backend.
void init();
+// Initializes the WASM backend with the given threads count.
+void init_with_threads_count(const int threads_count);
+
+// Get the actual number of threads used in the XNNPACK threadpool.
+int get_threads_count();
+
// Registers a tensor with a tensor ID, size, and the pointer to where the
// tensor data lives.
void register_tensor(const size_t tensor_id, const size_t size,
diff --git a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
index 2af757f8bc..f13a936920 100644
--- a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
+++ b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
@@ -19,6 +19,8 @@ export interface BackendWasmModule extends EmscriptenModule {
// Using the tfjs namespace to avoid conflict with emscripten's API.
tfjs: {
init(): void,
+ initWithThreadsCount(threadsCount: number): void,
+ getThreadsCount(): number,
registerTensor(id: number, size: number, memoryOffset: number): void,
// Disposes the data behind the data bucket.
disposeData(id: number): void,