From ee9586986a61b051a96337a7ffcec17d3cea2cbd Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 14:33:55 -0400 Subject: [PATCH 01/18] initial commit of module writing tool --- tfjs/package.json | 7 ++ tfjs/src/tools/custom_bundle/cli.ts | 94 +++++++++++++++++++ tfjs/src/tools/custom_bundle/custom_module.ts | 63 +++++++++++++ .../custom_bundle/esm_module_provider.ts | 81 ++++++++++++++++ tfjs/src/tools/custom_bundle/tsconfig.json | 7 ++ tfjs/src/tools/custom_bundle/types.ts | 45 +++++++++ tfjs/yarn.lock | 24 ++++- 7 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 tfjs/src/tools/custom_bundle/cli.ts create mode 100644 tfjs/src/tools/custom_bundle/custom_module.ts create mode 100644 tfjs/src/tools/custom_bundle/esm_module_provider.ts create mode 100644 tfjs/src/tools/custom_bundle/tsconfig.json create mode 100644 tfjs/src/tools/custom_bundle/types.ts diff --git a/tfjs/package.json b/tfjs/package.json index 97c0c28745b..5ab511762f2 100644 --- a/tfjs/package.json +++ b/tfjs/package.json @@ -14,6 +14,9 @@ "url": "https://github.com/tensorflow/tfjs.git" }, "license": "Apache-2.0", + "bin": { + "tfjs-custom-bundle": "dist/tools/custom_bundle/cli.js" + }, "devDependencies": { "@babel/core": "^7.9.0", "@babel/preset-env": "^7.9.5", @@ -67,6 +70,8 @@ "build-backend-webgl-ci": "cd ../tfjs-backend-webgl && yarn && yarn build-ci", "build-deps": "yarn build-core && yarn build-layers && yarn build-converter && yarn build-data && yarn build-backend-cpu && yarn build-backend-webgl", "build-deps-ci": "yarn build-core-ci && yarn build-layers-ci && yarn build-converter-ci && yarn build-data-ci && yarn build-backend-cpu-ci && yarn build-backend-webgl-ci", + "build-cli": "tsc --project ./src/tools/custom_bundle/tsconfig.json && chmod +x ./dist/tools/custom_bundle/cli.js", + "run-custom-build": "ts-node -s ./src/tools/custom_bundle/cli.ts", "build-npm": "./scripts/build-npm.sh", "link-local": "yalc link", "publish-local": "yarn build-npm && yalc push", @@ -82,6 +87,8 @@ "@tensorflow/tfjs-core": "link:../tfjs-core", "@tensorflow/tfjs-data": "link:../tfjs-data", "@tensorflow/tfjs-layers": "link:../tfjs-layers", + "argparse": "^1.0.10", + "chalk": "^4.1.0", "core-js": "3", "regenerator-runtime": "^0.13.5" } diff --git a/tfjs/src/tools/custom_bundle/cli.ts b/tfjs/src/tools/custom_bundle/cli.ts new file mode 100644 index 00000000000..a08af065001 --- /dev/null +++ b/tfjs/src/tools/custom_bundle/cli.ts @@ -0,0 +1,94 @@ +#!/usr/bin/env node + +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Entry point for cli tool to build custom tfjs bundles + */ + +import * as argparse from 'argparse'; +import chalk from 'chalk'; +import * as fs from 'fs'; + +import {getCustomModuleString} from './custom_module'; +import {CustomTFJSBundleConfig} from './types'; +import {esmModuleProvider} from './esm_module_provider'; + +const DEFAULT_CUSTOM_BUNDLE_ARGS: Partial = { + entries: [], + models: [], + kernels: [], + forwardModeOnly: true, + backends: ['cpu', 'webgl'], +}; + +const parser = new argparse.ArgumentParser(); +parser.addArgument( + '--config', {help: 'path to custom bundle config file.', required: true}); + +function bail(errorMsg: string) { + console.log(chalk.red(errorMsg)); + process.exit(1); +} + +function validateArgs(): CustomTFJSBundleConfig { + const args = parser.parseArgs(); + const configFilePath = args.config; + if (!fs.existsSync(configFilePath)) { + bail(`Error: config file does not exist at ${configFilePath}`); + } + let config; + try { + config = JSON.parse(fs.readFileSync(configFilePath, 'utf-8')); + } catch (error) { + bail(`Error could not read/parse JSON config file. \n ${error.message}`); + } + + if (config.outputPath == null) { + bail('Error: config must specify "outputPath" property'); + } + + console.log(`Using custom bundle configuration from ${ + configFilePath}. Final config:`); + console.log(`${JSON.stringify(config, null, 2)}\n`); + + return Object.assign({}, DEFAULT_CUSTOM_BUNDLE_ARGS, config); +} + +function getKernelNamesForConfig(config: CustomTFJSBundleConfig) { + // Later on this will do a union of kernels from entries, models and kernels, + // (and kernels used by the converter itself) Currently we only support + // directly listing kernels. remember that this also needs to handle + // kernels used by gradients if forwardModeOnly is false. + return config.kernels; +} +function produceCustomTFJSModule( + kernels: string[], backends: string[], forwardModeOnly: boolean, + outputPath: string) { + const moduleStr = getCustomModuleString( + kernels, backends, forwardModeOnly, esmModuleProvider); + + console.log(`Writing custom tfjs module to ${outputPath}`); + fs.writeFileSync(outputPath, moduleStr); +} + +const customConfig = validateArgs(); +const kernelsToInclude = getKernelNamesForConfig(customConfig); +produceCustomTFJSModule( + kernelsToInclude, customConfig.backends, customConfig.forwardModeOnly, + customConfig.outputPath); diff --git a/tfjs/src/tools/custom_bundle/custom_module.ts b/tfjs/src/tools/custom_bundle/custom_module.ts new file mode 100644 index 00000000000..9725e9f6483 --- /dev/null +++ b/tfjs/src/tools/custom_bundle/custom_module.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ModuleProvider} from './types'; + +export function getCustomModuleString( + kernels: string[], + backends: string[], + forwardModeOnly: boolean, + moduleProvider: ModuleProvider, + ): string { + const result: string[] = []; + + addLine(result, moduleProvider.importCoreStr()); + addLine(result, moduleProvider.importConverterStr()); + + for (const backend of backends) { + addLine(result, `\n//backend = ${backend}`); + addLine(result, moduleProvider.importBackendStr(backend)); + for (const kernelName of kernels) { + const kernelImport = moduleProvider.importKernelStr(kernelName, backend); + addLine(result, kernelImport.importStatement); + addLine(result, registerKernelStr(kernelImport.kernelConfigId)); + } + } + + if (!forwardModeOnly) { + addLine(result, `\n//Gradients`); + for (const kernelName of kernels) { + const gradImport = moduleProvider.importGradientConfigStr(kernelName); + addLine(result, gradImport.importStatement); + addLine(result, registerGradientConfigStr(gradImport.gradConfigId)); + } + } + + return result.join('\n'); +} + +function addLine(target: string[], line: string) { + target.push(line); +} + +function registerKernelStr(kernelConfigId: string) { + return `registerKernel(${kernelConfigId});`; +} + +function registerGradientConfigStr(gradConfigId: string) { + return `registerGradient(${gradConfigId});`; +} diff --git a/tfjs/src/tools/custom_bundle/esm_module_provider.ts b/tfjs/src/tools/custom_bundle/esm_module_provider.ts new file mode 100644 index 00000000000..136799c1652 --- /dev/null +++ b/tfjs/src/tools/custom_bundle/esm_module_provider.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ModuleProvider, SupportedBackend} from './types'; + +/** + * A module provider to generate custom esm modules. + */ +export const esmModuleProvider: ModuleProvider = { + importCoreStr() { + return ` +import {registerKernel, registerGradient} from '@tensorflow/tfjs-core/dist/base'; +export * from '@tensorflow/tfjs-core/dist/base'; + `; + }, + + importConverterStr() { + // TODO(yassogba). Create a 'base' file for converter, + return `export * from '@tensorflow/tfjs-converter';`; + }, + + importBackendStr(backend: SupportedBackend) { + const backendPkg = getBackendPath(backend); + return `export * from '${backendPkg}/dist/base';`; + }, + + importKernelStr(kernelName: string, backend: SupportedBackend) { + // TODO(yassogba) validate whether the target file referenced by + // importStatement exists and warn the user if it doesn't. That could happen + // here or in an earlier validation phase that uses this function + + const backendPkg = getBackendPath(backend); + const kernelConfigId = `${kernelName}_${backend}`; + const importStatement = + `import {${kernelNameToVariableName(kernelName)}Config as ${ + kernelConfigId}} from '${backendPkg}/dist/kernels/${kernelName}';`; + + return {importStatement, kernelConfigId}; + }, + + importGradientConfigStr(kernelName: string) { + // TODO(yassogba) validate whether the target file referenced by + // importStatement exists and warn the user if it doesn't. That could happen + // here or in an earlier validation phase that uses this function + + const gradConfigId = `${kernelNameToVariableName(kernelName)}GradConfig`; + const importStatement = + `import {${gradConfigId}} from '@tensorflow/tfjs-core/dist/gradients/${ + kernelName}_grad';`; + + return {importStatement, gradConfigId}; + } +}; + +function getBackendPath(backend: SupportedBackend) { + switch (backend) { + case 'cpu': + return '@tensorflow/tfjs-backend-cpu'; + case 'webgl': + return '@tensorflow/tfjs-backend-webgl'; + default: + throw new Error(`Unsupported backend ${backend}`); + } +} + +function kernelNameToVariableName(kernelName: string) { + return kernelName.charAt(0).toLowerCase() + kernelName.slice(1); +} diff --git a/tfjs/src/tools/custom_bundle/tsconfig.json b/tfjs/src/tools/custom_bundle/tsconfig.json new file mode 100644 index 00000000000..f1ed5fb3e6f --- /dev/null +++ b/tfjs/src/tools/custom_bundle/tsconfig.json @@ -0,0 +1,7 @@ +{ + "extends": "../../../tsconfig.test", + "compilerOptions": { + "module": "commonjs", + "esModuleInterop": true + } +} diff --git a/tfjs/src/tools/custom_bundle/types.ts b/tfjs/src/tools/custom_bundle/types.ts new file mode 100644 index 00000000000..7fb3989b776 --- /dev/null +++ b/tfjs/src/tools/custom_bundle/types.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export enum SupportedBackends { + cpu = 'cpu', + webgl = 'webgl' +} +export type SupportedBackend = keyof typeof SupportedBackends; + +export interface CustomTFJSBundleConfig { + entries?: string[]; // paths to javascript files to walk + models?: string[]; // paths to model.json files to walk + backends?: SupportedBackend[]; // backends to include/use kernels from + forwardModeOnly?: boolean; // whether to drop gradients + outputPath: string; // path to output file + kernels?: string[]; // Kernels to include +} + +// Interface for an object that can provide functionality to generate +// a correct custom module for that build environment (e.g. node vs g3). +export interface ModuleProvider { + importCoreStr: () => string; + importConverterStr: () => string; + importBackendStr: (backendPkg: string) => string; + importKernelStr: (kernelName: string, backend: string) => { + importStatement: string, kernelConfigId: string + }; + importGradientConfigStr: (kernelName: string) => { + importStatement: string, gradConfigId: string + }; +} diff --git a/tfjs/yarn.lock b/tfjs/yarn.lock index c8aa0e0d28d..d66381a954b 100644 --- a/tfjs/yarn.lock +++ b/tfjs/yarn.lock @@ -975,7 +975,7 @@ ansi-styles@^3.2.1: dependencies: color-convert "^1.9.0" -ansi-styles@^4.0.0: +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.2.1" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.2.1.tgz#90ae75c424d008d2624c5bf29ead3177ebfcf359" integrity sha512-9VGjrMsG1vePxcSweQsN20KY/c4zN0h9fLjqAbwbPfahM3t+NL+M9HC8xeXG2I8pX5NoamTGNuomEUFI7fcUjA== @@ -1001,7 +1001,7 @@ arg@^4.1.0: resolved "https://registry.yarnpkg.com/arg/-/arg-4.1.3.tgz#269fc7ad5b8e42cb63c896d5666017261c144089" integrity sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA== -argparse@^1.0.7: +argparse@^1.0.10, argparse@^1.0.7: version "1.0.10" resolved "https://registry.yarnpkg.com/argparse/-/argparse-1.0.10.tgz#bcd6791ea5ae09725e17e5ad988134cd40b3d911" integrity sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg== @@ -1446,6 +1446,14 @@ chalk@^2.0.0, chalk@^2.3.0, chalk@^2.4.1: escape-string-regexp "^1.0.5" supports-color "^5.3.0" +chalk@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.0.tgz#4e14870a618d9e2edd97dd8345fd9d9dc315646a" + integrity sha512-qwx12AxXe2Q5xQ43Ac//I6v5aXTipYrSESdOgzrN+9XjgEpyjpKuvSGaN4qE93f7TQTlerQQ8S+EQ0EyDoVL1A== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + chokidar@^3.0.0: version "3.4.0" resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.4.0.tgz#b30611423ce376357c765b9b8f904b9fba3c0be8" @@ -2305,6 +2313,11 @@ has-flag@^3.0.0: resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd" integrity sha1-tdRU3CGZriJWmfNGfloH87lVuv0= +has-flag@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" + integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== + has-symbols@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.0.tgz#ba1a8f1af2a0fc39650f5c850367704122063b44" @@ -4261,6 +4274,13 @@ supports-color@^6.1.0: dependencies: has-flag "^3.0.0" +supports-color@^7.1.0: + version "7.1.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.1.0.tgz#68e32591df73e25ad1c4b49108a2ec507962bfd1" + integrity sha512-oRSIpR8pxT1Wr2FquTNnGet79b3BWljqOuoW/h4oBhxJ/HUbX5nX6JSruTkvXDCFMwDPvsaTTbvMLKZWSy0R5g== + dependencies: + has-flag "^4.0.0" + terser@^4.6.2: version "4.6.11" resolved "https://registry.yarnpkg.com/terser/-/terser-4.6.11.tgz#12ff99fdd62a26de2a82f508515407eb6ccd8a9f" From eacc171ba1280f4a8a3facc9cfb2fd0fb9b04c59 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 14:34:20 -0400 Subject: [PATCH 02/18] remove chaining api from browser.ts --- tfjs-core/src/ops/browser.ts | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tfjs-core/src/ops/browser.ts b/tfjs-core/src/ops/browser.ts index 22aba683829..e060db10fa1 100644 --- a/tfjs-core/src/ops/browser.ts +++ b/tfjs-core/src/ops/browser.ts @@ -23,6 +23,9 @@ import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {PixelData, TensorLike} from '../types'; +import {cast} from './cast'; +import {max} from './max'; +import {min} from './min'; import {op} from './operation'; import {tensor3d} from './tensor3d'; @@ -177,7 +180,7 @@ export async function toPixels( if (!(img instanceof Tensor)) { // Assume int32 if user passed a native array. const originalImgTensor = $img; - $img = originalImgTensor.toInt(); + $img = cast(originalImgTensor, 'int32'); originalImgTensor.dispose(); } if ($img.rank !== 2 && $img.rank !== 3) { @@ -194,26 +197,26 @@ export async function toPixels( } const data = await $img.data(); - const minTensor = $img.min(); - const maxTensor = $img.max(); + const minTensor = min($img); + const maxTensor = max($img); const vals = await Promise.all([minTensor.data(), maxTensor.data()]); const minVals = vals[0]; const maxVals = vals[1]; - const min = minVals[0]; - const max = maxVals[0]; + const minVal = minVals[0]; + const maxVal = maxVals[0]; minTensor.dispose(); maxTensor.dispose(); if ($img.dtype === 'float32') { - if (min < 0 || max > 1) { + if (minVal < 0 || maxVal > 1) { throw new Error( `Tensor values for a float32 Tensor must be in the ` + - `range [0 - 1] but got range [${min} - ${max}].`); + `range [0 - 1] but got range [${minVal} - ${maxVal}].`); } } else if ($img.dtype === 'int32') { - if (min < 0 || max > 255) { + if (minVal < 0 || maxVal > 255) { throw new Error( `Tensor values for a int32 Tensor must be in the ` + - `range [0 - 255] but got range [${min} - ${max}].`); + `range [0 - 255] but got range [${minVal} - ${maxVal}].`); } } else { throw new Error( From 31ee1644dfebb730bb450e08ab6ef972199d1f5a Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 21:23:29 -0400 Subject: [PATCH 03/18] split base from index for core, cpu and webgl --- tfjs-backend-cpu/src/base.ts | 7 +- tfjs-backend-cpu/src/index.ts | 8 +- tfjs-backend-webgl/src/base.ts | 33 ++++++ tfjs-backend-webgl/src/index.ts | 16 +-- tfjs-backend-webgl/src/kernel_utils/shared.ts | 2 +- tfjs-core/src/base.ts | 110 ++++++++++++++++++ tfjs-core/src/index.ts | 91 +-------------- 7 files changed, 155 insertions(+), 112 deletions(-) create mode 100644 tfjs-backend-webgl/src/base.ts create mode 100644 tfjs-core/src/base.ts diff --git a/tfjs-backend-cpu/src/base.ts b/tfjs-backend-cpu/src/base.ts index 17fa0f9f063..0693f4d88cb 100644 --- a/tfjs-backend-cpu/src/base.ts +++ b/tfjs-backend-cpu/src/base.ts @@ -17,10 +17,15 @@ /* * base.ts contains all the exports from tfjs-backend-cpu - * that do not trigger side effects. + * but skips auto-kernel registration */ +import {registerBackend} from '@tensorflow/tfjs-core'; +import {MathBackendCPU} from './backend_cpu'; import * as shared from './shared'; export {MathBackendCPU} from './backend_cpu'; export {version as version_cpu} from './version'; export {shared}; + +// Side effects for default initialization of MathBackendCPU +registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */); diff --git a/tfjs-backend-cpu/src/index.ts b/tfjs-backend-cpu/src/index.ts index 53ca24399c5..abf89246ec3 100644 --- a/tfjs-backend-cpu/src/index.ts +++ b/tfjs-backend-cpu/src/index.ts @@ -15,12 +15,6 @@ * ============================================================================= */ -import {registerBackend} from '@tensorflow/tfjs-core'; -import {MathBackendCPU} from './base'; - -// Side effects for default initialization of MathBackendCPU -registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */); -import './register_all_kernels'; - // All exports from this package should be in base. export * from './base'; +import './register_all_kernels'; diff --git a/tfjs-backend-webgl/src/base.ts b/tfjs-backend-webgl/src/base.ts new file mode 100644 index 00000000000..f2412eab16a --- /dev/null +++ b/tfjs-backend-webgl/src/base.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// base.ts is the webgl backend without auto kernel registration. + +import {device_util, registerBackend} from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from './backend_webgl'; +export {version as version_webgl} from './version'; + +if (device_util.isBrowser()) { + registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */); +} + +// Export webgl utilities +export * from './webgl'; + +// Export forceHalfFlost under webgl namespace for the union bundle. +import {forceHalfFloat} from './webgl'; +export const webgl = {forceHalfFloat}; diff --git a/tfjs-backend-webgl/src/index.ts b/tfjs-backend-webgl/src/index.ts index 14f5863c480..abf89246ec3 100644 --- a/tfjs-backend-webgl/src/index.ts +++ b/tfjs-backend-webgl/src/index.ts @@ -15,18 +15,6 @@ * ============================================================================= */ -import {device_util, registerBackend} from '@tensorflow/tfjs-core'; -import {MathBackendWebGL} from './backend_webgl'; -export {version as version_webgl} from './version'; - -if (device_util.isBrowser()) { - registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */); -} +// All exports from this package should be in base. +export * from './base'; import './register_all_kernels'; - -// Export webgl utilities -export * from './webgl'; - -// Export forceHalfFlost under webgl namespace for the union bundle. -import {forceHalfFloat} from './webgl'; -export const webgl = {forceHalfFloat}; diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts index c5b51211eef..0ac71e981c0 100644 --- a/tfjs-backend-webgl/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -18,7 +18,7 @@ // Import shared functionality from tfjs-backend-cpu without triggering // side effects. // tslint:disable-next-line: no-imports-from-dist -import {shared} from '@tensorflow/tfjs-backend-cpu/dist/base'; +import * as shared from '@tensorflow/tfjs-backend-cpu/dist/shared'; const {maxImpl: maxImplCPU, transposeImpl: transposeImplCPU} = shared; diff --git a/tfjs-core/src/base.ts b/tfjs-core/src/base.ts new file mode 100644 index 00000000000..04623c446ae --- /dev/null +++ b/tfjs-core/src/base.ts @@ -0,0 +1,110 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// base.ts is tfjs-core without auto registration of gradients or chained ops. + +/** + * @fileoverview + * @suppress {partialAlias} Optimization disabled due to passing the module + * object into a function below: + * + * import * as ops from './ops/ops'; + * setOpHandler(ops); + */ + +// Engine is the global singleton that needs to be initialized before the rest +// of the app. +import './engine'; +// Register backend-agnostic flags. +import './flags'; +// Register platforms +import './platforms/platform_browser'; +import './platforms/platform_node'; + +// Serialization. +import * as io from './io/io'; +import * as math from './math'; +import * as browser from './ops/browser'; +import * as gather_util from './ops/gather_nd_util'; +import * as scatter_util from './ops/scatter_nd_util'; +import * as slice_util from './ops/slice_util'; +import * as serialization from './serialization'; +import {setOpHandler} from './tensor'; +import * as tensor_util from './tensor_util'; +import * as test_util from './test_util'; +import * as util from './util'; +import {version} from './version'; + +export {InferenceModel, MetaGraph, MetaGraphInfo, ModelPredictConfig, ModelTensorInfo, SavedModelTensorInfo, SignatureDef, SignatureDefInfo} from './model_types'; +// Optimizers. +export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; +export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; +export {AdamOptimizer} from './optimizers/adam_optimizer'; +export {AdamaxOptimizer} from './optimizers/adamax_optimizer'; +export {MomentumOptimizer} from './optimizers/momentum_optimizer'; +export {Optimizer} from './optimizers/optimizer'; +export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; +export {SGDOptimizer} from './optimizers/sgd_optimizer'; +export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; +export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; +export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types'; + +export * from './ops/ops'; +export {Reduction} from './ops/loss_ops_utils'; + +export * from './train'; +export * from './globals'; +export * from './kernel_registry'; +export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; + +export {TimingInfo, MemoryInfo, ForwardFunc} from './engine'; +export {Environment, env, ENV} from './environment'; +export {Platform} from './platforms/platform'; + +export {version as version_core}; + +// Top-level method exports. +export {nextFrame} from './browser_util'; + +// Second level exports. +import * as backend_util from './backends/backend_util'; +import * as device_util from './device_util'; +export { + browser, + io, + math, + serialization, + test_util, + util, + backend_util, + tensor_util, + slice_util, + gather_util, + scatter_util, + device_util +}; + +import * as kernel_impls from './backends/kernel_impls'; +export {kernel_impls}; +// Backend specific. +export {KernelBackend, BackendTimingInfo, DataMover, DataStorage} from './backends/backend'; + +import * as ops from './ops/ops'; +setOpHandler(ops); + +// Export all kernel names / info. +export * from './kernel_names'; diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 49a95123436..e965553b3d4 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -15,98 +15,11 @@ * ============================================================================= */ -/** - * @fileoverview - * @suppress {partialAlias} Optimization disabled due to passing the module - * object into a function below: - * - * import * as ops from './ops/ops'; - * setOpHandler(ops); - */ +// All exports from this package should be in base. +export * from './base'; -// Engine is the global singleton that needs to be initialized before the rest -// of the app. -import './engine'; -// Register backend-agnostic flags. -import './flags'; // Register all the gradients. import './register_all_gradients'; -import './platforms/platform_browser'; -import './platforms/platform_node'; - -// Serialization. -import * as io from './io/io'; -import * as math from './math'; -import * as browser from './ops/browser'; -import * as gather_util from './ops/gather_nd_util'; -import * as scatter_util from './ops/scatter_nd_util'; -import * as slice_util from './ops/slice_util'; -import * as serialization from './serialization'; -import {setOpHandler} from './tensor'; -import * as tensor_util from './tensor_util'; -import * as test_util from './test_util'; -import * as util from './util'; -import {version} from './version'; - -export {InferenceModel, MetaGraph, MetaGraphInfo, ModelPredictConfig, ModelTensorInfo, SavedModelTensorInfo, SignatureDef, SignatureDefInfo} from './model_types'; -// Optimizers. -export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; -export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; -export {AdamOptimizer} from './optimizers/adam_optimizer'; -export {AdamaxOptimizer} from './optimizers/adamax_optimizer'; -export {MomentumOptimizer} from './optimizers/momentum_optimizer'; -export {Optimizer} from './optimizers/optimizer'; -export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; -export {SGDOptimizer} from './optimizers/sgd_optimizer'; -export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; -export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; -export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types'; - -export * from './ops/ops'; -export {Reduction} from './ops/loss_ops_utils'; - -export * from './train'; -export * from './globals'; -export * from './kernel_registry'; -export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; - -export {TimingInfo, MemoryInfo, ForwardFunc} from './engine'; -export {Environment, env, ENV} from './environment'; -export {Platform} from './platforms/platform'; - -export {version as version_core}; - -// Top-level method exports. -export {nextFrame} from './browser_util'; - -// Second level exports. -import * as backend_util from './backends/backend_util'; -import * as device_util from './device_util'; -export { - browser, - io, - math, - serialization, - test_util, - util, - backend_util, - tensor_util, - slice_util, - gather_util, - scatter_util, - device_util -}; - -import * as kernel_impls from './backends/kernel_impls'; -export {kernel_impls}; -// Backend specific. -export {KernelBackend, BackendTimingInfo, DataMover, DataStorage} from './backends/backend'; - -import * as ops from './ops/ops'; -setOpHandler(ops); - -// Export all kernel names / info. -export * from './kernel_names'; // Import all op chainers and add type info to Tensor. import './public/chained_ops/register_all_chained_ops'; From 886ef69f56e5bfa03147b79ef10d3b9d22df4e92 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 21:23:36 -0400 Subject: [PATCH 04/18] fix import --- tfjs/src/tools/custom_bundle/cli.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 tfjs/src/tools/custom_bundle/cli.ts diff --git a/tfjs/src/tools/custom_bundle/cli.ts b/tfjs/src/tools/custom_bundle/cli.ts old mode 100644 new mode 100755 index a08af065001..6a24eb42794 --- a/tfjs/src/tools/custom_bundle/cli.ts +++ b/tfjs/src/tools/custom_bundle/cli.ts @@ -22,7 +22,7 @@ */ import * as argparse from 'argparse'; -import chalk from 'chalk'; +import * as chalk from 'chalk'; import * as fs from 'fs'; import {getCustomModuleString} from './custom_module'; From 5d1ac5fd1ab6e996cc7cf49687f1d304f6986a62 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 21:36:38 -0400 Subject: [PATCH 05/18] build tools separately. --- tfjs/package.json | 2 +- tfjs/src/tools/custom_bundle/tsconfig.json | 7 ------- tfjs/{src => }/tools/custom_bundle/cli.ts | 0 tfjs/{src => }/tools/custom_bundle/custom_module.ts | 0 .../tools/custom_bundle/esm_module_provider.ts | 0 tfjs/tools/custom_bundle/tsconfig.json | 10 ++++++++++ tfjs/{src => }/tools/custom_bundle/types.ts | 0 7 files changed, 11 insertions(+), 8 deletions(-) delete mode 100644 tfjs/src/tools/custom_bundle/tsconfig.json rename tfjs/{src => }/tools/custom_bundle/cli.ts (100%) rename tfjs/{src => }/tools/custom_bundle/custom_module.ts (100%) rename tfjs/{src => }/tools/custom_bundle/esm_module_provider.ts (100%) create mode 100644 tfjs/tools/custom_bundle/tsconfig.json rename tfjs/{src => }/tools/custom_bundle/types.ts (100%) diff --git a/tfjs/package.json b/tfjs/package.json index 5ab511762f2..fa0d80069fa 100644 --- a/tfjs/package.json +++ b/tfjs/package.json @@ -70,7 +70,7 @@ "build-backend-webgl-ci": "cd ../tfjs-backend-webgl && yarn && yarn build-ci", "build-deps": "yarn build-core && yarn build-layers && yarn build-converter && yarn build-data && yarn build-backend-cpu && yarn build-backend-webgl", "build-deps-ci": "yarn build-core-ci && yarn build-layers-ci && yarn build-converter-ci && yarn build-data-ci && yarn build-backend-cpu-ci && yarn build-backend-webgl-ci", - "build-cli": "tsc --project ./src/tools/custom_bundle/tsconfig.json && chmod +x ./dist/tools/custom_bundle/cli.js", + "build-cli": "tsc --project ./tools/custom_bundle/tsconfig.json && chmod +x ./dist/tools/custom_bundle/cli.js", "run-custom-build": "ts-node -s ./src/tools/custom_bundle/cli.ts", "build-npm": "./scripts/build-npm.sh", "link-local": "yalc link", diff --git a/tfjs/src/tools/custom_bundle/tsconfig.json b/tfjs/src/tools/custom_bundle/tsconfig.json deleted file mode 100644 index f1ed5fb3e6f..00000000000 --- a/tfjs/src/tools/custom_bundle/tsconfig.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "extends": "../../../tsconfig.test", - "compilerOptions": { - "module": "commonjs", - "esModuleInterop": true - } -} diff --git a/tfjs/src/tools/custom_bundle/cli.ts b/tfjs/tools/custom_bundle/cli.ts similarity index 100% rename from tfjs/src/tools/custom_bundle/cli.ts rename to tfjs/tools/custom_bundle/cli.ts diff --git a/tfjs/src/tools/custom_bundle/custom_module.ts b/tfjs/tools/custom_bundle/custom_module.ts similarity index 100% rename from tfjs/src/tools/custom_bundle/custom_module.ts rename to tfjs/tools/custom_bundle/custom_module.ts diff --git a/tfjs/src/tools/custom_bundle/esm_module_provider.ts b/tfjs/tools/custom_bundle/esm_module_provider.ts similarity index 100% rename from tfjs/src/tools/custom_bundle/esm_module_provider.ts rename to tfjs/tools/custom_bundle/esm_module_provider.ts diff --git a/tfjs/tools/custom_bundle/tsconfig.json b/tfjs/tools/custom_bundle/tsconfig.json new file mode 100644 index 00000000000..a981dfa447b --- /dev/null +++ b/tfjs/tools/custom_bundle/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../../tsconfig.test", + "include": [ + "../../tools/" + ], + "compilerOptions": { + "module": "commonjs", + "outDir": "../../dist/tools/custom_bundle" + } +} diff --git a/tfjs/src/tools/custom_bundle/types.ts b/tfjs/tools/custom_bundle/types.ts similarity index 100% rename from tfjs/src/tools/custom_bundle/types.ts rename to tfjs/tools/custom_bundle/types.ts From 06372927b426cb3d1757356e038b7ad9a2aa342e Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 21:37:04 -0400 Subject: [PATCH 06/18] set side effects property for core, cpu, webgl --- tfjs-backend-cpu/package.json | 7 ++++++- tfjs-backend-webgl/package.json | 8 +++++++- tfjs-core/package.json | 10 +++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-cpu/package.json b/tfjs-backend-cpu/package.json index 8fed43cf62d..86efac109d8 100644 --- a/tfjs-backend-cpu/package.json +++ b/tfjs-backend-cpu/package.json @@ -70,5 +70,10 @@ "browser": { "util": false, "crypto": false - } + }, + "sideEffects": [ + "./dist/register_all_kernels.js", + "./dist/base.js", + "./dist/index.js" + ] } diff --git a/tfjs-backend-webgl/package.json b/tfjs-backend-webgl/package.json index 64b91654b52..f2de6562434 100644 --- a/tfjs-backend-webgl/package.json +++ b/tfjs-backend-webgl/package.json @@ -80,5 +80,11 @@ "browser": { "util": false, "crypto": false - } + }, + "sideEffects": [ + "./dist/register_all_kernels.js", + "./dist/flags_webgl.js", + "./dist/base.js", + "./dist/index.js" + ] } diff --git a/tfjs-core/package.json b/tfjs-core/package.json index 1d73f344c78..bb226ba8468 100644 --- a/tfjs-core/package.json +++ b/tfjs-core/package.json @@ -91,5 +91,13 @@ "node-fetch": false, "util": false, "crypto": false - } + }, + "sideEffects": [ + "./dist/index.js", + "./dist/engine.js", + "./dist/flags.js", + "./dist/platforms/*.js", + "./dist/register_all_gradients.js", + "./dist/public/chained_ops/*.js" + ] } From 55addb203a2fb577501a4dbb3af8372f6ba70668 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 22:39:17 -0400 Subject: [PATCH 07/18] add tests for tools --- tfjs/package.json | 1 + tfjs/run_tools_tests.ts | 32 ++++ .../tools/custom_bundle/custom_module_test.ts | 165 ++++++++++++++++++ .../custom_bundle/esm_module_provider.ts | 1 - .../custom_bundle/esm_module_provider_test.ts | 63 +++++++ 5 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 tfjs/run_tools_tests.ts create mode 100644 tfjs/tools/custom_bundle/custom_module_test.ts create mode 100644 tfjs/tools/custom_bundle/esm_module_provider_test.ts diff --git a/tfjs/package.json b/tfjs/package.json index fa0d80069fa..63683a0df16 100644 --- a/tfjs/package.json +++ b/tfjs/package.json @@ -78,6 +78,7 @@ "publish-npm": "npm publish", "lint": "tslint -p . -t verbose", "test": "yarn && yarn build-deps && yarn build && karma start", + "test-tools": "ts-node --project ./tools/custom_bundle/tsconfig.json run_tools_tests.ts", "test-ci": "./scripts/test-ci.sh" }, "dependencies": { diff --git a/tfjs/run_tools_tests.ts b/tfjs/run_tools_tests.ts new file mode 100644 index 00000000000..cd7e46bc933 --- /dev/null +++ b/tfjs/run_tools_tests.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// tslint:disable-next-line:no-require-imports +const jasmineCtor = require('jasmine'); +// tslint:disable-next-line:no-require-imports + +Error.stackTraceLimit = Infinity; + +process.on('unhandledRejection', e => { + throw e; +}); +const toolsTests = './tools/**/*_test.ts'; + +const runner = new jasmineCtor(); +runner.loadConfig({spec_files: [toolsTests], random: false}); + +runner.execute(); diff --git a/tfjs/tools/custom_bundle/custom_module_test.ts b/tfjs/tools/custom_bundle/custom_module_test.ts new file mode 100644 index 00000000000..ce7549fc738 --- /dev/null +++ b/tfjs/tools/custom_bundle/custom_module_test.ts @@ -0,0 +1,165 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getCustomModuleString} from './custom_module'; +import {ModuleProvider} from './types'; + +const mockModuleProvider: ModuleProvider = { + importCoreStr: () => 'import CORE', + importConverterStr: () => 'import CONVERTER', + importBackendStr: (name: string) => `import BACKEND ${name}`, + importKernelStr: (kernelName: string, backend: string) => ({ + importStatement: `import KERNEL ${kernelName} from BACKEND ${backend}`, + kernelConfigId: `${kernelName}_${backend}` + }), + importGradientConfigStr: (kernel: string) => ({ + importStatement: `import GRADIENT ${kernel}`, + gradConfigId: `${kernel}_GRAD_CONFIG`, + }), +}; + +describe('ESM Module Provider forwardModeOnly=true', () => { + const forwardModeOnly = true; + it('one kernel, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd'], forwardModeOnly, mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); + + it('one kernel, two backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + + expect(res).toContain('import BACKEND SlowBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd'); + expect(res).toContain('registerKernel(MathKrnl_SlowBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); + + it('two kernels, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + expect(res).toContain('registerKernel(MathKrn2_FastBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); + + it('two kernels, two backends', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + expect(res).toContain('registerKernel(MathKrn2_FastBcknd)'); + + expect(res).toContain('import BACKEND SlowBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd'); + expect(res).toContain('import KERNEL MathKrn2 from BACKEND SlowBcknd'); + expect(res).toContain('registerKernel(MathKrnl_SlowBcknd)'); + expect(res).toContain('registerKernel(MathKrn2_SlowBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); +}); + +describe('ESM Module Provider forwardModeOnly=false', () => { + const forwardModeOnly = false; + + it('one kernel, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd'], forwardModeOnly, mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + }); + + it('one kernel, two backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + + const gradIndex = res.indexOf('GRADIENT'); + expect(res.indexOf('GRADIENT', gradIndex + 1)) + .toBe(-1, `Gradient import appears twice in:\n ${res}`); + }); + + it('two kernels, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + + expect(res).toContain('import GRADIENT MathKrn2'); + expect(res).toContain('registerGradient(MathKrn2_GRAD_CONFIG)'); + }); + + it('two kernels, two backends', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + + expect(res).toContain('import GRADIENT MathKrn2'); + expect(res).toContain('registerGradient(MathKrn2_GRAD_CONFIG)'); + }); +}); diff --git a/tfjs/tools/custom_bundle/esm_module_provider.ts b/tfjs/tools/custom_bundle/esm_module_provider.ts index 136799c1652..7e3e51fd494 100644 --- a/tfjs/tools/custom_bundle/esm_module_provider.ts +++ b/tfjs/tools/custom_bundle/esm_module_provider.ts @@ -28,7 +28,6 @@ export * from '@tensorflow/tfjs-core/dist/base'; }, importConverterStr() { - // TODO(yassogba). Create a 'base' file for converter, return `export * from '@tensorflow/tfjs-converter';`; }, diff --git a/tfjs/tools/custom_bundle/esm_module_provider_test.ts b/tfjs/tools/custom_bundle/esm_module_provider_test.ts new file mode 100644 index 00000000000..0466c0f67da --- /dev/null +++ b/tfjs/tools/custom_bundle/esm_module_provider_test.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {esmModuleProvider} from './esm_module_provider'; + +describe('ESM Module Provider', () => { + it('importCoreStr', () => { + const res = esmModuleProvider.importCoreStr(); + expect(res).toContain( + // tslint:disable-next-line: max-line-length + `import {registerKernel, registerGradient} from '@tensorflow/tfjs-core/dist/base'`); + expect(res).toContain(`export * from '@tensorflow/tfjs-core/dist/base';`); + }); + + // tslint:disable-next-line: ban + it('importConverterStr', () => { + const res = esmModuleProvider.importConverterStr(); + expect(res).toBe(`export * from '@tensorflow/tfjs-converter';`); + }); + + it('importBackendStr cpu', () => { + const res = esmModuleProvider.importBackendStr('cpu'); + expect(res).toBe(`export * from '@tensorflow/tfjs-backend-cpu/dist/base';`); + }); + + it('importBackendStr webgl', () => { + const res = esmModuleProvider.importBackendStr('webgl'); + expect(res).toBe( + `export * from '@tensorflow/tfjs-backend-webgl/dist/base';`); + }); + + it('importKernelStr Max cpu', () => { + const res = esmModuleProvider.importKernelStr('Max', 'cpu'); + expect(res.importStatement).toContain('import {maxConfig as Max_cpu}'); + expect(res.importStatement) + .toContain(`from '@tensorflow/tfjs-backend-cpu/dist/kernels/Max'`); + + expect(res.kernelConfigId).toBe('Max_cpu'); + }); + + it('importGradientConfigStr Max', () => { + const res = esmModuleProvider.importGradientConfigStr('Max'); + expect(res.importStatement).toContain('import {maxGradConfig}'); + expect(res.importStatement) + .toContain(`from '@tensorflow/tfjs-core/dist/gradients/Max_grad'`); + + expect(res.gradConfigId).toBe('maxGradConfig'); + }); +}); From 672c670415a11b82c0a1955c5e9db1fb33152725 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 29 Jul 2020 22:50:29 -0400 Subject: [PATCH 08/18] add more validation --- tfjs/scripts/test-ci.sh | 1 + tfjs/tools/custom_bundle/cli.ts | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tfjs/scripts/test-ci.sh b/tfjs/scripts/test-ci.sh index 85f5fd0cb50..8316f8414d6 100755 --- a/tfjs/scripts/test-ci.sh +++ b/tfjs/scripts/test-ci.sh @@ -18,6 +18,7 @@ set -e yarn karma start --browsers='bs_firefox_mac,bs_chrome_mac' --singleRun +yarn test-tools # cd integration_tests # yarn benchmark-cloud # Reinstall the following line once https://github.com/tensorflow/tfjs/pull/1663 diff --git a/tfjs/tools/custom_bundle/cli.ts b/tfjs/tools/custom_bundle/cli.ts index 6a24eb42794..60bcbb367fb 100755 --- a/tfjs/tools/custom_bundle/cli.ts +++ b/tfjs/tools/custom_bundle/cli.ts @@ -26,7 +26,7 @@ import * as chalk from 'chalk'; import * as fs from 'fs'; import {getCustomModuleString} from './custom_module'; -import {CustomTFJSBundleConfig} from './types'; +import {CustomTFJSBundleConfig, SupportedBackends} from './types'; import {esmModuleProvider} from './esm_module_provider'; const DEFAULT_CUSTOM_BUNDLE_ARGS: Partial = { @@ -67,7 +67,24 @@ function validateArgs(): CustomTFJSBundleConfig { configFilePath}. Final config:`); console.log(`${JSON.stringify(config, null, 2)}\n`); - return Object.assign({}, DEFAULT_CUSTOM_BUNDLE_ARGS, config); + const finalConfig = Object.assign({}, DEFAULT_CUSTOM_BUNDLE_ARGS, config); + + if (finalConfig.entries.length !== 0) { + bail('Error: config.entries not yet supported'); + } + + if (finalConfig.models.length !== 0) { + bail('Error: config.models not yet supported'); + } + + for (const requestedBackend of finalConfig.backends) { + if (requestedBackend !== SupportedBackends.cpu && + requestedBackend !== SupportedBackends.webgl) { + bail(`Error: Unsupported backend specified '${requestedBackend}'`); + } + } + + return finalConfig; } function getKernelNamesForConfig(config: CustomTFJSBundleConfig) { From ffdbe6c82e673e731c8a0ee7f2255f50d989505d Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Thu, 30 Jul 2020 11:22:36 -0400 Subject: [PATCH 09/18] split out required side effects for core --- tfjs-core/package.json | 4 ++- tfjs-core/src/base.ts | 9 ------- tfjs-core/src/base_side_effects.ts | 27 +++++++++++++++++++ tfjs-core/src/index.ts | 6 +++-- .../custom_bundle/esm_module_provider.ts | 1 + .../custom_bundle/esm_module_provider_test.ts | 2 ++ 6 files changed, 37 insertions(+), 12 deletions(-) create mode 100644 tfjs-core/src/base_side_effects.ts diff --git a/tfjs-core/package.json b/tfjs-core/package.json index bb226ba8468..74ef2a84f50 100644 --- a/tfjs-core/package.json +++ b/tfjs-core/package.json @@ -94,10 +94,12 @@ }, "sideEffects": [ "./dist/index.js", + "./dist/base_side_effects.js", "./dist/engine.js", "./dist/flags.js", "./dist/platforms/*.js", "./dist/register_all_gradients.js", - "./dist/public/chained_ops/*.js" + "./dist/public/chained_ops/*.js", + "./dist/io/*.js" ] } diff --git a/tfjs-core/src/base.ts b/tfjs-core/src/base.ts index 04623c446ae..831f00dee0a 100644 --- a/tfjs-core/src/base.ts +++ b/tfjs-core/src/base.ts @@ -26,15 +26,6 @@ * setOpHandler(ops); */ -// Engine is the global singleton that needs to be initialized before the rest -// of the app. -import './engine'; -// Register backend-agnostic flags. -import './flags'; -// Register platforms -import './platforms/platform_browser'; -import './platforms/platform_node'; - // Serialization. import * as io from './io/io'; import * as math from './math'; diff --git a/tfjs-core/src/base_side_effects.ts b/tfjs-core/src/base_side_effects.ts new file mode 100644 index 00000000000..f645809a7c1 --- /dev/null +++ b/tfjs-core/src/base_side_effects.ts @@ -0,0 +1,27 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// Required side effectful code for tfjs-core (in any build) + +// Engine is the global singleton that needs to be initialized before the rest +// of the app. +import './engine'; +// Register backend-agnostic flags. +import './flags'; +// Register platforms +import './platforms/platform_browser'; +import './platforms/platform_node'; diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index e965553b3d4..971236dbf3c 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -15,11 +15,13 @@ * ============================================================================= */ +// Required side effectful code. +import './base_side_effects'; // All exports from this package should be in base. export * from './base'; +// TEMP comment these out in this branch. In 3.x these will be removed // Register all the gradients. -import './register_all_gradients'; - +// import './register_all_gradients'; // Import all op chainers and add type info to Tensor. import './public/chained_ops/register_all_chained_ops'; diff --git a/tfjs/tools/custom_bundle/esm_module_provider.ts b/tfjs/tools/custom_bundle/esm_module_provider.ts index 7e3e51fd494..24d31cdfc5d 100644 --- a/tfjs/tools/custom_bundle/esm_module_provider.ts +++ b/tfjs/tools/custom_bundle/esm_module_provider.ts @@ -23,6 +23,7 @@ export const esmModuleProvider: ModuleProvider = { importCoreStr() { return ` import {registerKernel, registerGradient} from '@tensorflow/tfjs-core/dist/base'; +import '@tensorflow/tfjs-core/dist/base_side_effects'; export * from '@tensorflow/tfjs-core/dist/base'; `; }, diff --git a/tfjs/tools/custom_bundle/esm_module_provider_test.ts b/tfjs/tools/custom_bundle/esm_module_provider_test.ts index 0466c0f67da..6c72070eb1e 100644 --- a/tfjs/tools/custom_bundle/esm_module_provider_test.ts +++ b/tfjs/tools/custom_bundle/esm_module_provider_test.ts @@ -23,6 +23,8 @@ describe('ESM Module Provider', () => { expect(res).toContain( // tslint:disable-next-line: max-line-length `import {registerKernel, registerGradient} from '@tensorflow/tfjs-core/dist/base'`); + expect(res).toContain( + `import '@tensorflow/tfjs-core/dist/base_side_effects';`); expect(res).toContain(`export * from '@tensorflow/tfjs-core/dist/base';`); }); From a2869d2f4b00026edd00b6aa945fac8a9b49fe1b Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Fri, 31 Jul 2020 14:57:55 -0400 Subject: [PATCH 10/18] restore gradient registration --- tfjs-core/src/index.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 971236dbf3c..52090c4e992 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -20,8 +20,7 @@ import './base_side_effects'; // All exports from this package should be in base. export * from './base'; -// TEMP comment these out in this branch. In 3.x these will be removed // Register all the gradients. -// import './register_all_gradients'; +import './register_all_gradients'; // Import all op chainers and add type info to Tensor. import './public/chained_ops/register_all_chained_ops'; From c8fff269752abf1a63100c46bf299fb724f5a837 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Mon, 3 Aug 2020 23:04:38 -0400 Subject: [PATCH 11/18] add kernel_to_ops mapper script --- tfjs-converter/metadata/kernel2op.json | 202 ++++++++++++++++++++++ tfjs-converter/package.json | 1 + tfjs-converter/scripts/kernels_to_ops.ts | 142 +++++++++++++++ tfjs-converter/yarn.lock | 209 ++++++++++++++++++++++- tsconfig.test.json | 2 +- 5 files changed, 553 insertions(+), 3 deletions(-) create mode 100644 tfjs-converter/metadata/kernel2op.json create mode 100644 tfjs-converter/scripts/kernels_to_ops.ts diff --git a/tfjs-converter/metadata/kernel2op.json b/tfjs-converter/metadata/kernel2op.json new file mode 100644 index 00000000000..384fa752884 --- /dev/null +++ b/tfjs-converter/metadata/kernel2op.json @@ -0,0 +1,202 @@ +{ + "Abs": "abs", + "Acos": "acos", + "Acosh": "acosh", + "Add": "add", + "AddN": "addN", + "AddV2": "add", + "All": "all", + "Any": "any", + "ArgMax": "argMax", + "ArgMin": "argMin", + "Asin": "asin", + "Asinh": "asinh", + "Atan": "atan", + "Atan2": "atan2", + "Atanh": "atanh", + "AvgPool": "avgPool", + "AvgPool3D": "avgPool3d", + "BatchMatMul": "matMul", + "BatchMatMulV2": "matMul", + "BatchToSpaceND": "batchToSpaceND", + "BiasAdd": "add", + "BroadcastTo": "broadcastTo", + "Cast": "cast", + "Ceil": "ceil", + "ClipByValue": "clipByValue", + "Complex": "complex", + "ComplexAbs": "abs", + "Concat": "concat", + "ConcatV2": "concat", + "Const": null, + "Conv1D": "conv1d", + "Conv2D": "conv2d", + "Conv2DBackpropInput": "conv2dTranspose", + "Conv2dTranspose": "conv2dTranspose", + "Conv3D": "conv3d", + "Cos": "cos", + "Cosh": "cosh", + "CropAndResize": "image.cropAndResize", + "Cumsum": "cumsum", + "DepthToSpace": "depthToSpace", + "DepthwiseConv2d": "depthwiseConv2d", + "DepthwiseConv2dNative": "depthwiseConv2d", + "Dilation2D": "dilation2d", + "Div": "div", + "DivNoNan": "divNoNan", + "Elu": "elu", + "Enter": null, + "Equal": "equal", + "Erf": "erf", + "Exit": null, + "Exp": "exp", + "ExpandDims": "expandDims", + "Expm1": "expm1", + "FFT": "fft", + "FakeQuantWithMinMaxVars": null, + "Fill": "fill", + "Floor": "floor", + "FloorDiv": "floorDiv", + "FloorMod": "mod", + "FusedBatchNorm": "batchNorm", + "FusedBatchNormV2": "batchNorm", + "FusedBatchNormV3": "batchNorm", + "FusedDepthwiseConv2dNative": null, + "Gather": "gather", + "GatherNd": "gatherND", + "GatherV2": "gather", + "Greater": "greater", + "GreaterEqual": "greaterEqual", + "IFFT": "ifft", + "IRFFT": "irfft", + "Identity": null, + "IdentityN": null, + "If": null, + "Imag": "imag", + "LRN": "localResponseNormalization", + "LeakyRelu": "leakyRelu", + "Less": "less", + "LessEqual": "lessEqual", + "LinSpace": "linspace", + "ListDiff": "setdiff1dAsync", + "Log": "log", + "Log1p": "log1p", + "LogSoftmax": "logSoftmax", + "LogicalAnd": "logicalAnd", + "LogicalNot": "logicalNot", + "LogicalOr": "logicalOr", + "LoopCond": null, + "MatMul": "matMul", + "Max": "max", + "MaxPool": "maxPool", + "MaxPool3D": "maxPool3d", + "MaxPoolWithArgmax": "maxPoolWithArgmax", + "Maximum": "maximum", + "Mean": "mean", + "Merge": null, + "Min": "min", + "Minimum": "minimum", + "Mod": "mod", + "Mul": "mul", + "Multinomial": "multinomial", + "Neg": "neg", + "NextIteration": null, + "NoOp": "scalar", + "NonMaxSuppressionV2": "image.nonMaxSuppressionWithScoreAsync", + "NonMaxSuppressionV3": "image.nonMaxSuppressionWithScoreAsync", + "NonMaxSuppressionV4": "image.nonMaxSuppressionWithScoreAsync", + "NonMaxSuppressionV5": "image.nonMaxSuppressionWithScoreAsync", + "NotEqual": "notEqual", + "OneHot": "oneHot", + "Ones": "ones", + "OnesLike": "onesLike", + "Pack": "tidy", + "Pad": "pad", + "PadV2": "pad", + "Placeholder": null, + "PlaceholderWithDefault": null, + "Pow": "pow", + "Prelu": "prelu", + "Print": null, + "Prod": "prod", + "RFFT": "rfft", + "RandomUniform": "randomUniform", + "Range": "range", + "Rank": "scalar", + "Real": "real", + "RealDiv": "div", + "Reciprocal": "reciprocal", + "Relu": "relu", + "Relu6": "clipByValue", + "Reshape": "reshape", + "ResizeBilinear": "image.resizeBilinear", + "ResizeNearestNeighbor": "image.resizeNearestNeighbor", + "Reverse": "reverse", + "ReverseV2": "reverse", + "Round": "round", + "Rsqrt": "rsqrt", + "ScatterNd": "scatterND", + "Select": "where", + "SelectV2": "where", + "Selu": "selu", + "Shape": "tensor1d", + "ShapeN": "tensor1d", + "Sigmoid": "sigmoid", + "Sign": "sign", + "Sin": "sin", + "Sinh": "sinh", + "Size": "scalar", + "Slice": "slice", + "Snapshot": null, + "Softmax": "softmax", + "Softplus": "softplus", + "SpaceToBatchND": "spaceToBatchND", + "SparseToDense": "sparseToDense", + "Split": "split", + "SplitV": "split", + "Sqrt": "sqrt", + "Square": "square", + "SquaredDifference": "squaredDifference", + "Squeeze": "squeeze", + "StatelessIf": null, + "StatelessWhile": null, + "StopGradient": null, + "StridedSlice": "stridedSlice", + "Sub": "sub", + "Sum": "sum", + "Switch": null, + "Tan": "tan", + "Tanh": "tanh", + "TensorArrayCloseV3": null, + "TensorArrayConcatV3": null, + "TensorArrayGatherV3": null, + "TensorArrayReadV3": null, + "TensorArrayScatterV3": null, + "TensorArraySizeV3": null, + "TensorArraySplitV3": null, + "TensorArrayV3": null, + "TensorArrayWriteV3": null, + "TensorListConcat": null, + "TensorListFromTensor": null, + "TensorListGather": null, + "TensorListGetItem": null, + "TensorListPopBack": null, + "TensorListPushBack": null, + "TensorListReserve": null, + "TensorListScatter": null, + "TensorListScatterV2": null, + "TensorListSetItem": null, + "TensorListSplit": null, + "TensorListStack": null, + "Tile": "tile", + "TopKV2": "topk", + "Transpose": "transpose", + "TruncatedNormal": "truncatedNormal", + "Unpack": "tidy", + "Where": null, + "While": null, + "Zeros": "zeros", + "ZerosLike": "zerosLike", + "_FusedConv2D": null, + "_FusedMatMul": "fused.matMul" +} \ No newline at end of file diff --git a/tfjs-converter/package.json b/tfjs-converter/package.json index 3b8e1812e5f..373a067d7ed 100644 --- a/tfjs-converter/package.json +++ b/tfjs-converter/package.json @@ -48,6 +48,7 @@ "rollup": "~2.3.2", "rollup-plugin-terser": "~5.3.0", "rollup-plugin-visualizer": "~3.3.2", + "ts-morph": "^7.1.3", "ts-node": "~8.8.2", "tslint": "~5.8.0", "tslint-no-circular-imports": "~0.5.0", diff --git a/tfjs-converter/scripts/kernels_to_ops.ts b/tfjs-converter/scripts/kernels_to_ops.ts new file mode 100644 index 00000000000..427fc7fb074 --- /dev/null +++ b/tfjs-converter/scripts/kernels_to_ops.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * This script generates a mapping of Kernel Names to op names as defined by + * the converter source code. This allows a couple of things for modular builds + * 1. From a model.json file we can create imports for the ops the converter + * will call. + * 2. From those ops we could validate that the kernels we add to the modular + * build match the names of kernels in model.json (this is not necessary + * but is potentially useful for alignment). + * + * This can also be used to keep our supported ops list up to date. + * + * The approach used is to parse the source code of the converter executors + * (src/operations/executors) for the following kind pattern. + * case 'BiasAdd': + * case 'AddV2': + * case 'Add': { + * return [tfc.add( + * (getParamValue('a', node, tensorMap, context) as tfc.Tensor), + * getParamValue('b', node, tensorMap, context) as tfc.Tensor)]; + * } + * + * Case matchers represent kernel names and tfc.* represent the tfjs op that is + * called. This example shows that we need to support fallthrough case + * statements as well. + * + */ + +import * as argparse from 'argparse'; +import * as fs from 'fs'; +import {CaseClause, CaseOrDefaultClause, Project, SourceFile, SwitchStatement, SyntaxKind} from 'ts-morph'; + +const parser = new argparse.ArgumentParser(); + +parser.addArgument( + '--out', {help: 'Path to output JSON to create', required: true}); + +// initialize +const project = new Project({}); + +function getSwitchStatement(source: SourceFile): SwitchStatement { + let switchStatement: SwitchStatement; + source.forEachDescendant((node) => { + if (node.getKindName() === 'SwitchStatement') { + switchStatement = node as SwitchStatement; + } + }); + return switchStatement; +} + +function getKernelMappingForFile(source: SourceFile) { + const switchStatement = getSwitchStatement(source); + if (switchStatement === null) { + throw new Error('No switch statment found in executor'); + } + const caseClauses = switchStatement.getClauses(); + + const kernelsToOp: {[key: string]: string;} = {}; + let currentClauseGroup: string[] = []; + caseClauses.forEach((caseClause: CaseOrDefaultClause) => { + if (caseClause instanceof CaseClause) { + let kernelName; + caseClause.forEachChild(clausePart => { + const kind = clausePart.getKindName(); + if (kind === 'StringLiteral') { + kernelName = clausePart.getText().replace(/\'/g, ''); + currentClauseGroup.push(kernelName); + } + if (kind === 'Block' || kind === 'ReturnStatement') { + const callExprs = + clausePart.getDescendantsOfKind(SyntaxKind.CallExpression); + const tfcCall = callExprs.find(expr => expr.getText().match(/tfc/)); + let tfSymbol = null; + if (tfcCall != null) { + const tfcCallStr = tfcCall.getText(); + console.log('tfcCallStr', tfcCallStr); + const symbolMatcher = /(tfc\.([\w\.]*))\(/; + const matches = tfcCallStr.match(symbolMatcher); + tfSymbol = matches != null ? matches[2] : null; + } + + for (const kern of currentClauseGroup) { + kernelsToOp[kern] = tfSymbol; + } + currentClauseGroup = []; + } + }); + } + }); + + return kernelsToOp; +} + +function getKernelMapping() { + const sourceFiles = project.getSourceFiles(); + + const kernelsToOp: {[key: string]: string;} = {}; + + for (const sourceFile of sourceFiles) { + const mapping = getKernelMappingForFile(sourceFile); + Object.assign(kernelsToOp, mapping); + } + return kernelsToOp; +} + +async function run(outputFilePath: string) { + const EXECUTORS_PATH = 'src/operations/executors/*_executor.ts'; + project.addSourceFilesAtPaths(EXECUTORS_PATH); + + const kernelMapping = getKernelMapping(); + + const pairs: Array<[string, string]> = Object.entries(kernelMapping).sort(); + const sortedKernelMapping: {[key: string]: string;} = {}; + pairs.forEach(([k, v]) => { + sortedKernelMapping[k] = v; + }); + const replacer: null = null; + const space = 2; + fs.writeFileSync( + outputFilePath, JSON.stringify(sortedKernelMapping, replacer, space), + {encoding: 'utf8'}); +} + +const args = parser.parseArgs(); +console.log('Writing output to', args.out); +run(args.out); diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index 49b9ca0755d..a896b7b5e67 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -23,6 +23,35 @@ chalk "^2.0.0" js-tokens "^4.0.0" +"@dsherret/to-absolute-glob@^2.0.2": + version "2.0.2" + resolved "https://registry.yarnpkg.com/@dsherret/to-absolute-glob/-/to-absolute-glob-2.0.2.tgz#1f6475dc8bd974cea07a2daf3864b317b1dd332c" + integrity sha1-H2R13IvZdM6gei2vOGSzF7HdMyw= + dependencies: + is-absolute "^1.0.0" + is-negated-glob "^1.0.0" + +"@nodelib/fs.scandir@2.1.3": + version "2.1.3" + resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.3.tgz#3a582bdb53804c6ba6d146579c46e52130cf4a3b" + integrity sha512-eGmwYQn3gxo4r7jdQnkrrN6bY478C3P+a/y72IJukF8LjB6ZHeB3c+Ehacj3sYeSmUXGlnA67/PmbM9CVwL7Dw== + dependencies: + "@nodelib/fs.stat" "2.0.3" + run-parallel "^1.1.9" + +"@nodelib/fs.stat@2.0.3", "@nodelib/fs.stat@^2.0.2": + version "2.0.3" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.3.tgz#34dc5f4cabbc720f4e60f75a747e7ecd6c175bd3" + integrity sha512-bQBFruR2TAwoevBEd/NWMoAAtNGzTRgdrqnYCc7dhzfoNvqPzLyqlEQnzZ3kVnNrSp25iyxE00/3h2fqGAGArA== + +"@nodelib/fs.walk@^1.2.3": + version "1.2.4" + resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.4.tgz#011b9202a70a6366e436ca5c065844528ab04976" + integrity sha512-1V9XOY4rDW0rehzbrcqAmHnz8e7SKvX27gh8Gt2WgB0+pdzdiLV83p72kZPU+jvMbS1qU5mauP2iOvO8rhmurQ== + dependencies: + "@nodelib/fs.scandir" "2.1.3" + fastq "^1.6.0" + "@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": version "1.1.2" resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf" @@ -125,6 +154,18 @@ version "0.0.0" uid "" +"@ts-morph/common@~0.5.2": + version "0.5.2" + resolved "https://registry.yarnpkg.com/@ts-morph/common/-/common-0.5.2.tgz#d02c2493c1e07dfd47f247b4f0b72f083fcaea3a" + integrity sha512-eLmfYV6u6gUgHrB9QV9lpuWg3cD60mhXdv0jvM5exWR/Cor8HG+GziFIj2hPEWHJknqzuU4meZd8DTqIzZfDRQ== + dependencies: + "@dsherret/to-absolute-glob" "^2.0.2" + fast-glob "^3.2.2" + fs-extra "^9.0.0" + is-negated-glob "^1.0.0" + multimatch "^4.0.0" + typescript "~3.9.7" + "@types/color-name@^1.1.1": version "1.1.1" resolved "https://registry.yarnpkg.com/@types/color-name/-/color-name-1.1.1.tgz#1c1261bbeaa10a8055bbc5d8ab84b7b2afc846a0" @@ -155,6 +196,11 @@ resolved "https://registry.yarnpkg.com/@types/long/-/long-3.0.32.tgz#f4e5af31e9e9b196d8e5fca8a5e2e20aa3d60b69" integrity sha512-ZXyOOm83p7X8p3s0IYM3VeueNmHpkk/yMlP8CLeOnEcu6hIwPH7YjZBvhQkR0ZFS2DqZAxKtJ/M5fcuv3OU5BA== +"@types/minimatch@^3.0.3": + version "3.0.3" + resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-3.0.3.tgz#3dca0e3f33b200fc7d1139c0cd96c1268cadfd9d" + integrity sha512-tHq6qdbT9U1IRSGf14CL0pUlULksvY9OZ+5eEgl1N7t+OA3tGvNpxJCzuKQlsNgCVwbAs670L1vcVQi8j9HjnA== + "@types/node-fetch@1.6.9": version "1.6.9" resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-1.6.9.tgz#a750fb0f4cf2960bf72b462e4c86908022dd69c5" @@ -348,6 +394,11 @@ arr-union@^3.1.0: resolved "https://registry.yarnpkg.com/arr-union/-/arr-union-3.1.0.tgz#e39b09aea9def866a8f206e288af63919bae39c4" integrity sha1-45sJrqne+Gao8gbiiK9jkZuuOcQ= +array-differ@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/array-differ/-/array-differ-3.0.0.tgz#3cbb3d0f316810eafcc47624734237d6aee4ae6b" + integrity sha512-THtfYS6KtME/yIAhKjZ2ul7XI96lQGHRputJQHO80LAWQnuGP4iCIN8vdMRboGbIEYBwU33q8Tch1os2+X0kMg== + array-find-index@^1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/array-find-index/-/array-find-index-1.0.2.tgz#df010aa1287e164bbda6f9723b0a96a1ec4187a1" @@ -360,6 +411,11 @@ array-union@^1.0.1: dependencies: array-uniq "^1.0.1" +array-union@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/array-union/-/array-union-2.1.0.tgz#b798420adbeb1de828d84acd8a2e23d3efe85e8d" + integrity sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw== + array-uniq@^1.0.1: version "1.0.3" resolved "https://registry.yarnpkg.com/array-uniq/-/array-uniq-1.0.3.tgz#af6ac877a25cc7f74e058894753858dfdb24fdb6" @@ -380,6 +436,11 @@ arrify@^1.0.0: resolved "https://registry.yarnpkg.com/arrify/-/arrify-1.0.1.tgz#898508da2226f380df904728456849c1501a4b0d" integrity sha1-iYUI2iIm84DfkEcoRWhJwVAaSw0= +arrify@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/arrify/-/arrify-2.0.1.tgz#c9655e9331e0abcd588d2a7cad7e9956f66701fa" + integrity sha512-3duEwti880xqi4eAMN8AyR4a0ByT90zoYdLlevfrvU43vb0YZwZVfxOgxWrLXXXpyugL0hNZc9G6BiB5B3nUug== + asn1.js@^4.0.0: version "4.10.1" resolved "https://registry.yarnpkg.com/asn1.js/-/asn1.js-4.10.1.tgz#b9c2bf5805f1e64aadeed6df3a2bfafb5a73f5a0" @@ -424,6 +485,11 @@ async@^2.6.1, async@^2.6.2: dependencies: lodash "^4.17.11" +at-least-node@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/at-least-node/-/at-least-node-1.0.0.tgz#602cd4b46e844ad4effc92a8011a3c46e0238dc2" + integrity sha512-+q/t7Ekv1EDY2l6Gda6LLiX14rU9TV20Wa3ofeQmwPFZbOMo9DXrLbOjFaaclkXKWidIaopwAObQDqwWtGUjqg== + atob@^2.1.1: version "2.1.2" resolved "https://registry.yarnpkg.com/atob/-/atob-2.1.2.tgz#6d9517eb9e030d2436666651e86bd9f6f13533c9" @@ -1395,6 +1461,11 @@ clone@^1.0.2: resolved "https://registry.yarnpkg.com/clone/-/clone-1.0.4.tgz#da309cc263df15994c688ca902179ca3c7cd7c7e" integrity sha1-2jCcwmPfFZlMaIypAheco8fNfH4= +code-block-writer@^10.1.0: + version "10.1.0" + resolved "https://registry.yarnpkg.com/code-block-writer/-/code-block-writer-10.1.0.tgz#54fc410ebef2af836d9c2314ac40af7d7b37eee9" + integrity sha512-RG9hpXtWFeUWhuUav1YuP/vGcyncW+t90yJLk9fNZs1De2OuHTHKAKThVCokt29PYq5RoJ0QSZaIZ+rvPO23hA== + code-point-at@^1.0.0: version "1.1.0" resolved "https://registry.yarnpkg.com/code-point-at/-/code-point-at-1.1.0.tgz#0d070b4d043a5bea33a2f1a40e2edb3d9a4ccf77" @@ -2048,6 +2119,18 @@ fast-deep-equal@^1.0.0: resolved "https://registry.yarnpkg.com/fast-deep-equal/-/fast-deep-equal-1.1.0.tgz#c053477817c86b51daa853c81e059b733d023614" integrity sha1-wFNHeBfIa1HaqFPIHgWbcz0CNhQ= +fast-glob@^3.2.2: + version "3.2.4" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.4.tgz#d20aefbf99579383e7f3cc66529158c9b98554d3" + integrity sha512-kr/Oo6PX51265qeuCYsyGypiO5uJFgBS0jksyG7FUeCyQzNwYnzrNIMR1NXfkZXsMYXYLRAHgISHBz8gQcxKHQ== + dependencies: + "@nodelib/fs.stat" "^2.0.2" + "@nodelib/fs.walk" "^1.2.3" + glob-parent "^5.1.0" + merge2 "^1.3.0" + micromatch "^4.0.2" + picomatch "^2.2.1" + fast-json-stable-stringify@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/fast-json-stable-stringify/-/fast-json-stable-stringify-2.0.0.tgz#d5142c0caee6b1189f87d3a76111064f86c8bbf2" @@ -2058,6 +2141,13 @@ fast-levenshtein@~2.0.4: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha1-PYpcZog6FqMMqGQ+hR8Zuqd5eRc= +fastq@^1.6.0: + version "1.8.0" + resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.8.0.tgz#550e1f9f59bbc65fe185cb6a9b4d95357107f481" + integrity sha512-SMIZoZdLh/fgofivvIkmknUXyPnvxRE3DhtZ5Me3Mrsk5gyPL42F0xr51TdRXskBxHfMp+07bcYzfsYEsSQA9Q== + dependencies: + reusify "^1.0.4" + fill-range@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-4.0.0.tgz#d544811d428f98eb06a63dc402d2403c328c38f7" @@ -2153,6 +2243,16 @@ fs-extra@^7.0.1: jsonfile "^4.0.0" universalify "^0.1.0" +fs-extra@^9.0.0: + version "9.0.1" + resolved "https://registry.yarnpkg.com/fs-extra/-/fs-extra-9.0.1.tgz#910da0062437ba4c39fedd863f1675ccfefcb9fc" + integrity sha512-h2iAoN838FqAFJY2/qVpzFXy+EBxfVE220PalAqQLDVsFOHLJrZvut5puAbCdNv6WJk+B8ihI+k0c7JK5erwqQ== + dependencies: + at-least-node "^1.0.0" + graceful-fs "^4.2.0" + jsonfile "^6.0.1" + universalify "^1.0.0" + fs-minipass@^1.2.5: version "1.2.5" resolved "https://registry.yarnpkg.com/fs-minipass/-/fs-minipass-1.2.5.tgz#06c277218454ec288df77ada54a03b8702aacb9d" @@ -2230,6 +2330,13 @@ glob-parent@^3.1.0: is-glob "^3.1.0" path-dirname "^1.0.0" +glob-parent@^5.1.0: + version "5.1.1" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.1.tgz#b6c1ef417c4e5663ea498f1c45afac6916bbc229" + integrity sha512-FnI+VGOpnlGHWZxthPGR+QhR78fuiK0sNLkHQv+bL9fQi57lNNdquIbna/WrfROrolq8GK5Ek6BiMwqL/voRYQ== + dependencies: + is-glob "^4.0.1" + glob@^5.0.15: version "5.0.15" resolved "https://registry.yarnpkg.com/glob/-/glob-5.0.15.tgz#1bc936b9e02f4a603fcc222ecf7633d30b8b93b1" @@ -2299,6 +2406,11 @@ graceful-fs@^4.1.11, graceful-fs@^4.1.15, graceful-fs@^4.1.2, graceful-fs@^4.1.6 resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.1.15.tgz#ffb703e1066e8a0eeaa4c8b80ba9253eeefbfb00" integrity sha512-6uHUhOPEBgQ24HM+r6b/QwWfZq+yiFcipKFrOFiBEnWdy5sdzYoi+pJeQaPI5qOLRFqWmAXUPQNsielzdLoecA== +graceful-fs@^4.2.0: + version "4.2.4" + resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.4.tgz#2256bde14d3632958c465ebc96dc467ca07a29fb" + integrity sha512-WjKPNJF79dtJAVniUlGGWHYGz2jWxT6VhN/4m1NdkbZ2nOsEF+cI1Edgql5zCRhs/VsQYRvrXctxktVXZUkixw== + handlebars@^4.0.1: version "4.1.2" resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.1.2.tgz#b6b37c1ced0306b221e094fc7aca3ec23b131b67" @@ -2530,6 +2642,14 @@ invert-kv@^1.0.0: resolved "https://registry.yarnpkg.com/invert-kv/-/invert-kv-1.0.0.tgz#104a8e4aaca6d3d8cd157a8ef8bfab2d7a3ffdb6" integrity sha1-EEqOSqym09jNFXqO+L+rLXo//bY= +is-absolute@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-absolute/-/is-absolute-1.0.0.tgz#395e1ae84b11f26ad1795e73c17378e48a301576" + integrity sha512-dOWoqflvcydARa360Gvv18DZ/gRuHKi2NU/wU5X1ZFzdYfH29nkiNZsF3mp4OJ3H4yo9Mx8A/uAGNzpzPN3yBA== + dependencies: + is-relative "^1.0.0" + is-windows "^1.0.1" + is-accessor-descriptor@^0.1.6: version "0.1.6" resolved "https://registry.yarnpkg.com/is-accessor-descriptor/-/is-accessor-descriptor-0.1.6.tgz#a9e12cb3ae8d876727eeef3843f8a0897b5c98d6" @@ -2648,11 +2768,23 @@ is-glob@^4.0.0: dependencies: is-extglob "^2.1.1" +is-glob@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.1.tgz#7567dbe9f2f5e2467bc77ab83c4a29482407a5dc" + integrity sha512-5G0tKtBTFImOqDnLB2hG6Bp2qcKEFduo4tZu9MT/H6NQv/ghhy30o55ufafxJ/LdH79LLs2Kfrn85TLKyA7BUg== + dependencies: + is-extglob "^2.1.1" + is-module@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" integrity sha1-Mlj7afeMFNW4FdZkM2tM/7ZEFZE= +is-negated-glob@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-negated-glob/-/is-negated-glob-1.0.0.tgz#6910bca5da8c95e784b5751b976cf5a10fee36d2" + integrity sha1-aRC8pdqMleeEtXUbl2z1oQ/uNtI= + is-number@^3.0.0: version "3.0.0" resolved "https://registry.yarnpkg.com/is-number/-/is-number-3.0.0.tgz#24fd6201a4782cf50561c810276afc7d12d71195" @@ -2698,12 +2830,26 @@ is-reference@^1.1.2: dependencies: "@types/estree" "0.0.39" +is-relative@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-relative/-/is-relative-1.0.0.tgz#a1bb6935ce8c5dba1e8b9754b9b2dcc020e2260d" + integrity sha512-Kw/ReK0iqwKeu0MITLFuj0jbPAmEiOsIwyIXvvbfa6QfmN9pkD1M+8pdk7Rl/dTKbH34/XBFMbgD4iMJhLQbGA== + dependencies: + is-unc-path "^1.0.0" + +is-unc-path@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-unc-path/-/is-unc-path-1.0.0.tgz#d731e8898ed090a12c352ad2eaed5095ad322c9d" + integrity sha512-mrGpVd0fs7WWLfVsStvgF6iEJnbjDFZh9/emhRDcGWTduTfNHd9CHeUwH3gYIjdbwo4On6hunkztwOaAw0yllQ== + dependencies: + unc-path-regex "^0.1.2" + is-utf8@^0.2.0: version "0.2.1" resolved "https://registry.yarnpkg.com/is-utf8/-/is-utf8-0.2.1.tgz#4b0da1442104d1b336340e80797e865cf39f7d72" integrity sha1-Sw2hRCEE0bM2NA6AeX6GXPOffXI= -is-windows@^1.0.2: +is-windows@^1.0.1, is-windows@^1.0.2: version "1.0.2" resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== @@ -2840,6 +2986,15 @@ jsonfile@^4.0.0: optionalDependencies: graceful-fs "^4.1.6" +jsonfile@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/jsonfile/-/jsonfile-6.0.1.tgz#98966cba214378c8c84b82e085907b40bf614179" + integrity sha512-jR2b5v7d2vIOust+w3wtFKZIfpC2pnRmFAhAC/BuweZFQR8qZzxH1OyrQ10HmdVYiXWkYUqPVsz91cG7EL2FBg== + dependencies: + universalify "^1.0.0" + optionalDependencies: + graceful-fs "^4.1.6" + karma-browserstack-launcher@~1.4.0: version "1.4.0" resolved "https://registry.yarnpkg.com/karma-browserstack-launcher/-/karma-browserstack-launcher-1.4.0.tgz#22f92e969d2db6cfc00e578708bda39378d5f2ab" @@ -3156,6 +3311,11 @@ merge-stream@^2.0.0: resolved "https://registry.yarnpkg.com/merge-stream/-/merge-stream-2.0.0.tgz#52823629a14dd00c9770fb6ad47dc6310f2c1f60" integrity sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w== +merge2@^1.3.0: + version "1.4.1" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" + integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== + micromatch@^3.1.10, micromatch@^3.1.4: version "3.1.10" resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-3.1.10.tgz#70859bc95c9840952f359a068a3fc49f9ecfac23" @@ -3280,6 +3440,17 @@ ms@^2.1.1: resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.1.tgz#30a5864eb3ebb0a66f2ebe6d727af06a09d86e0a" integrity sha512-tgp+dl5cGk28utYktBsrFqA7HKgrhgPsg6Z/EfhWI4gl1Hwq8B/GmY/0oXZ6nF8hDVesS/FpnYaD/kOWhYQvyg== +multimatch@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/multimatch/-/multimatch-4.0.0.tgz#8c3c0f6e3e8449ada0af3dd29efb491a375191b3" + integrity sha512-lDmx79y1z6i7RNx0ZGCPq1bzJ6ZoDDKbvh7jxr9SJcWLkShMzXrHbYVpTdnhNM5MXpDUxCQ4DgqVttVXlBgiBQ== + dependencies: + "@types/minimatch" "^3.0.3" + array-differ "^3.0.0" + array-union "^2.1.0" + arrify "^2.0.1" + minimatch "^3.0.4" + nan@^2.9.2: version "2.12.1" resolved "https://registry.yarnpkg.com/nan/-/nan-2.12.1.tgz#7b1aa193e9aa86057e3c7bbd0ac448e770925552" @@ -3675,7 +3846,7 @@ pbkdf2@^3.0.3: safe-buffer "^5.0.1" sha.js "^2.4.8" -picomatch@^2.0.5: +picomatch@^2.0.5, picomatch@^2.2.1: version "2.2.2" resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.2.2.tgz#21f333e9b6b8eaff02468f5146ea406d345f4dad" integrity sha512-q0M/9eZHzmr0AulXyPwNfZjtwZ/RBZlbN3K3CErVrk50T2ASYI7Bye0EvekFY3IP1Nt2DHu0re+V2ZHIpMkuWg== @@ -4053,6 +4224,11 @@ ret@~0.1.10: resolved "https://registry.yarnpkg.com/ret/-/ret-0.1.15.tgz#b8a4825d5bdb1fc3f6f53c2bc33f81388681c7bc" integrity sha512-TTlYpa+OL+vMMNG24xSlQGEJ3B/RzEfUlLct7b5G/ytav+wPrplCpVMFuwzXbkecJrb6IYo1iFb0S9v37754mg== +reusify@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" + integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== + rfdc@^1.1.2: version "1.1.4" resolved "https://registry.yarnpkg.com/rfdc/-/rfdc-1.1.4.tgz#ba72cc1367a0ccd9cf81a870b3b58bd3ad07f8c2" @@ -4117,6 +4293,11 @@ rollup@~2.3.2: optionalDependencies: fsevents "~2.1.2" +run-parallel@^1.1.9: + version "1.1.9" + resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.1.9.tgz#c9dd3a7cf9f4b2c4b6244e173a6ed866e61dd679" + integrity sha512-DEqnSRTDw/Tc3FXf49zedI638Z9onwUotBMiUFKmrO2sdFKIbXamXGQ3Axd4qgphxKB4kw/qP1w5kTxnfU1B9Q== + safe-buffer@^5.0.1, safe-buffer@^5.1.0, safe-buffer@^5.1.1, safe-buffer@^5.1.2, safe-buffer@~5.1.0, safe-buffer@~5.1.1: version "5.1.2" resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d" @@ -4651,6 +4832,15 @@ trim-right@^1.0.1: resolved "https://registry.yarnpkg.com/trim-right/-/trim-right-1.0.1.tgz#cb2e1203067e0c8de1f614094b9fe45704ea6003" integrity sha1-yy4SAwZ+DI3h9hQJS5/kVwTqYAM= +ts-morph@^7.1.3: + version "7.1.3" + resolved "https://registry.yarnpkg.com/ts-morph/-/ts-morph-7.1.3.tgz#b1bdda66e49500227453ef19874bbf9e0f359d04" + integrity sha512-NlfQolw+IT+gFDnRPrkce9h427d6/Vea8S2YdV5C6aAg2CGuWzmkMfcLKmrDqIkU8ZXcI3KqmJAqiaLhEcPXWQ== + dependencies: + "@dsherret/to-absolute-glob" "^2.0.2" + "@ts-morph/common" "~0.5.2" + code-block-writer "^10.1.0" + ts-node@~8.8.2: version "8.8.2" resolved "https://registry.yarnpkg.com/ts-node/-/ts-node-8.8.2.tgz#0b39e690bee39ea5111513a9d2bcdc0bc121755f" @@ -4721,6 +4911,11 @@ typescript@3.5.3: resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.5.3.tgz#c830f657f93f1ea846819e929092f5fe5983e977" integrity sha512-ACzBtm/PhXBDId6a6sDJfroT2pOWt/oOnk4/dElG5G33ZL776N3Y6/6bKZJBFpd+b05F3Ct9qDjMeJmRWtE2/g== +typescript@~3.9.7: + version "3.9.7" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.9.7.tgz#98d600a5ebdc38f40cb277522f12dc800e9e25fa" + integrity sha512-BLbiRkiBzAwsjut4x/dsibSTB6yWpwT5qWmC2OfuCg3GgVQCSgMs4vEctYPhsaGtd0AeuuHMkjZ2h2WG8MSzRw== + uglify-js@^3.1.4: version "3.6.0" resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.6.0.tgz#704681345c53a8b2079fb6cec294b05ead242ff5" @@ -4734,6 +4929,11 @@ ultron@~1.1.0: resolved "https://registry.yarnpkg.com/ultron/-/ultron-1.1.1.tgz#9fe1536a10a664a65266a1e3ccf85fd36302bc9c" integrity sha512-UIEXBNeYmKptWH6z8ZnqTeS8fV74zG0/eRU9VGkpzz+LIJNs8W/zM/L+7ctCkRrgbNnnR0xxw4bKOr0cW0N0Og== +unc-path-regex@^0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/unc-path-regex/-/unc-path-regex-0.1.2.tgz#e73dd3d7b0d7c5ed86fbac6b0ae7d8c6a69d50fa" + integrity sha1-5z3T17DXxe2G+6xrCufYxqadUPo= + union-value@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/union-value/-/union-value-1.0.0.tgz#5c71c34cb5bad5dcebe3ea0cd08207ba5aa1aea4" @@ -4749,6 +4949,11 @@ universalify@^0.1.0: resolved "https://registry.yarnpkg.com/universalify/-/universalify-0.1.2.tgz#b646f69be3942dabcecc9d6639c80dc105efaa66" integrity sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg== +universalify@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/universalify/-/universalify-1.0.0.tgz#b61a1da173e8435b2fe3c67d29b9adf8594bd16d" + integrity sha512-rb6X1W158d7pRQBg5gkR8uPaSfiids68LTJQYOtEUhoJUWBdaQHsuT/EUduxXYxcrt4r5PJ4fuHW1MHT6p0qug== + unpipe@1.0.0, unpipe@~1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/unpipe/-/unpipe-1.0.0.tgz#b2bf4ee8514aae6165b4817829d21b2ef49904ec" diff --git a/tsconfig.test.json b/tsconfig.test.json index 11126cd76d2..8e3496aced3 100644 --- a/tsconfig.test.json +++ b/tsconfig.test.json @@ -10,7 +10,7 @@ "declaration": false, "target": "es5", "lib": [ - "es2015", + "es2017", "dom" ], "outDir": "./dist", From f8bb30162621c1598cc2113a3b993c2e4d0c4fd5 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 4 Aug 2020 00:15:22 -0400 Subject: [PATCH 12/18] allow multiple tfjs op calls per executor block --- tfjs-converter/metadata/kernel2op.json | 718 ++++++++++++++++------- tfjs-converter/scripts/kernels_to_ops.ts | 32 +- 2 files changed, 537 insertions(+), 213 deletions(-) diff --git a/tfjs-converter/metadata/kernel2op.json b/tfjs-converter/metadata/kernel2op.json index 384fa752884..8a5de115876 100644 --- a/tfjs-converter/metadata/kernel2op.json +++ b/tfjs-converter/metadata/kernel2op.json @@ -1,202 +1,520 @@ { - "Abs": "abs", - "Acos": "acos", - "Acosh": "acosh", - "Add": "add", - "AddN": "addN", - "AddV2": "add", - "All": "all", - "Any": "any", - "ArgMax": "argMax", - "ArgMin": "argMin", - "Asin": "asin", - "Asinh": "asinh", - "Atan": "atan", - "Atan2": "atan2", - "Atanh": "atanh", - "AvgPool": "avgPool", - "AvgPool3D": "avgPool3d", - "BatchMatMul": "matMul", - "BatchMatMulV2": "matMul", - "BatchToSpaceND": "batchToSpaceND", - "BiasAdd": "add", - "BroadcastTo": "broadcastTo", - "Cast": "cast", - "Ceil": "ceil", - "ClipByValue": "clipByValue", - "Complex": "complex", - "ComplexAbs": "abs", - "Concat": "concat", - "ConcatV2": "concat", - "Const": null, - "Conv1D": "conv1d", - "Conv2D": "conv2d", - "Conv2DBackpropInput": "conv2dTranspose", - "Conv2dTranspose": "conv2dTranspose", - "Conv3D": "conv3d", - "Cos": "cos", - "Cosh": "cosh", - "CropAndResize": "image.cropAndResize", - "Cumsum": "cumsum", - "DepthToSpace": "depthToSpace", - "DepthwiseConv2d": "depthwiseConv2d", - "DepthwiseConv2dNative": "depthwiseConv2d", - "Dilation2D": "dilation2d", - "Div": "div", - "DivNoNan": "divNoNan", - "Elu": "elu", - "Enter": null, - "Equal": "equal", - "Erf": "erf", - "Exit": null, - "Exp": "exp", - "ExpandDims": "expandDims", - "Expm1": "expm1", - "FFT": "fft", - "FakeQuantWithMinMaxVars": null, - "Fill": "fill", - "Floor": "floor", - "FloorDiv": "floorDiv", - "FloorMod": "mod", - "FusedBatchNorm": "batchNorm", - "FusedBatchNormV2": "batchNorm", - "FusedBatchNormV3": "batchNorm", - "FusedDepthwiseConv2dNative": null, - "Gather": "gather", - "GatherNd": "gatherND", - "GatherV2": "gather", - "Greater": "greater", - "GreaterEqual": "greaterEqual", - "IFFT": "ifft", - "IRFFT": "irfft", - "Identity": null, - "IdentityN": null, - "If": null, - "Imag": "imag", - "LRN": "localResponseNormalization", - "LeakyRelu": "leakyRelu", - "Less": "less", - "LessEqual": "lessEqual", - "LinSpace": "linspace", - "ListDiff": "setdiff1dAsync", - "Log": "log", - "Log1p": "log1p", - "LogSoftmax": "logSoftmax", - "LogicalAnd": "logicalAnd", - "LogicalNot": "logicalNot", - "LogicalOr": "logicalOr", - "LoopCond": null, - "MatMul": "matMul", - "Max": "max", - "MaxPool": "maxPool", - "MaxPool3D": "maxPool3d", - "MaxPoolWithArgmax": "maxPoolWithArgmax", - "Maximum": "maximum", - "Mean": "mean", - "Merge": null, - "Min": "min", - "Minimum": "minimum", - "Mod": "mod", - "Mul": "mul", - "Multinomial": "multinomial", - "Neg": "neg", - "NextIteration": null, - "NoOp": "scalar", - "NonMaxSuppressionV2": "image.nonMaxSuppressionWithScoreAsync", - "NonMaxSuppressionV3": "image.nonMaxSuppressionWithScoreAsync", - "NonMaxSuppressionV4": "image.nonMaxSuppressionWithScoreAsync", - "NonMaxSuppressionV5": "image.nonMaxSuppressionWithScoreAsync", - "NotEqual": "notEqual", - "OneHot": "oneHot", - "Ones": "ones", - "OnesLike": "onesLike", - "Pack": "tidy", - "Pad": "pad", - "PadV2": "pad", - "Placeholder": null, - "PlaceholderWithDefault": null, - "Pow": "pow", - "Prelu": "prelu", - "Print": null, - "Prod": "prod", - "RFFT": "rfft", - "RandomUniform": "randomUniform", - "Range": "range", - "Rank": "scalar", - "Real": "real", - "RealDiv": "div", - "Reciprocal": "reciprocal", - "Relu": "relu", - "Relu6": "clipByValue", - "Reshape": "reshape", - "ResizeBilinear": "image.resizeBilinear", - "ResizeNearestNeighbor": "image.resizeNearestNeighbor", - "Reverse": "reverse", - "ReverseV2": "reverse", - "Round": "round", - "Rsqrt": "rsqrt", - "ScatterNd": "scatterND", - "Select": "where", - "SelectV2": "where", - "Selu": "selu", - "Shape": "tensor1d", - "ShapeN": "tensor1d", - "Sigmoid": "sigmoid", - "Sign": "sign", - "Sin": "sin", - "Sinh": "sinh", - "Size": "scalar", - "Slice": "slice", - "Snapshot": null, - "Softmax": "softmax", - "Softplus": "softplus", - "SpaceToBatchND": "spaceToBatchND", - "SparseToDense": "sparseToDense", - "Split": "split", - "SplitV": "split", - "Sqrt": "sqrt", - "Square": "square", - "SquaredDifference": "squaredDifference", - "Squeeze": "squeeze", - "StatelessIf": null, - "StatelessWhile": null, - "StopGradient": null, - "StridedSlice": "stridedSlice", - "Sub": "sub", - "Sum": "sum", - "Switch": null, - "Tan": "tan", - "Tanh": "tanh", - "TensorArrayCloseV3": null, - "TensorArrayConcatV3": null, - "TensorArrayGatherV3": null, - "TensorArrayReadV3": null, - "TensorArrayScatterV3": null, - "TensorArraySizeV3": null, - "TensorArraySplitV3": null, - "TensorArrayV3": null, - "TensorArrayWriteV3": null, - "TensorListConcat": null, - "TensorListFromTensor": null, - "TensorListGather": null, - "TensorListGetItem": null, - "TensorListPopBack": null, - "TensorListPushBack": null, - "TensorListReserve": null, - "TensorListScatter": null, - "TensorListScatterV2": null, - "TensorListSetItem": null, - "TensorListSplit": null, - "TensorListStack": null, - "Tile": "tile", - "TopKV2": "topk", - "Transpose": "transpose", - "TruncatedNormal": "truncatedNormal", - "Unpack": "tidy", - "Where": null, - "While": null, - "Zeros": "zeros", - "ZerosLike": "zerosLike", - "_FusedConv2D": null, - "_FusedMatMul": "fused.matMul" + "Abs": [ + "abs" + ], + "Acos": [ + "acos" + ], + "Acosh": [ + "acosh" + ], + "Add": [ + "add" + ], + "AddN": [ + "addN" + ], + "AddV2": [ + "add" + ], + "All": [ + "all" + ], + "Any": [ + "any" + ], + "ArgMax": [ + "argMax" + ], + "ArgMin": [ + "argMin" + ], + "Asin": [ + "asin" + ], + "Asinh": [ + "asinh" + ], + "Atan": [ + "atan" + ], + "Atan2": [ + "atan2" + ], + "Atanh": [ + "atanh" + ], + "AvgPool": [ + "avgPool" + ], + "AvgPool3D": [ + "avgPool3d" + ], + "BatchMatMul": [ + "matMul" + ], + "BatchMatMulV2": [ + "matMul" + ], + "BatchToSpaceND": [ + "batchToSpaceND" + ], + "BiasAdd": [ + "add" + ], + "BroadcastTo": [ + "broadcastTo" + ], + "Cast": [ + "cast" + ], + "Ceil": [ + "ceil" + ], + "ClipByValue": [ + "clipByValue" + ], + "Complex": [ + "complex" + ], + "ComplexAbs": [ + "abs" + ], + "Concat": [ + "concat" + ], + "ConcatV2": [ + "concat" + ], + "Const": [], + "Conv1D": [ + "conv1d" + ], + "Conv2D": [ + "conv2d" + ], + "Conv2DBackpropInput": [ + "conv2dTranspose" + ], + "Conv2dTranspose": [ + "conv2dTranspose" + ], + "Conv3D": [ + "conv3d" + ], + "Cos": [ + "cos" + ], + "Cosh": [ + "cosh" + ], + "CropAndResize": [ + "image.cropAndResize" + ], + "Cumsum": [ + "cumsum" + ], + "DepthToSpace": [ + "depthToSpace" + ], + "DepthwiseConv2d": [ + "depthwiseConv2d" + ], + "DepthwiseConv2dNative": [ + "depthwiseConv2d" + ], + "Dilation2D": [ + "dilation2d" + ], + "Div": [ + "div" + ], + "DivNoNan": [ + "divNoNan" + ], + "Elu": [ + "elu" + ], + "Enter": [], + "Equal": [ + "equal" + ], + "Erf": [ + "erf" + ], + "Exit": [], + "Exp": [ + "exp" + ], + "ExpandDims": [ + "expandDims" + ], + "Expm1": [ + "expm1" + ], + "FFT": [ + "fft" + ], + "FakeQuantWithMinMaxVars": [], + "Fill": [ + "fill" + ], + "Floor": [ + "floor" + ], + "FloorDiv": [ + "floorDiv" + ], + "FloorMod": [ + "mod" + ], + "FusedBatchNorm": [ + "batchNorm" + ], + "FusedBatchNormV2": [ + "batchNorm" + ], + "FusedBatchNormV3": [ + "batchNorm" + ], + "FusedDepthwiseConv2dNative": [], + "Gather": [ + "gather" + ], + "GatherNd": [ + "gatherND" + ], + "GatherV2": [ + "gather" + ], + "Greater": [ + "greater" + ], + "GreaterEqual": [ + "greaterEqual" + ], + "IFFT": [ + "ifft" + ], + "IRFFT": [ + "irfft" + ], + "Identity": [], + "IdentityN": [], + "If": [], + "Imag": [ + "imag" + ], + "LRN": [ + "localResponseNormalization" + ], + "LeakyRelu": [ + "leakyRelu" + ], + "Less": [ + "less" + ], + "LessEqual": [ + "lessEqual" + ], + "LinSpace": [ + "linspace" + ], + "ListDiff": [ + "setdiff1dAsync" + ], + "Log": [ + "log" + ], + "Log1p": [ + "log1p" + ], + "LogSoftmax": [ + "logSoftmax" + ], + "LogicalAnd": [ + "logicalAnd" + ], + "LogicalNot": [ + "logicalNot" + ], + "LogicalOr": [ + "logicalOr" + ], + "LoopCond": [], + "MatMul": [ + "matMul" + ], + "Max": [ + "max" + ], + "MaxPool": [ + "maxPool" + ], + "MaxPool3D": [ + "maxPool3d" + ], + "MaxPoolWithArgmax": [ + "maxPoolWithArgmax" + ], + "Maximum": [ + "maximum" + ], + "Mean": [ + "mean" + ], + "Merge": [], + "Min": [ + "min" + ], + "Minimum": [ + "minimum" + ], + "Mod": [ + "mod" + ], + "Mul": [ + "mul" + ], + "Multinomial": [ + "multinomial" + ], + "Neg": [ + "neg" + ], + "NextIteration": [], + "NoOp": [ + "scalar" + ], + "NonMaxSuppressionV2": [ + "image.nonMaxSuppressionAsync" + ], + "NonMaxSuppressionV3": [ + "image.nonMaxSuppressionAsync" + ], + "NonMaxSuppressionV4": [ + "image.nonMaxSuppressionPaddedAsync" + ], + "NonMaxSuppressionV5": [ + "image.nonMaxSuppressionWithScoreAsync" + ], + "NotEqual": [ + "notEqual" + ], + "OneHot": [ + "oneHot" + ], + "Ones": [ + "ones" + ], + "OnesLike": [ + "onesLike" + ], + "Pack": [ + "tidy", + "util.arraysEqual", + "stack" + ], + "Pad": [ + "pad" + ], + "PadV2": [ + "pad" + ], + "Placeholder": [], + "PlaceholderWithDefault": [], + "Pow": [ + "pow" + ], + "Prelu": [ + "prelu" + ], + "Print": [], + "Prod": [ + "prod" + ], + "RFFT": [ + "rfft" + ], + "RandomUniform": [ + "randomUniform" + ], + "Range": [ + "range" + ], + "Rank": [ + "scalar" + ], + "Real": [ + "real" + ], + "RealDiv": [ + "div" + ], + "Reciprocal": [ + "reciprocal" + ], + "Relu": [ + "relu" + ], + "Relu6": [ + "clipByValue" + ], + "Reshape": [ + "reshape" + ], + "ResizeBilinear": [ + "image.resizeBilinear" + ], + "ResizeNearestNeighbor": [ + "image.resizeNearestNeighbor" + ], + "Reverse": [ + "reverse" + ], + "ReverseV2": [ + "reverse" + ], + "Round": [ + "round" + ], + "Rsqrt": [ + "rsqrt" + ], + "ScatterNd": [ + "scatterND" + ], + "Select": [ + "where" + ], + "SelectV2": [ + "where" + ], + "Selu": [ + "selu" + ], + "Shape": [ + "tensor1d" + ], + "ShapeN": [ + "tensor1d" + ], + "Sigmoid": [ + "sigmoid" + ], + "Sign": [ + "sign" + ], + "Sin": [ + "sin" + ], + "Sinh": [ + "sinh" + ], + "Size": [ + "scalar" + ], + "Slice": [ + "slice" + ], + "Snapshot": [], + "Softmax": [ + "softmax" + ], + "Softplus": [ + "softplus" + ], + "SpaceToBatchND": [ + "spaceToBatchND" + ], + "SparseToDense": [ + "sparseToDense" + ], + "Split": [ + "split" + ], + "SplitV": [ + "split" + ], + "Sqrt": [ + "sqrt" + ], + "Square": [ + "square" + ], + "SquaredDifference": [ + "squaredDifference" + ], + "Squeeze": [ + "squeeze" + ], + "StatelessIf": [], + "StatelessWhile": [], + "StopGradient": [], + "StridedSlice": [ + "stridedSlice" + ], + "Sub": [ + "sub" + ], + "Sum": [ + "sum" + ], + "Switch": [], + "Tan": [ + "tan" + ], + "Tanh": [ + "tanh" + ], + "TensorArrayCloseV3": [], + "TensorArrayConcatV3": [], + "TensorArrayGatherV3": [], + "TensorArrayReadV3": [], + "TensorArrayScatterV3": [], + "TensorArraySizeV3": [], + "TensorArraySplitV3": [], + "TensorArrayV3": [], + "TensorArrayWriteV3": [], + "TensorListConcat": [], + "TensorListFromTensor": [], + "TensorListGather": [], + "TensorListGetItem": [], + "TensorListPopBack": [], + "TensorListPushBack": [], + "TensorListReserve": [], + "TensorListScatter": [], + "TensorListScatterV2": [], + "TensorListSetItem": [], + "TensorListSplit": [], + "TensorListStack": [], + "Tile": [ + "tile" + ], + "TopKV2": [ + "topk" + ], + "Transpose": [ + "transpose" + ], + "TruncatedNormal": [ + "truncatedNormal" + ], + "Unpack": [ + "unstack" + ], + "Where": [ + "whereAsync" + ], + "While": [], + "Zeros": [ + "zeros" + ], + "ZerosLike": [ + "zerosLike" + ], + "_FusedConv2D": [], + "_FusedMatMul": [ + "fused.matMul" + ] } \ No newline at end of file diff --git a/tfjs-converter/scripts/kernels_to_ops.ts b/tfjs-converter/scripts/kernels_to_ops.ts index 427fc7fb074..378aebf9fe5 100644 --- a/tfjs-converter/scripts/kernels_to_ops.ts +++ b/tfjs-converter/scripts/kernels_to_ops.ts @@ -71,7 +71,7 @@ function getKernelMappingForFile(source: SourceFile) { } const caseClauses = switchStatement.getClauses(); - const kernelsToOp: {[key: string]: string;} = {}; + const kernelsToOp: {[key: string]: string[];} = {}; let currentClauseGroup: string[] = []; caseClauses.forEach((caseClause: CaseOrDefaultClause) => { if (caseClause instanceof CaseClause) { @@ -85,18 +85,24 @@ function getKernelMappingForFile(source: SourceFile) { if (kind === 'Block' || kind === 'ReturnStatement') { const callExprs = clausePart.getDescendantsOfKind(SyntaxKind.CallExpression); - const tfcCall = callExprs.find(expr => expr.getText().match(/tfc/)); - let tfSymbol = null; - if (tfcCall != null) { + const tfcCallExprs = + callExprs.filter(expr => expr.getText().match(/tfc/)); + const tfSymbols: Set = new Set(); + for (const tfcCall of tfcCallExprs) { const tfcCallStr = tfcCall.getText(); - console.log('tfcCallStr', tfcCallStr); - const symbolMatcher = /(tfc\.([\w\.]*))\(/; - const matches = tfcCallStr.match(symbolMatcher); - tfSymbol = matches != null ? matches[2] : null; + const functionCallMatcher = /(tfc\.([\w\.]*)\()/g; + const matches = tfcCallStr.match(functionCallMatcher); + if (matches != null && matches.length > 0) { + for (const match of matches) { + // extract the method name (and any namespaces used to call it) + const symbolMatcher = /(tfc\.([\w\.]*)\()/; + const symbol = match.match(symbolMatcher)[2]; + tfSymbols.add(symbol); + } + } } - for (const kern of currentClauseGroup) { - kernelsToOp[kern] = tfSymbol; + kernelsToOp[kern] = Array.from(tfSymbols); } currentClauseGroup = []; } @@ -110,7 +116,7 @@ function getKernelMappingForFile(source: SourceFile) { function getKernelMapping() { const sourceFiles = project.getSourceFiles(); - const kernelsToOp: {[key: string]: string;} = {}; + const kernelsToOp: {[key: string]: string[];} = {}; for (const sourceFile of sourceFiles) { const mapping = getKernelMappingForFile(sourceFile); @@ -125,8 +131,8 @@ async function run(outputFilePath: string) { const kernelMapping = getKernelMapping(); - const pairs: Array<[string, string]> = Object.entries(kernelMapping).sort(); - const sortedKernelMapping: {[key: string]: string;} = {}; + const pairs: Array<[string, string[]]> = Object.entries(kernelMapping).sort(); + const sortedKernelMapping: {[key: string]: string[];} = {}; pairs.forEach(([k, v]) => { sortedKernelMapping[k] = v; }); From ec2aa31262004a9e2b89bf47c15e0aba0cc1bd40 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 4 Aug 2020 00:21:21 -0400 Subject: [PATCH 13/18] refactor executors to not branch on op name once inside switch statement --- tfjs-converter/metadata/kernel2op.json | 8 +- .../executors/convolution_executor.ts | 474 ++++++++++-------- .../operations/executors/dynamic_executor.ts | 85 ++-- .../executors/slice_join_executor.ts | 10 +- 4 files changed, 319 insertions(+), 258 deletions(-) diff --git a/tfjs-converter/metadata/kernel2op.json b/tfjs-converter/metadata/kernel2op.json index 8a5de115876..2f4016623f5 100644 --- a/tfjs-converter/metadata/kernel2op.json +++ b/tfjs-converter/metadata/kernel2op.json @@ -177,7 +177,9 @@ "FusedBatchNormV3": [ "batchNorm" ], - "FusedDepthwiseConv2dNative": [], + "FusedDepthwiseConv2dNative": [ + "fused.depthwiseConv2d" + ], "Gather": [ "gather" ], @@ -513,7 +515,9 @@ "ZerosLike": [ "zerosLike" ], - "_FusedConv2D": [], + "_FusedConv2D": [ + "fused.conv2d" + ], "_FusedMatMul": [ "fused.matMul" ] diff --git a/tfjs-converter/src/operations/executors/convolution_executor.ts b/tfjs-converter/src/operations/executors/convolution_executor.ts index c59620a7c5c..f68a9091458 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor.ts @@ -23,240 +23,278 @@ import {InternalOpExecutor, Node} from '../types'; import {getPadding, getParamValue} from './utils'; -export const executeOp: InternalOpExecutor = (node: Node, - tensorMap: NamedTensorsMap, - context: ExecutionContext): - tfc.Tensor[] => { - switch (node.op) { - case 'Conv1D': { - const stride = - getParamValue('stride', node, tensorMap, context) as number; - const pad = getParamValue('pad', node, tensorMap, context); - const dataFormat = - (getParamValue('dataFormat', node, tensorMap, context) as string) - .toUpperCase(); - const dilation = - getParamValue('dilation', node, tensorMap, context) as number; - return [tfc.conv1d( - getParamValue('x', node, tensorMap, context) as tfc.Tensor3D, - getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, - stride, pad as 'valid' | 'same', dataFormat as 'NWC' | 'NCW', - dilation)]; +function fusedConvAndDepthWiseParams( + node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) { + const [extraOp, activationFunc] = + (getParamValue('fusedOps', node, tensorMap, context) as string[]); + + const isBiasAdd = extraOp === 'biasadd'; + const isPrelu = activationFunc === 'prelu'; + const isBatchNorm = extraOp === 'fusedbatchnorm'; + + const numArgs = + (getParamValue('numArgs', node, tensorMap, context) as number); + if (isBiasAdd) { + if (isPrelu && numArgs !== 2) { + throw new Error( + 'FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' + + 'must have two extra arguments: bias and alpha.'); } - case 'Conv2D': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getPadding(node, tensorMap, context); - const dataFormat = - (getParamValue('dataFormat', node, tensorMap, context) as string) - .toUpperCase(); - const dilations = - getParamValue('dilations', node, tensorMap, context) as number[]; - return [tfc.conv2d( - getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D, - [stride[1], stride[2]], pad as 'valid' | 'same', - dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])]; + if (!isPrelu && numArgs !== 1) { + throw new Error( + 'FusedConv2d and DepthwiseConv2d with BiasAdd must have ' + + 'one extra argument: bias.'); } - case '_FusedConv2D': - case 'FusedDepthwiseConv2dNative': { - const [extraOp, activationFunc] = - (getParamValue('fusedOps', node, tensorMap, context) as string[]); + } + if (isBatchNorm) { + throw new Error( + 'FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported.'); + } + const stride = getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getPadding(node, tensorMap, context); + const dataFormat = + (getParamValue('dataFormat', node, tensorMap, context) as string) + .toUpperCase(); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; + const [biasArg, preluArg] = + getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; - const isBiasAdd = extraOp === 'biasadd'; - const isPrelu = activationFunc === 'prelu'; - const isBatchNorm = extraOp === 'fusedbatchnorm'; + return { + stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc + } +} - const numArgs = - (getParamValue('numArgs', node, tensorMap, context) as number); - if (isBiasAdd) { - if (isPrelu && numArgs !== 2) { - throw new Error( - 'FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' + - 'must have two extra arguments: bias and alpha.'); +export const executeOp: InternalOpExecutor = + (node: Node, tensorMap: NamedTensorsMap, + context: ExecutionContext): tfc.Tensor[] => { + switch (node.op) { + case 'Conv1D': { + const stride = + getParamValue('stride', node, tensorMap, context) as number; + const pad = getParamValue('pad', node, tensorMap, context); + const dataFormat = + (getParamValue('dataFormat', node, tensorMap, context) as string) + .toUpperCase(); + const dilation = + getParamValue('dilation', node, tensorMap, context) as number; + return [tfc.conv1d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, + stride, pad as 'valid' | 'same', dataFormat as 'NWC' | 'NCW', + dilation)]; } - if (!isPrelu && numArgs !== 1) { - throw new Error( - 'FusedConv2d and DepthwiseConv2d with BiasAdd must have ' + - 'one extra argument: bias.'); + case 'Conv2D': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getPadding(node, tensorMap, context); + const dataFormat = + (getParamValue('dataFormat', node, tensorMap, context) as string) + .toUpperCase(); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; + return [tfc.conv2d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D, + [stride[1], stride[2]], pad as 'valid' | 'same', + dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])]; } - } - if (isBatchNorm) { - throw new Error( - 'FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported.'); - } - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getPadding(node, tensorMap, context); - const dataFormat = - (getParamValue('dataFormat', node, tensorMap, context) as string) - .toUpperCase(); - const dilations = - getParamValue('dilations', node, tensorMap, context) as number[]; - const [biasArg, preluArg] = - getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; - const kernelMethod = node.op === '_FusedConv2D' ? - tfc.fused.conv2d : - tfc.fused.depthwiseConv2d; - return [kernelMethod({ - x: getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - filter: getParamValue('filter', node, tensorMap, context) as - tfc.Tensor4D, - strides: [stride[1], stride[2]], - pad: pad as 'valid' | 'same', - dataFormat: dataFormat as 'NHWC' | 'NCHW', - dilations: [dilations[1], dilations[2]], - bias: biasArg, - activation: activationFunc as tfc.fused.Activation, - preluActivationWeights: preluArg - })]; - } - case 'Conv2DBackpropInput': - case 'Conv2dTranspose': { - const shape = getParamValue( - 'outputShape', node, tensorMap, - context) as [number, number, number] | - [number, number, number, number]; - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getPadding(node, tensorMap, context); - return [tfc.conv2dTranspose( - getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D, - shape, [stride[1], stride[2]], pad as 'valid' | 'same')]; - } - case 'DepthwiseConv2dNative': - case 'DepthwiseConv2d': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getPadding(node, tensorMap, context); - const dilations = - getParamValue('dilations', node, tensorMap, context) as number[]; - const dataFormat = - (getParamValue('dataFormat', node, tensorMap, context) as string) - .toUpperCase(); + case '_FusedConv2D': { + const { + stride, + pad, + dataFormat, + dilations, + biasArg, + preluArg, + activationFunc + } = fusedConvAndDepthWiseParams(node, tensorMap, context); - return [tfc.depthwiseConv2d( - getParamValue('input', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D, - [stride[1], stride[2]], pad as 'valid' | 'same', - dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])]; - } - case 'Conv3D': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const dataFormat = - (getParamValue('dataFormat', node, tensorMap, context) as string) - .toUpperCase(); - const dilations = - getParamValue('dilations', node, tensorMap, context) as number[]; - return [tfc.conv3d( - getParamValue('x', node, tensorMap, context) as tfc.Tensor4D | - tfc.Tensor, - getParamValue('filter', node, tensorMap, context) as - tfc.Tensor, - [stride[1], stride[2], stride[3]], pad as 'valid' | 'same', - dataFormat as 'NDHWC' | 'NCDHW', - [dilations[1], dilations[2], dilations[3]])]; - } - case 'AvgPool': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const kernelSize = - getParamValue('kernelSize', node, tensorMap, context) as number[]; + return [tfc.fused.conv2d({ + x: getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + filter: getParamValue('filter', node, tensorMap, context) as + tfc.Tensor4D, + strides: [stride[1], stride[2]], + pad: pad as 'valid' | 'same', + dataFormat: dataFormat as 'NHWC' | 'NCHW', + dilations: [dilations[1], dilations[2]], + bias: biasArg, + activation: activationFunc as tfc.fused.Activation, + preluActivationWeights: preluArg + })]; + } - return [tfc.avgPool( - getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], - pad as 'valid' | 'same')]; - } - case 'MaxPool': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const kernelSize = - getParamValue('kernelSize', node, tensorMap, context) as number[]; + case 'FusedDepthwiseConv2dNative': { + const { + stride, + pad, + dataFormat, + dilations, + biasArg, + preluArg, + activationFunc + } = fusedConvAndDepthWiseParams(node, tensorMap, context); - return [tfc.maxPool( - getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], - pad as 'valid' | 'same')]; - } - case 'MaxPoolWithArgmax': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const kernelSize = - getParamValue('kernelSize', node, tensorMap, context) as number[]; - const includeBatchInIndex = - getParamValue('includeBatchInIndex', node, tensorMap, context) as - boolean; - const {result, indexes} = tfc.maxPoolWithArgmax( - getParamValue('x', node, tensorMap, context) as tfc.Tensor4D, - [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], - pad as 'valid' | 'same', includeBatchInIndex); - return [result, indexes]; - } - case 'AvgPool3D': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const kernelSize = - getParamValue('kernelSize', node, tensorMap, context) as number[]; + return [tfc.fused.depthwiseConv2d({ + x: getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + filter: getParamValue('filter', node, tensorMap, context) as + tfc.Tensor4D, + strides: [stride[1], stride[2]], + pad: pad as 'valid' | 'same', + dataFormat: dataFormat as 'NHWC' | 'NCHW', + dilations: [dilations[1], dilations[2]], + bias: biasArg, + activation: activationFunc as tfc.fused.Activation, + preluActivationWeights: preluArg + })]; + } + case 'Conv2DBackpropInput': + case 'Conv2dTranspose': { + const shape = getParamValue( + 'outputShape', node, tensorMap, + context) as [number, number, number] | + [number, number, number, number]; + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getPadding(node, tensorMap, context); + return [tfc.conv2dTranspose( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D, + shape, [stride[1], stride[2]], pad as 'valid' | 'same')]; + } + case 'DepthwiseConv2dNative': + case 'DepthwiseConv2d': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getPadding(node, tensorMap, context); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; + const dataFormat = + (getParamValue('dataFormat', node, tensorMap, context) as string) + .toUpperCase(); - return [tfc.avgPool3d( - getParamValue('x', node, tensorMap, context) as tfc.Tensor5D, - [kernelSize[1], kernelSize[2], kernelSize[3]], - [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; - } + return [tfc.depthwiseConv2d( + getParamValue('input', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor4D, + [stride[1], stride[2]], pad as 'valid' | 'same', + dataFormat as 'NHWC' | 'NCHW', [dilations[1], dilations[2]])]; + } + case 'Conv3D': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const dataFormat = + (getParamValue('dataFormat', node, tensorMap, context) as string) + .toUpperCase(); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; + return [tfc.conv3d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor4D | + tfc.Tensor, + getParamValue('filter', node, tensorMap, context) as + tfc.Tensor, + [stride[1], stride[2], stride[3]], pad as 'valid' | 'same', + dataFormat as 'NDHWC' | 'NCDHW', + [dilations[1], dilations[2], dilations[3]])]; + } + case 'AvgPool': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const kernelSize = + getParamValue('kernelSize', node, tensorMap, context) as number[]; + + return [tfc.avgPool( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], + pad as 'valid' | 'same')]; + } + case 'MaxPool': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const kernelSize = + getParamValue('kernelSize', node, tensorMap, context) as number[]; - case 'MaxPool3D': { - const stride = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const kernelSize = - getParamValue('kernelSize', node, tensorMap, context) as number[]; + return [tfc.maxPool( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], + pad as 'valid' | 'same')]; + } + case 'MaxPoolWithArgmax': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const kernelSize = + getParamValue('kernelSize', node, tensorMap, context) as number[]; + const includeBatchInIndex = + getParamValue('includeBatchInIndex', node, tensorMap, context) as + boolean; + const {result, indexes} = tfc.maxPoolWithArgmax( + getParamValue('x', node, tensorMap, context) as tfc.Tensor4D, + [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], + pad as 'valid' | 'same', includeBatchInIndex); + return [result, indexes]; + } + case 'AvgPool3D': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const kernelSize = + getParamValue('kernelSize', node, tensorMap, context) as number[]; - return [tfc.maxPool3d( - getParamValue('x', node, tensorMap, context) as tfc.Tensor5D, - [kernelSize[1], kernelSize[2], kernelSize[3]], - [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; - } + return [tfc.avgPool3d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor5D, + [kernelSize[1], kernelSize[2], kernelSize[3]], + [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; + } - case 'Dilation2D': { - const strides = - getParamValue('strides', node, tensorMap, context) as number[]; - const pad = getParamValue('pad', node, tensorMap, context); - const dilations = - getParamValue('dilations', node, tensorMap, context) as number[]; + case 'MaxPool3D': { + const stride = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const kernelSize = + getParamValue('kernelSize', node, tensorMap, context) as number[]; - // strides: [1, stride_height, stride_width, 1]. - const strideHeight = strides[1]; - const strideWidth = strides[2]; + return [tfc.maxPool3d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor5D, + [kernelSize[1], kernelSize[2], kernelSize[3]], + [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; + } - // dilations: [1, dilation_height, dilation_width, 1]. - const dilationHeight = dilations[1]; - const dilationWidth = dilations[2]; + case 'Dilation2D': { + const strides = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; - return [tfc.dilation2d( - getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | - tfc.Tensor4D, - getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, - [strideHeight, strideWidth], pad as 'valid' | 'same', - [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)]; - } + // strides: [1, stride_height, stride_width, 1]. + const strideHeight = strides[1]; + const strideWidth = strides[2]; - default: - throw TypeError(`Node type ${node.op} is not implemented`); - } -}; + // dilations: [1, dilation_height, dilation_width, 1]. + const dilationHeight = dilations[1]; + const dilationWidth = dilations[2]; + + return [tfc.dilation2d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, + [strideHeight, strideWidth], pad as 'valid' | 'same', + [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)]; + } + + default: + throw TypeError(`Node type ${node.op} is not implemented`); + } + }; export const CATEGORY = 'convolution'; diff --git a/tfjs-converter/src/operations/executors/dynamic_executor.ts b/tfjs-converter/src/operations/executors/dynamic_executor.ts index 3c0d4897d25..326140573d0 100644 --- a/tfjs-converter/src/operations/executors/dynamic_executor.ts +++ b/tfjs-converter/src/operations/executors/dynamic_executor.ts @@ -23,47 +23,68 @@ import {InternalOpAsyncExecutor, Node} from '../types'; import {getParamValue} from './utils'; +function nmsParams( + node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) { + const boxes = getParamValue('boxes', node, tensorMap, context) as tfc.Tensor; + const scores = + getParamValue('scores', node, tensorMap, context) as tfc.Tensor; + const maxOutputSize = + getParamValue('maxOutputSize', node, tensorMap, context) as number; + const iouThreshold = + getParamValue('iouThreshold', node, tensorMap, context) as number; + const scoreThreshold = + getParamValue('scoreThreshold', node, tensorMap, context) as number; + const softNmsSigma = + getParamValue('softNmsSigma', node, tensorMap, context) as number; + + return { + boxes, + scores, + maxOutputSize, + iouThreshold, + scoreThreshold, + softNmsSigma + }; +} + export const executeOp: InternalOpAsyncExecutor = async( node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): Promise => { switch (node.op) { - case 'NonMaxSuppressionV5': - case 'NonMaxSuppressionV4': - case 'NonMaxSuppressionV3': - case 'NonMaxSuppressionV2': { - const boxes = - getParamValue('boxes', node, tensorMap, context) as tfc.Tensor; - const scores = - getParamValue('scores', node, tensorMap, context) as tfc.Tensor; - const maxOutputSize = - getParamValue('maxOutputSize', node, tensorMap, context) as number; - const iouThreshold = - getParamValue('iouThreshold', node, tensorMap, context) as number; - const scoreThreshold = - getParamValue('scoreThreshold', node, tensorMap, context) as number; - - if (node.op === 'NonMaxSuppressionV5') { - const softNmsSigma = - getParamValue('softNmsSigma', node, tensorMap, context) as number; + case 'NonMaxSuppressionV5': { + const { + boxes, + scores, + maxOutputSize, + iouThreshold, + scoreThreshold, + softNmsSigma + } = nmsParams(node, tensorMap, context); - const result = await tfc.image.nonMaxSuppressionWithScoreAsync( - boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, - iouThreshold, scoreThreshold, softNmsSigma); + const result = await tfc.image.nonMaxSuppressionWithScoreAsync( + boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, + iouThreshold, scoreThreshold, softNmsSigma); - return [result.selectedIndices, result.selectedScores]; - } + return [result.selectedIndices, result.selectedScores]; + } + case 'NonMaxSuppressionV4': { + const {boxes, scores, maxOutputSize, iouThreshold, scoreThreshold} = + nmsParams(node, tensorMap, context); - if (node.op === 'NonMaxSuppressionV4') { - const padToMaxOutputSize = - getParamValue('padToMaxOutputSize', node, tensorMap, context) as - boolean; + const padToMaxOutputSize = + getParamValue('padToMaxOutputSize', node, tensorMap, context) as + boolean; - const result = await tfc.image.nonMaxSuppressionPaddedAsync( - boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, - iouThreshold, scoreThreshold, padToMaxOutputSize); + const result = await tfc.image.nonMaxSuppressionPaddedAsync( + boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, + iouThreshold, scoreThreshold, padToMaxOutputSize); - return [result.selectedIndices, result.validOutputs]; - } + return [result.selectedIndices, result.validOutputs]; + } + case 'NonMaxSuppressionV3': + case 'NonMaxSuppressionV2': { + const {boxes, scores, maxOutputSize, iouThreshold, scoreThreshold} = + nmsParams(node, tensorMap, context); return [await tfc.image.nonMaxSuppressionAsync( boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index 99f8bf22f98..b3b47e9c7a2 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -102,12 +102,10 @@ export const executeOp: InternalOpExecutor = (node: Node, }); } case 'Unpack': { - return tfc.tidy(() => { - const axis = getParamValue('axis', node, tensorMap, context) as number; - const tensor = - getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; - return tfc.unstack(tensor, axis); - }); + const axis = getParamValue('axis', node, tensorMap, context) as number; + const tensor = + getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; + return tfc.unstack(tensor, axis); } case 'Tile': { const reps = getParamValue('reps', node, tensorMap, context) as number[]; From 5ddd280684022cf2138575e21f75bdb792558046 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 4 Aug 2020 00:21:46 -0400 Subject: [PATCH 14/18] add tests for kernel to op mapping --- tfjs-converter/package.json | 9 +-- tfjs-converter/src/metadata_test.ts | 90 +++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 tfjs-converter/src/metadata_test.ts diff --git a/tfjs-converter/package.json b/tfjs-converter/package.json index 373a067d7ed..b4e10b0da2b 100644 --- a/tfjs-converter/package.json +++ b/tfjs-converter/package.json @@ -56,8 +56,8 @@ "yalc": "~1.0.0-pre.21" }, "scripts": { - "build": "yarn gen-json --test && tsc && yarn bundle", - "build-ci": "yarn gen-json --test && tsc && yarn bundle-ci", + "build": "yarn gen-json --test && yarn gen-kernel2ops && tsc && yarn bundle", + "build-ci": "yarn gen-json --test && yarn gen-kernel2ops && tsc && yarn bundle-ci", "bundle": "rollup -c", "bundle-ci": "rollup -c --ci", "build-core": "cd ../tfjs-core && yarn && yarn build", @@ -70,7 +70,7 @@ "link-local": "yalc link", "publish-local": "yarn build-npm && yalc push", "publish-npm": "npm publish", - "test": "yarn && yarn build-deps && yarn gen-json --test && ts-node -P tsconfig.test.json run_tests.ts", + "test": "yarn && yarn build-deps && yarn gen-json --test && yarn gen-kernel2ops && ts-node -P tsconfig.test.json run_tests.ts", "test-ci": "ts-node --skip-ignore -P tsconfig.test.json run_tests.ts", "test-snippets": "ts-node --skip-ignore -s ./scripts/test_snippets.ts", "lint": "tslint -p . -t verbose", @@ -80,6 +80,7 @@ "model-summary": "ts-node -s ./tools/model_summary.ts", "pb2json": "ts-node -s ./tools/pb2json_converter.ts", "build-pip-package": "yarn gen-json --test && cd python && ./build-pip-package.sh --test /tmp/tfjs-pips", - "run-python-tests": "yarn gen-json --test && cd python && ./run-python-tests.sh" + "run-python-tests": "yarn gen-json --test && cd python && ./run-python-tests.sh", + "gen-kernel2ops": "ts-node -s scripts/kernels_to_ops.ts --out metadata/kernel2op.json" } } diff --git a/tfjs-converter/src/metadata_test.ts b/tfjs-converter/src/metadata_test.ts new file mode 100644 index 00000000000..2df3acaa842 --- /dev/null +++ b/tfjs-converter/src/metadata_test.ts @@ -0,0 +1,90 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// This test checks the metadata.json to make sure that only the kernels that +// we know we don't map to tfjs ops have empty entries in metadata +// kernel2op.json +describe('kernel2op metadata file', () => { + it('has kernel2op.json', () => { + expect(() => { + // tslint:disable-next-line:no-require-imports + require('../metadata/kernel2op.json'); + }).not.toThrow(); + }); + + it('only known unmapped kernel are unmmapped', () => { + const knownUnmappedKernels = [ + 'Const', + 'Enter', + 'Exit', + 'FakeQuantWithMinMaxVars', + 'Identity', + 'IdentityN', + 'If', + 'LoopCond', + 'Merge', + 'NextIteration', + 'Placeholder', + 'PlaceholderWithDefault', + 'Print', + 'Snapshot', + 'StatelessIf', + 'StatelessWhile', + 'StopGradient', + 'Switch', + 'TensorArrayCloseV3', + 'TensorArrayConcatV3', + 'TensorArrayGatherV3', + 'TensorArrayReadV3', + 'TensorArrayScatterV3', + 'TensorArraySizeV3', + 'TensorArraySplitV3', + 'TensorArrayV3', + 'TensorArrayWriteV3', + 'TensorListConcat', + 'TensorListFromTensor', + 'TensorListGather', + 'TensorListGetItem', + 'TensorListPopBack', + 'TensorListPushBack', + 'TensorListReserve', + 'TensorListScatter', + 'TensorListScatterV2', + 'TensorListSetItem', + 'TensorListSplit', + 'TensorListStack', + 'While', + ]; + // tslint:disable-next-line:no-require-imports + const kernel2op = require('../metadata/kernel2op.json'); + const kernels: string[] = Object.keys(kernel2op); + + for (const kernelName of kernels) { + const tfOps = kernel2op[kernelName]; + if (knownUnmappedKernels.includes(kernelName)) { + expect(tfOps.length) + .toEqual(0, `Kernel "${kernelName}" is expected to be unmapped but + instead maps to ${tfOps}`); + } else { + expect(tfOps.length) + .toBeGreaterThan( + 0, `Kernel ${kernelName} is expected to be mapped to a list + of tf ops`); + } + } + }); +}); From 6f13136e5fcd028b29f1fd94791e7b3beec3d647 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 4 Aug 2020 11:40:15 -0400 Subject: [PATCH 15/18] docs and lint --- tfjs-converter/scripts/kernels_to_ops.ts | 26 ++++++++++++++----- .../executors/convolution_executor.ts | 10 +++++-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tfjs-converter/scripts/kernels_to_ops.ts b/tfjs-converter/scripts/kernels_to_ops.ts index 378aebf9fe5..d130b4a1686 100644 --- a/tfjs-converter/scripts/kernels_to_ops.ts +++ b/tfjs-converter/scripts/kernels_to_ops.ts @@ -36,8 +36,8 @@ * getParamValue('b', node, tensorMap, context) as tfc.Tensor)]; * } * - * Case matchers represent kernel names and tfc.* represent the tfjs op that is - * called. This example shows that we need to support fallthrough case + * Case matchers represent kernel names and tfc.*(...) represent the tfjs op(s) + * that are called. This example shows that we need to support fallthrough case * statements as well. * */ @@ -51,10 +51,10 @@ const parser = new argparse.ArgumentParser(); parser.addArgument( '--out', {help: 'Path to output JSON to create', required: true}); -// initialize const project = new Project({}); function getSwitchStatement(source: SourceFile): SwitchStatement { + // Each executor only has one switch statment. let switchStatement: SwitchStatement; source.forEachDescendant((node) => { if (node.getKindName() === 'SwitchStatement') { @@ -64,6 +64,10 @@ function getSwitchStatement(source: SourceFile): SwitchStatement { return switchStatement; } +type KernelMapping = { + [key: string]: string[] +}; + function getKernelMappingForFile(source: SourceFile) { const switchStatement = getSwitchStatement(source); if (switchStatement === null) { @@ -71,8 +75,11 @@ function getKernelMappingForFile(source: SourceFile) { } const caseClauses = switchStatement.getClauses(); - const kernelsToOp: {[key: string]: string[];} = {}; + const kernelsToOp: KernelMapping = {}; let currentClauseGroup: string[] = []; + + // Loop through clauses until you reach one that has a block or return. + // This allows us to coalesce fallthrough case blocks in a switch statement. caseClauses.forEach((caseClause: CaseOrDefaultClause) => { if (caseClause instanceof CaseClause) { let kernelName; @@ -83,6 +90,11 @@ function getKernelMappingForFile(source: SourceFile) { currentClauseGroup.push(kernelName); } if (kind === 'Block' || kind === 'ReturnStatement') { + // We have reached a code block, all the previously captured + // kernels use this block as their execution path. + + // Parse the code block and determing all the tfc.*() function calls + // used. const callExprs = clausePart.getDescendantsOfKind(SyntaxKind.CallExpression); const tfcCallExprs = @@ -104,6 +116,7 @@ function getKernelMappingForFile(source: SourceFile) { for (const kern of currentClauseGroup) { kernelsToOp[kern] = Array.from(tfSymbols); } + // Reset the clause tracker as we are moving to a new set of kernels currentClauseGroup = []; } }); @@ -115,8 +128,7 @@ function getKernelMappingForFile(source: SourceFile) { function getKernelMapping() { const sourceFiles = project.getSourceFiles(); - - const kernelsToOp: {[key: string]: string[];} = {}; + const kernelsToOp: KernelMapping = {}; for (const sourceFile of sourceFiles) { const mapping = getKernelMappingForFile(sourceFile); @@ -132,7 +144,7 @@ async function run(outputFilePath: string) { const kernelMapping = getKernelMapping(); const pairs: Array<[string, string[]]> = Object.entries(kernelMapping).sort(); - const sortedKernelMapping: {[key: string]: string[];} = {}; + const sortedKernelMapping: KernelMapping = {}; pairs.forEach(([k, v]) => { sortedKernelMapping[k] = v; }); diff --git a/tfjs-converter/src/operations/executors/convolution_executor.ts b/tfjs-converter/src/operations/executors/convolution_executor.ts index f68a9091458..f6521cb45a0 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor.ts @@ -61,8 +61,14 @@ function fusedConvAndDepthWiseParams( getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; return { - stride, pad, dataFormat, dilations, biasArg, preluArg, activationFunc - } + stride, + pad, + dataFormat, + dilations, + biasArg, + preluArg, + activationFunc + }; } export const executeOp: InternalOpExecutor = From 41c5ad8dc501d0f20fc3f0371364c1d688c2e7c6 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 4 Aug 2020 13:22:41 -0400 Subject: [PATCH 16/18] save --- tfjs-converter/scripts/kernels_to_ops.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-converter/scripts/kernels_to_ops.ts b/tfjs-converter/scripts/kernels_to_ops.ts index d130b4a1686..a6be511e09b 100644 --- a/tfjs-converter/scripts/kernels_to_ops.ts +++ b/tfjs-converter/scripts/kernels_to_ops.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2020 Google Inc. All Rights Reserved. + * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at From 501fdbf4feaa39f3da6a68ffe67ff07a44e6726b Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Fri, 7 Aug 2020 10:16:50 -0400 Subject: [PATCH 17/18] add argparse as dev dependency --- tfjs-converter/package.json | 1 + tfjs-converter/yarn.lock | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-converter/package.json b/tfjs-converter/package.json index b4e10b0da2b..b3a7a60f2ff 100644 --- a/tfjs-converter/package.json +++ b/tfjs-converter/package.json @@ -28,6 +28,7 @@ "@types/long": "~3.0.32", "@types/node-fetch": "1.6.9", "ajv": "~6.3.0", + "argparse": "^1.0.10", "babel-core": "~6.26.3", "babel-plugin-external-helpers": "~6.22.0", "babel-preset-env": "~1.7.0", diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index a896b7b5e67..eb4ca44d57c 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -372,7 +372,7 @@ arg@^4.1.0: resolved "https://registry.yarnpkg.com/arg/-/arg-4.1.3.tgz#269fc7ad5b8e42cb63c896d5666017261c144089" integrity sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA== -argparse@^1.0.7: +argparse@^1.0.10, argparse@^1.0.7: version "1.0.10" resolved "https://registry.yarnpkg.com/argparse/-/argparse-1.0.10.tgz#bcd6791ea5ae09725e17e5ad988134cd40b3d911" integrity sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg== From d639372808148ac873e2364ae99bc6ab6e7029cd Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Fri, 7 Aug 2020 10:56:03 -0400 Subject: [PATCH 18/18] add argparse types --- tfjs-converter/package.json | 1 + tfjs-converter/yarn.lock | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tfjs-converter/package.json b/tfjs-converter/package.json index b3a7a60f2ff..a7e70ce9de7 100644 --- a/tfjs-converter/package.json +++ b/tfjs-converter/package.json @@ -23,6 +23,7 @@ "@rollup/plugin-typescript": "^3.0.0", "@tensorflow/tfjs-backend-cpu": "link:../tfjs-backend-cpu", "@tensorflow/tfjs-core": "link:../tfjs-core", + "@types/argparse": "^1.0.38", "@types/deep-equal": "^1.0.1", "@types/jasmine": "~2.8.6", "@types/long": "~3.0.32", diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index eb4ca44d57c..afc28cded91 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -166,6 +166,11 @@ multimatch "^4.0.0" typescript "~3.9.7" +"@types/argparse@^1.0.38": + version "1.0.38" + resolved "https://registry.yarnpkg.com/@types/argparse/-/argparse-1.0.38.tgz#a81fd8606d481f873a3800c6ebae4f1d768a56a9" + integrity sha512-ebDJ9b0e702Yr7pWgB0jzm+CX4Srzz8RcXtLJDJB+BSccqMa36uyH/zUsSYao5+BD1ytv3k3rPYCq4mAE1hsXA== + "@types/color-name@^1.1.1": version "1.1.1" resolved "https://registry.yarnpkg.com/@types/color-name/-/color-name-1.1.1.tgz#1c1261bbeaa10a8055bbc5d8ab84b7b2afc846a0"