Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Allow users to set threads count at runtime #5727

Merged
merged 7 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tfjs-backend-wasm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ If the steps above are correctly done, you can check the Network tab from the
console and make sure the
<code>tfjs-backend-wasm-<b>threaded-simd</b>.wasm</code> WASM binary is loaded.

## Setting threads count

By default, the backend will use the number of logical CPU cores as the
threads count when creating the threadpool used by XNNPACK. You can use the
`setThreadsCount` API to manually set it (must be called before calling
`tf.setBackend('wasm')`).

### Via NPM

```js
import * as tf from '@tensorflow/tfjs';
import {setThreadsCount} from '@tensorflow/tfjs-backend-wasm';

setThreadsCount(2);
tf.setBackend('wasm').then(() => {...});
```

### Via script tag

```js
tf.wasm.setThreadsCount(2);
tf.setBackend('wasm').then(() => {...});
```

## Running MobileNet

```js
Expand Down
20 changes: 17 additions & 3 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ export class BackendWasm extends KernelBackend {
private dataIdNextNumber = 1;
dataIdMap: DataStorage<TensorData>;

constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) {
constructor(public wasm: BackendWasmModule|BackendWasmThreadedSimdModule) {
super();
this.wasm.tfjs.init();
this.wasm.tfjs.initWithThreadsCount(threadsCount);
this.dataIdMap = new DataStorage(this, engine());
}

Expand Down Expand Up @@ -210,7 +210,7 @@ export class BackendWasm extends KernelBackend {
}

function createInstantiateWasmFunc(path: string) {
// this will be replace by rollup plugin patchWechatWebAssembly in
// this will be replace by rollup plugin patchWechatWebAssembly in
// minprogram's output.
// tslint:disable-next-line:no-any
return (imports: any, callback: any) => {
Expand Down Expand Up @@ -346,6 +346,8 @@ export async function init(): Promise<{wasm: BackendWasmModule}> {
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
initWithThreadsCount:
module.cwrap('init_with_threads_count', null, ['number']),
registerTensor: module.cwrap(
'register_tensor', null,
[
Expand Down Expand Up @@ -474,3 +476,15 @@ export function resetWasmPath(): void {
customFetch = false;
initAborted = false;
}

let threadsCount = -1;

/**
* Sets the number of threads that will be used by XNNPACK to create
* threadpool (default to the number of logical CPU cores).
*
* This must be called before calling `tf.setBackend('wasm')`.
*/
export function setThreadsCount(numThreads: number) {
threadsCount = numThreads;
}
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {registerBackend} from '@tensorflow/tfjs-core';

import {BackendWasm, init} from './backend_wasm';

export {BackendWasm, setWasmPath, setWasmPaths} from './backend_wasm';
export {BackendWasm, setThreadsCount, setWasmPath, setWasmPaths} from './backend_wasm';
export {version as version_wasm} from './version';

const WASM_PRIORITY = 2;
Expand Down
9 changes: 6 additions & 3 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ cc_binary(
"-s MALLOC=emmalloc",
"-s USE_PTHREADS=1",
"-s PROXY_TO_PTHREAD=1",
# Many x86-64 processors have 2 threads per core, so we divide by 2.
"-s PTHREAD_POOL_SIZE=" +
"'Math.min(4, Math.max(1, (navigator.hardwareConcurrency || 1) / 2))'",
# Pre-create 8 webworkers (threads). The actual number of threads that
# will be used by XNNPACK for creating the threadpool might be fewer. It
# is by default set to the number of logical cores which in most cases
# should be fewer than 8. It can also be set manually by calling the
# `setThreadsCount` API.
"-s PTHREAD_POOL_SIZE=8",
],
deps = [
":all_kernels",
Expand Down
47 changes: 34 additions & 13 deletions tfjs-backend-wasm/src/cc/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,7 @@ TensorInfo &get_tensor_info_out(const size_t tensor_id) {

size_t xnn_operator_count = 0;

// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency.
// Many x86-64 processors have 2 threads per core, so we are dividing by 2.
#ifdef __EMSCRIPTEN_PTHREADS__
int num_cores = emscripten_num_logical_cores() / 2;
#else
int num_cores = 1;
#endif

int min_num_threads = 1;
int max_num_threads = 4;
pthreadpool *threadpool = pthreadpool_create(
std::min(std::max(num_cores, min_num_threads), max_num_threads));
pthreadpool *threadpool = NULL;

// Registers a disposal callback for a tensor id with a given callback function.
void register_disposal_callback(const size_t tensor_id,
Expand All @@ -82,13 +71,42 @@ const size_t num_tensors() { return data.size(); }
} // namespace backend

namespace wasm {

// emscripten_num_logical_cores corresponds to navigator.hardwareConcurrency.
// Many x86-64 processors have 2 threads per core, so we are dividing by 2.
#ifdef __EMSCRIPTEN_PTHREADS__
int num_cores = emscripten_num_logical_cores() / 2;
#else
int num_cores = 1;
#endif

int min_num_threads = 1;
// In cc/BUILD, we pre-created 8 webworker threads which is the maximum number
// of threads that the threadpool could use here.
int max_num_threads = 8;

// We use C-style API to interface with Javascript.

extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void init() { xnn_initialize(nullptr); }
void init() { init_with_threads_count(num_cores); }

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void init_with_threads_count(const int threads_count) {
int count = threads_count;
if (threads_count < 0) {
count = num_cores;
}
int capped_num_threads = std::min(
std::min(std::max(count, min_num_threads), max_num_threads), num_cores);
tfjs::backend::threadpool = pthreadpool_create(capped_num_threads);
xnn_initialize(nullptr);
}

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
Expand Down Expand Up @@ -139,6 +157,9 @@ void dispose() {

data.clear();
disposal_callbacks.clear();

pthreadpool_destroy(tfjs::backend::threadpool);
tfjs::backend::threadpool = NULL;
}

} // extern "C"
Expand Down
3 changes: 3 additions & 0 deletions tfjs-backend-wasm/src/cc/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ extern "C" {
// Initializes the WASM backend.
void init();

// Initializes the WASM backend with the given threads count.
void init_with_threads_count(const int threads_count);

// Registers a tensor with a tensor ID, size, and the pointer to where the
// tensor data lives.
void register_tensor(const size_t tensor_id, const size_t size,
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export interface BackendWasmModule extends EmscriptenModule {
// Using the tfjs namespace to avoid conflict with emscripten's API.
tfjs: {
init(): void,
initWithThreadsCount(threadsCount: number): void,
registerTensor(id: number, size: number, memoryOffset: number): void,
// Disposes the data behind the data bucket.
disposeData(id: number): void,
Expand Down