Skip to content

Commit

Permalink
[wasm] Allow users to set threads count at runtime (#5727)
Browse files Browse the repository at this point in the history
* [wasm] Allow users to set threads count at runtime

* update

* fix

* fix

* address comments

* fix
  • Loading branch information
jinjingforever committed Oct 15, 2021
1 parent bceafd1 commit 306b577
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 20 deletions.
30 changes: 30 additions & 0 deletions tfjs-backend-wasm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,36 @@ If the steps above are correctly done, you can check the Network tab from the
console and make sure the
<code>tfjs-backend-wasm-<b>threaded-simd</b>.wasm</code> WASM binary is loaded.

## Threads count

By default, the backend will use the number of logical CPU cores as the
threads count when creating the threadpool used by XNNPACK. You can use the
`setThreadsCount` API to manually set it (must be called before calling
`tf.setBackend('wasm')`). `getThreadsCount` API can be used to get the actual
number of threads being used (must be called after the WASM backend is
initialized).

### Via NPM

```js
import * as tf from '@tensorflow/tfjs';
import {getThreadsCount, setThreadsCount} from '@tensorflow/tfjs-backend-wasm';

setThreadsCount(2);
tf.setBackend('wasm').then(() => {
console.log(getThreadsCount());
});
```

### Via script tag

```js
tf.wasm.setThreadsCount(2);
tf.setBackend('wasm').then(() => {
consosle.log(tf.wasm.getThreadsCount());
});
```

## Running MobileNet

```js
Expand Down
35 changes: 32 additions & 3 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ export class BackendWasm extends KernelBackend {
private dataIdNextNumber = 1;
dataIdMap: DataStorage<TensorData>;

constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) {
constructor(public wasm: BackendWasmModule|BackendWasmThreadedSimdModule) {
super();
this.wasm.tfjs.init();
this.wasm.tfjs.initWithThreadsCount(threadsCount);
actualThreadsCount = this.wasm.tfjs.getThreadsCount();
this.dataIdMap = new DataStorage(this, engine());
}

Expand Down Expand Up @@ -210,7 +211,7 @@ export class BackendWasm extends KernelBackend {
}

function createInstantiateWasmFunc(path: string) {
// this will be replace by rollup plugin patchWechatWebAssembly in
// this will be replace by rollup plugin patchWechatWebAssembly in
// minprogram's output.
// tslint:disable-next-line:no-any
return (imports: any, callback: any) => {
Expand Down Expand Up @@ -346,6 +347,9 @@ export async function init(): Promise<{wasm: BackendWasmModule}> {
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
initWithThreadsCount:
module.cwrap('init_with_threads_count', null, ['number']),
getThreadsCount: module.cwrap('get_threads_count', 'number', []),
registerTensor: module.cwrap(
'register_tensor', null,
[
Expand Down Expand Up @@ -474,3 +478,28 @@ export function resetWasmPath(): void {
customFetch = false;
initAborted = false;
}

let threadsCount = -1;
let actualThreadsCount = -1;

/**
* Sets the number of threads that will be used by XNNPACK to create
* threadpool (default to the number of logical CPU cores).
*
* This must be called before calling `tf.setBackend('wasm')`.
*/
export function setThreadsCount(numThreads: number) {
threadsCount = numThreads;
}

/**
* Gets the actual threads count that is used by XNNPACK.
*
* It is set after the backend is intialized.
*/
export function getThreadsCount(): number {
if (actualThreadsCount === -1) {
throw new Error(`WASM backend not initialized.`);
}
return actualThreadsCount;
}
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {registerBackend} from '@tensorflow/tfjs-core';

import {BackendWasm, init} from './backend_wasm';

export {BackendWasm, setWasmPath, setWasmPaths} from './backend_wasm';
export {BackendWasm, getThreadsCount, setThreadsCount, setWasmPath, setWasmPaths} from './backend_wasm';
export {version as version_wasm} from './version';

const WASM_PRIORITY = 2;
Expand Down
9 changes: 6 additions & 3 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ cc_binary(
"-s MALLOC=emmalloc",
"-s USE_PTHREADS=1",
"-s PROXY_TO_PTHREAD=1",
# Many x86-64 processors have 2 threads per core, so we divide by 2.
"-s PTHREAD_POOL_SIZE=" +
"'Math.min(4, Math.max(1, (navigator.hardwareConcurrency || 1) / 2))'",
# Pre-create 8 webworkers (threads). The actual number of threads that
# will be used by XNNPACK for creating the threadpool might be fewer. It
# is by default set to the number of logical cores which in most cases
# should be fewer than 8. It can also be set manually by calling the
# `setThreadsCount` API.
"-s PTHREAD_POOL_SIZE=8",
],
deps = [
":all_kernels",
Expand Down
57 changes: 44 additions & 13 deletions tfjs-backend-wasm/src/cc/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,7 @@ TensorInfo &get_tensor_info_out(const size_t tensor_id) {

size_t xnn_operator_count = 0;

// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency.
// Many x86-64 processors have 2 threads per core, so we are dividing by 2.
#ifdef __EMSCRIPTEN_PTHREADS__
int num_cores = emscripten_num_logical_cores() / 2;
#else
int num_cores = 1;
#endif

int min_num_threads = 1;
int max_num_threads = 4;
pthreadpool *threadpool = pthreadpool_create(
std::min(std::max(num_cores, min_num_threads), max_num_threads));
pthreadpool *threadpool = NULL;

// Registers a disposal callback for a tensor id with a given callback function.
void register_disposal_callback(const size_t tensor_id,
Expand All @@ -82,13 +71,52 @@ const size_t num_tensors() { return data.size(); }
} // namespace backend

namespace wasm {

// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency.
// Many x86-64 processors have 2 threads per core, so we are dividing by 2.
#ifdef __EMSCRIPTEN_PTHREADS__
int num_cores = emscripten_num_logical_cores() / 2;
#else
int num_cores = 1;
#endif

int min_num_threads = 1;
// In cc/BUILD, we pre-created 8 webworker threads which is the maximum number
// of threads that the threadpool could use here.
int max_num_threads = 8;

// We use C-style API to interface with Javascript.

extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void init() { xnn_initialize(nullptr); }
void init() { init_with_threads_count(num_cores); }

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void init_with_threads_count(const int threads_count) {
int count = threads_count;
if (threads_count < 0) {
count = num_cores;
}
int capped_num_threads = std::min(
std::min(std::max(count, min_num_threads), max_num_threads), num_cores);
tfjs::backend::threadpool = pthreadpool_create(capped_num_threads);
xnn_initialize(nullptr);
}

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
int get_threads_count() {
if (tfjs::backend::threadpool == NULL) {
return -1;
}
return pthreadpool_get_threads_count(tfjs::backend::threadpool);
}

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
Expand Down Expand Up @@ -139,6 +167,9 @@ void dispose() {

data.clear();
disposal_callbacks.clear();

pthreadpool_destroy(tfjs::backend::threadpool);
tfjs::backend::threadpool = NULL;
}

} // extern "C"
Expand Down
6 changes: 6 additions & 0 deletions tfjs-backend-wasm/src/cc/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ extern "C" {
// Initializes the WASM backend.
void init();

// Initializes the WASM backend with the given threads count.
void init_with_threads_count(const int threads_count);

// Get the actual number of threads used in the XNNPACK threadpool.
int get_threads_count();

// Registers a tensor with a tensor ID, size, and the pointer to where the
// tensor data lives.
void register_tensor(const size_t tensor_id, const size_t size,
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export interface BackendWasmModule extends EmscriptenModule {
// Using the tfjs namespace to avoid conflict with emscripten's API.
tfjs: {
init(): void,
initWithThreadsCount(threadsCount: number): void,
getThreadsCount(): number,
registerTensor(id: number, size: number, memoryOffset: number): void,
// Disposes the data behind the data bucket.
disposeData(id: number): void,
Expand Down

0 comments on commit 306b577

Please sign in to comment.