diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index 8717019c9bf..51308b3624b 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -14,6 +14,7 @@ * limitations under the License. * ============================================================================= */ +import './flags_wasm'; import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core'; diff --git a/tfjs-backend-wasm/src/flags_wasm.ts b/tfjs-backend-wasm/src/flags_wasm.ts new file mode 100644 index 00000000000..69c3c286cf0 --- /dev/null +++ b/tfjs-backend-wasm/src/flags_wasm.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {env} from '@tensorflow/tfjs-core'; + +const ENV = env(); + +/** + * True if SIMD is supported. + */ +// From: https://github.com/GoogleChromeLabs/wasm-feature-detect +ENV.registerFlag( + 'WASM_HAS_SIMD_SUPPORT', async () => WebAssembly.validate(new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, + 2, 1, 0, 10, 9, 1, 7, 0, 65, 0, 253, 15, 26, 11 + ]))); diff --git a/tfjs-core/src/environment.ts b/tfjs-core/src/environment.ts index abd8f993877..b94b5aec636 100644 --- a/tfjs-core/src/environment.ts +++ b/tfjs-core/src/environment.ts @@ -21,11 +21,12 @@ import {Platform} from './platforms/platform'; const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; type FlagValue = number|boolean; +type FlagEvaluationFn = (() => FlagValue)|(() => Promise); export type Flags = { [featureName: string]: FlagValue }; export type FlagRegistryEntry = { - evaluationFn: () => FlagValue; + evaluationFn: FlagEvaluationFn; setHook?: (value: FlagValue) => void; }; @@ -60,7 +61,7 @@ export class Environment { } registerFlag( - flagName: string, evaluationFn: () => FlagValue, + flagName: string, evaluationFn: FlagEvaluationFn, setHook?: (value: FlagValue) => void) { this.flagRegistry[flagName] = {evaluationFn, setHook}; @@ -74,12 +75,28 @@ export class Environment { } } + async getAsync(flagName: string): Promise { + if (flagName in this.flags) { + return this.flags[flagName]; + } + + this.flags[flagName] = await this.evaluateFlag(flagName); + return this.flags[flagName]; + } + get(flagName: string): FlagValue { if (flagName in this.flags) { return this.flags[flagName]; } - this.flags[flagName] = this.evaluateFlag(flagName); + const flagValue = this.evaluateFlag(flagName); + if (flagValue instanceof Promise) { + throw new Error( + `Flag ${flagName} cannot be synchronously evaluated. ` + + `Please use getAsync() instead.`); + } + + this.flags[flagName] = flagValue; return this.flags[flagName]; } @@ -111,7 +128,7 @@ export class Environment { } } - private evaluateFlag(flagName: string): FlagValue { + private evaluateFlag(flagName: string): FlagValue|Promise { if (this.flagRegistry[flagName] == null) { throw new Error( `Cannot evaluate flag '${flagName}': no evaluation function found.`); diff --git a/tfjs-core/src/flags_test.ts b/tfjs-core/src/flags_test.ts index a0510077b3f..ee656cda90d 100644 --- a/tfjs-core/src/flags_test.ts +++ b/tfjs-core/src/flags_test.ts @@ -107,3 +107,19 @@ describe('IS_TEST', () => { expect(tf.env().getBool('IS_TEST')).toBe(false); }); }); + +describe('async flags test', () => { + const asyncFlagName = 'ASYNC_FLAG'; + beforeEach(() => tf.env().registerFlag(asyncFlagName, async () => true)); + + afterEach(() => tf.env().reset()); + + it('evaluating async flag works', async () => { + const flagVal = await tf.env().getAsync(asyncFlagName); + expect(flagVal).toBe(true); + }); + + it('evaluating async flag synchronously fails', async () => { + expect(() => tf.env().get(asyncFlagName)).toThrow(); + }); +});