diff --git a/tfjs-backend-wasm/scripts/build-wasm.sh b/tfjs-backend-wasm/scripts/build-wasm.sh index 73e950a2d5..4171bbb894 100755 --- a/tfjs-backend-wasm/scripts/build-wasm.sh +++ b/tfjs-backend-wasm/scripts/build-wasm.sh @@ -33,18 +33,19 @@ cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm/tfjs-backend-was if [[ "$1" != "--dev" ]]; then # SIMD and threaded + SIMD builds. - yarn bazel build $BAZEL_REMOTE -c opt --copt="-msimd128" //tfjs-backend-wasm/src/cc:tfjs-backend-wasm-simd \ - //tfjs-backend-wasm/src/cc:tfjs-backend-wasm-threaded-simd + yarn bazel build $BAZEL_REMOTE -c opt --copt="-msimd128" //tfjs-backend-wasm/src/cc:tfjs-backend-wasm-simd + yarn bazel build $BAZEL_REMOTE -c opt --copt="-pthread" --copt="-msimd128" //tfjs-backend-wasm/src/cc:tfjs-backend-wasm-threaded-simd + # Copy SIMD - cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-simd/tfjs-backend-wasm.wasm \ + cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-simd/tfjs-backend-wasm-simd.wasm \ ../wasm-out/tfjs-backend-wasm-simd.wasm # Copy threaded - cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-threaded-simd/tfjs-backend-wasm.js \ + cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-threaded-simd/tfjs-backend-wasm-threaded-simd.js \ ../wasm-out/tfjs-backend-wasm-threaded-simd.js - cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-threaded-simd/tfjs-backend-wasm.worker.js \ + cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-threaded-simd/tfjs-backend-wasm-threaded-simd.worker.js \ ../wasm-out/tfjs-backend-wasm-threaded-simd.worker.js - cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-threaded-simd/tfjs-backend-wasm.wasm \ + cp -f ../../dist/bin/tfjs-backend-wasm/src/cc/tfjs-backend-wasm-threaded-simd/tfjs-backend-wasm-threaded-simd.wasm \ ../wasm-out/tfjs-backend-wasm-threaded-simd.wasm node ./create-worker-module.js diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index e33ac8729d..640c879223 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -19,6 +19,7 @@ import './flags_wasm'; import {backend_util, BackendTimingInfo, DataStorage, DataType, deprecationWarn, engine, env, KernelBackend, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm'; +import {BackendWasmThreadedSimdModule} from '../wasm-out/tfjs-backend-wasm-threaded-simd'; import wasmFactoryThreadedSimd from '../wasm-out/tfjs-backend-wasm-threaded-simd.js'; // @ts-ignore import {wasmWorkerContents} from '../wasm-out/tfjs-backend-wasm-threaded-simd.worker.js'; @@ -41,7 +42,7 @@ export class BackendWasm extends KernelBackend { private dataIdNextNumber = 1; dataIdMap: DataStorage; - constructor(public wasm: BackendWasmModule) { + constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) { super(); this.wasm.tfjs.init(); this.dataIdMap = new DataStorage(this, engine()); @@ -159,6 +160,9 @@ export class BackendWasm extends KernelBackend { dispose() { this.wasm.tfjs.dispose(); + if ('PThread' in this.wasm) { + this.wasm.PThread.terminateAllThreads(); + } this.wasm = null; } @@ -214,7 +218,7 @@ function createInstantiateWasmFunc(path: string) { } response.arrayBuffer().then(binary => { WebAssembly.instantiate(binary, imports).then(output => { - callback(output.instance); + callback(output.instance, output.module); }); }); }); diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 42f36072dc..bb4c0b056e 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -8,6 +8,18 @@ KERNELS_WITH_KEEPALIVE = glob( exclude = ["**/*_test.cc"], ) +BASE_LINKOPTS = [ + "-s ALLOW_MEMORY_GROWTH=1", + "-s DEFAULT_LIBRARY_FUNCS_TO_INCLUDE=[]", + "-s DISABLE_EXCEPTION_CATCHING=1", + "-s FILESYSTEM=0", + "-s EXIT_RUNTIME=0", + "-s EXPORTED_FUNCTIONS='[\"_malloc\", \"_free\"]'", + "-s EXTRA_EXPORTED_RUNTIME_METHODS='[\"cwrap\"]'", + "-s MODULARIZE=1", + "-s MALLOC=emmalloc", +] + # This build rule generates tfjs-backend-wasm.{js,wasm}. # # The ".js" at the end of the build name is significant because it determines @@ -19,17 +31,36 @@ KERNELS_WITH_KEEPALIVE = glob( cc_binary( name = "tfjs-backend-wasm.js", srcs = ["backend.cc"] + KERNELS_WITH_KEEPALIVE, - linkopts = [ - "-s ALLOW_MEMORY_GROWTH=1", - "-s DEFAULT_LIBRARY_FUNCS_TO_INCLUDE=[]", - "-s DISABLE_EXCEPTION_CATCHING=1", - "-s FILESYSTEM=0", - "-s EXIT_RUNTIME=0", - "-s EXPORTED_FUNCTIONS='[\"_malloc\", \"_free\"]'", - "-s EXTRA_EXPORTED_RUNTIME_METHODS='[\"cwrap\"]'", - "-s MODULARIZE=1", + linkopts = BASE_LINKOPTS + [ "-s EXPORT_NAME=WasmBackendModule", + ], + deps = [ + ":all_kernels", + ":backend", + ], +) + +cc_binary( + name = "tfjs-backend-wasm-simd.js", + srcs = ["backend.cc"] + KERNELS_WITH_KEEPALIVE, + linkopts = BASE_LINKOPTS, + deps = [ + ":all_kernels", + ":backend", + ], +) + +cc_binary( + name = "tfjs-backend-wasm-threaded-simd.js", + srcs = ["backend.cc"] + KERNELS_WITH_KEEPALIVE, + linkopts = BASE_LINKOPTS + [ + "-s EXPORT_NAME=WasmBackendModuleThreadedSimd", "-s MALLOC=emmalloc", + "-s USE_PTHREADS=1", + "-s PROXY_TO_PTHREAD=1", + # Many x86-64 processors have 2 threads per core, so we divide by 2. + "-s PTHREAD_POOL_SIZE=" + + "'Math.min(4, Math.max(1, (navigator.hardwareConcurrency || 1) / 2))'", ], deps = [ ":all_kernels", @@ -44,13 +75,13 @@ wasm_cc_binary( wasm_cc_binary( name = "tfjs-backend-wasm-simd", - cc_target = ":tfjs-backend-wasm.js", + cc_target = ":tfjs-backend-wasm-simd.js", simd = True, ) wasm_cc_binary( name = "tfjs-backend-wasm-threaded-simd", - cc_target = ":tfjs-backend-wasm.js", + cc_target = ":tfjs-backend-wasm-threaded-simd.js", simd = True, threads = "emscripten", ) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 4ff26f19be..ce0ef4f612 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -165,9 +165,9 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { { 'tfjs-backend-wasm.wasm': '/base/wasm-out/tfjs-backend-wasm.wasm', 'tfjs-backend-wasm-simd.wasm': - '/base/wasm-out/tfjs-backend-wasm.wasm', + '/base/wasm-out/tfjs-backend-wasm-simd.wasm', 'tfjs-backend-wasm-threaded-simd.wasm': - '/base/wasm-out/tfjs-backend-wasm.wasm' + '/base/wasm-out/tfjs-backend-wasm-threaded-simd.wasm' }, usePlatformFetch); let wasmPath: string; diff --git a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm-threaded-simd.d.ts b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm-threaded-simd.d.ts index c1d20303be..71ef9ef109 100644 --- a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm-threaded-simd.d.ts +++ b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm-threaded-simd.d.ts @@ -15,16 +15,13 @@ * ============================================================================= */ -export interface BackendWasmModule extends EmscriptenModule { - // Using the tfjs namespace to avoid conflict with emscripten's API. - tfjs: { - init(): void, - registerTensor(id: number, size: number, memoryOffset: number): void, - // Disposes the data behind the data bucket. - disposeData(id: number): void, - // Disposes the backend and all of its associated data. - dispose(): void, - } +import {BackendWasmModule} from './tfjs-backend-wasm'; + +export interface BackendWasmThreadedSimdModule extends BackendWasmModule { + PThread: { + // Terminates all webworkers + terminateAllThreads(): void, + }; } export interface WasmFactoryConfig {