diff --git a/tfjs-backend-wasm/README.md b/tfjs-backend-wasm/README.md index 79dc70eb7e..7e461c92ba 100644 --- a/tfjs-backend-wasm/README.md +++ b/tfjs-backend-wasm/README.md @@ -91,6 +91,36 @@ If the steps above are correctly done, you can check the Network tab from the console and make sure the tfjs-backend-wasm-threaded-simd.wasm WASM binary is loaded. +## Threads count + +By default, the backend will use the number of logical CPU cores as the +threads count when creating the threadpool used by XNNPACK. You can use the +`setThreadsCount` API to manually set it (must be called before calling +`tf.setBackend('wasm')`). `getThreadsCount` API can be used to get the actual +number of threads being used (must be called after the WASM backend is +initialized). + +### Via NPM + +```js +import * as tf from '@tensorflow/tfjs'; +import {getThreadsCount, setThreadsCount} from '@tensorflow/tfjs-backend-wasm'; + +setThreadsCount(2); +tf.setBackend('wasm').then(() => { + console.log(getThreadsCount()); +}); +``` + +### Via script tag + +```js +tf.wasm.setThreadsCount(2); +tf.setBackend('wasm').then(() => { + consosle.log(tf.wasm.getThreadsCount()); +}); +``` + ## Running MobileNet ```js diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index b229ad83f4..83c907b15a 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -42,9 +42,10 @@ export class BackendWasm extends KernelBackend { private dataIdNextNumber = 1; dataIdMap: DataStorage; - constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) { + constructor(public wasm: BackendWasmModule|BackendWasmThreadedSimdModule) { super(); - this.wasm.tfjs.init(); + this.wasm.tfjs.initWithThreadsCount(threadsCount); + actualThreadsCount = this.wasm.tfjs.getThreadsCount(); this.dataIdMap = new DataStorage(this, engine()); } @@ -210,7 +211,7 @@ export class BackendWasm extends KernelBackend { } function createInstantiateWasmFunc(path: string) { - // this will be replace by rollup plugin patchWechatWebAssembly in + // this will be replace by rollup plugin patchWechatWebAssembly in // minprogram's output. // tslint:disable-next-line:no-any return (imports: any, callback: any) => { @@ -346,6 +347,9 @@ export async function init(): Promise<{wasm: BackendWasmModule}> { // Using the tfjs namespace to avoid conflict with emscripten's API. module.tfjs = { init: module.cwrap('init', null, []), + initWithThreadsCount: + module.cwrap('init_with_threads_count', null, ['number']), + getThreadsCount: module.cwrap('get_threads_count', 'number', []), registerTensor: module.cwrap( 'register_tensor', null, [ @@ -474,3 +478,28 @@ export function resetWasmPath(): void { customFetch = false; initAborted = false; } + +let threadsCount = -1; +let actualThreadsCount = -1; + +/** + * Sets the number of threads that will be used by XNNPACK to create + * threadpool (default to the number of logical CPU cores). + * + * This must be called before calling `tf.setBackend('wasm')`. + */ +export function setThreadsCount(numThreads: number) { + threadsCount = numThreads; +} + +/** + * Gets the actual threads count that is used by XNNPACK. + * + * It is set after the backend is intialized. + */ +export function getThreadsCount(): number { + if (actualThreadsCount === -1) { + throw new Error(`WASM backend not initialized.`); + } + return actualThreadsCount; +} diff --git a/tfjs-backend-wasm/src/base.ts b/tfjs-backend-wasm/src/base.ts index ae9b9da1a6..25c050c20a 100644 --- a/tfjs-backend-wasm/src/base.ts +++ b/tfjs-backend-wasm/src/base.ts @@ -21,7 +21,7 @@ import {registerBackend} from '@tensorflow/tfjs-core'; import {BackendWasm, init} from './backend_wasm'; -export {BackendWasm, setWasmPath, setWasmPaths} from './backend_wasm'; +export {BackendWasm, getThreadsCount, setThreadsCount, setWasmPath, setWasmPaths} from './backend_wasm'; export {version as version_wasm} from './version'; const WASM_PRIORITY = 2; diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 12c4e97b67..1df5aa80e9 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -58,9 +58,12 @@ cc_binary( "-s MALLOC=emmalloc", "-s USE_PTHREADS=1", "-s PROXY_TO_PTHREAD=1", - # Many x86-64 processors have 2 threads per core, so we divide by 2. - "-s PTHREAD_POOL_SIZE=" + - "'Math.min(4, Math.max(1, (navigator.hardwareConcurrency || 1) / 2))'", + # Pre-create 8 webworkers (threads). The actual number of threads that + # will be used by XNNPACK for creating the threadpool might be fewer. It + # is by default set to the number of logical cores which in most cases + # should be fewer than 8. It can also be set manually by calling the + # `setThreadsCount` API. + "-s PTHREAD_POOL_SIZE=8", ], deps = [ ":all_kernels", diff --git a/tfjs-backend-wasm/src/cc/backend.cc b/tfjs-backend-wasm/src/cc/backend.cc index 72f2428c7f..5a5fee863f 100644 --- a/tfjs-backend-wasm/src/cc/backend.cc +++ b/tfjs-backend-wasm/src/cc/backend.cc @@ -52,18 +52,7 @@ TensorInfo &get_tensor_info_out(const size_t tensor_id) { size_t xnn_operator_count = 0; -// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency. -// Many x86-64 processors have 2 threads per core, so we are dividing by 2. -#ifdef __EMSCRIPTEN_PTHREADS__ -int num_cores = emscripten_num_logical_cores() / 2; -#else -int num_cores = 1; -#endif - -int min_num_threads = 1; -int max_num_threads = 4; -pthreadpool *threadpool = pthreadpool_create( - std::min(std::max(num_cores, min_num_threads), max_num_threads)); +pthreadpool *threadpool = NULL; // Registers a disposal callback for a tensor id with a given callback function. void register_disposal_callback(const size_t tensor_id, @@ -82,13 +71,52 @@ const size_t num_tensors() { return data.size(); } } // namespace backend namespace wasm { + +// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency. +// Many x86-64 processors have 2 threads per core, so we are dividing by 2. +#ifdef __EMSCRIPTEN_PTHREADS__ +int num_cores = emscripten_num_logical_cores() / 2; +#else +int num_cores = 1; +#endif + +int min_num_threads = 1; +// In cc/BUILD, we pre-created 8 webworker threads which is the maximum number +// of threads that the threadpool could use here. +int max_num_threads = 8; + // We use C-style API to interface with Javascript. + extern "C" { #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE #endif -void init() { xnn_initialize(nullptr); } +void init() { init_with_threads_count(num_cores); } + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void init_with_threads_count(const int threads_count) { + int count = threads_count; + if (threads_count < 0) { + count = num_cores; + } + int capped_num_threads = std::min( + std::min(std::max(count, min_num_threads), max_num_threads), num_cores); + tfjs::backend::threadpool = pthreadpool_create(capped_num_threads); + xnn_initialize(nullptr); +} + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +int get_threads_count() { + if (tfjs::backend::threadpool == NULL) { + return -1; + } + return pthreadpool_get_threads_count(tfjs::backend::threadpool); +} #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE @@ -139,6 +167,9 @@ void dispose() { data.clear(); disposal_callbacks.clear(); + + pthreadpool_destroy(tfjs::backend::threadpool); + tfjs::backend::threadpool = NULL; } } // extern "C" diff --git a/tfjs-backend-wasm/src/cc/backend.h b/tfjs-backend-wasm/src/cc/backend.h index 1281ed16d0..89bfc8161d 100644 --- a/tfjs-backend-wasm/src/cc/backend.h +++ b/tfjs-backend-wasm/src/cc/backend.h @@ -99,6 +99,12 @@ extern "C" { // Initializes the WASM backend. void init(); +// Initializes the WASM backend with the given threads count. +void init_with_threads_count(const int threads_count); + +// Get the actual number of threads used in the XNNPACK threadpool. +int get_threads_count(); + // Registers a tensor with a tensor ID, size, and the pointer to where the // tensor data lives. void register_tensor(const size_t tensor_id, const size_t size, diff --git a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts index 2af757f8bc..f13a936920 100644 --- a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts +++ b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts @@ -19,6 +19,8 @@ export interface BackendWasmModule extends EmscriptenModule { // Using the tfjs namespace to avoid conflict with emscripten's API. tfjs: { init(): void, + initWithThreadsCount(threadsCount: number): void, + getThreadsCount(): number, registerTensor(id: number, size: number, memoryOffset: number): void, // Disposes the data behind the data bucket. disposeData(id: number): void,