Skip to content
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
* =============================================================================
*/
import './flags_wasm';

import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core';

Expand Down
30 changes: 30 additions & 0 deletions tfjs-backend-wasm/src/flags_wasm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {env} from '@tensorflow/tfjs-core';

const ENV = env();

/**
* True if SIMD is supported.
*/
// From: https://github.com/GoogleChromeLabs/wasm-feature-detect
ENV.registerFlag(
'WASM_HAS_SIMD_SUPPORT', async () => WebAssembly.validate(new Uint8Array([
0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3,
2, 1, 0, 10, 9, 1, 7, 0, 65, 0, 253, 15, 26, 11
])));
25 changes: 21 additions & 4 deletions tfjs-core/src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import {Platform} from './platforms/platform';
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';

type FlagValue = number|boolean;
type FlagEvaluationFn = (() => FlagValue)|(() => Promise<FlagValue>);
export type Flags = {
[featureName: string]: FlagValue
};
export type FlagRegistryEntry = {
evaluationFn: () => FlagValue;
evaluationFn: FlagEvaluationFn;
setHook?: (value: FlagValue) => void;
};

Expand Down Expand Up @@ -60,7 +61,7 @@ export class Environment {
}

registerFlag(
flagName: string, evaluationFn: () => FlagValue,
flagName: string, evaluationFn: FlagEvaluationFn,
setHook?: (value: FlagValue) => void) {
this.flagRegistry[flagName] = {evaluationFn, setHook};

Expand All @@ -74,12 +75,28 @@ export class Environment {
}
}

async getAsync(flagName: string): Promise<FlagValue> {
if (flagName in this.flags) {
return this.flags[flagName];
}

this.flags[flagName] = await this.evaluateFlag(flagName);
return this.flags[flagName];
}

get(flagName: string): FlagValue {
if (flagName in this.flags) {
return this.flags[flagName];
}

this.flags[flagName] = this.evaluateFlag(flagName);
const flagValue = this.evaluateFlag(flagName);
if (flagValue instanceof Promise) {
throw new Error(
`Flag ${flagName} cannot be synchronously evaluated. ` +
`Please use getAsync() instead.`);
}

this.flags[flagName] = flagValue;

return this.flags[flagName];
}
Expand Down Expand Up @@ -111,7 +128,7 @@ export class Environment {
}
}

private evaluateFlag(flagName: string): FlagValue {
private evaluateFlag(flagName: string): FlagValue|Promise<FlagValue> {
if (this.flagRegistry[flagName] == null) {
throw new Error(
`Cannot evaluate flag '${flagName}': no evaluation function found.`);
Expand Down
16 changes: 16 additions & 0 deletions tfjs-core/src/flags_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,19 @@ describe('IS_TEST', () => {
expect(tf.env().getBool('IS_TEST')).toBe(false);
});
});

describe('async flags test', () => {
const asyncFlagName = 'ASYNC_FLAG';
beforeEach(() => tf.env().registerFlag(asyncFlagName, async () => true));

afterEach(() => tf.env().reset());

it('evaluating async flag works', async () => {
const flagVal = await tf.env().getAsync(asyncFlagName);
expect(flagVal).toBe(true);
});

it('evaluating async flag synchronously fails', async () => {
expect(() => tf.env().get(asyncFlagName)).toThrow();
});
});