diff --git a/tfjs-backend-cpu/package.json b/tfjs-backend-cpu/package.json index 8fed43cf62d..86efac109d8 100644 --- a/tfjs-backend-cpu/package.json +++ b/tfjs-backend-cpu/package.json @@ -70,5 +70,10 @@ "browser": { "util": false, "crypto": false - } + }, + "sideEffects": [ + "./dist/register_all_kernels.js", + "./dist/base.js", + "./dist/index.js" + ] } diff --git a/tfjs-backend-cpu/src/base.ts b/tfjs-backend-cpu/src/base.ts index 17fa0f9f063..60d137700c1 100644 --- a/tfjs-backend-cpu/src/base.ts +++ b/tfjs-backend-cpu/src/base.ts @@ -17,10 +17,15 @@ /* * base.ts contains all the exports from tfjs-backend-cpu - * that do not trigger side effects. + * without auto-kernel registration */ +import {registerBackend} from '@tensorflow/tfjs-core'; +import {MathBackendCPU} from './backend_cpu'; import * as shared from './shared'; export {MathBackendCPU} from './backend_cpu'; export {version as version_cpu} from './version'; export {shared}; + +// Side effects for default initialization of MathBackendCPU +registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */); diff --git a/tfjs-backend-cpu/src/index.ts b/tfjs-backend-cpu/src/index.ts index 53ca24399c5..abf89246ec3 100644 --- a/tfjs-backend-cpu/src/index.ts +++ b/tfjs-backend-cpu/src/index.ts @@ -15,12 +15,6 @@ * ============================================================================= */ -import {registerBackend} from '@tensorflow/tfjs-core'; -import {MathBackendCPU} from './base'; - -// Side effects for default initialization of MathBackendCPU -registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */); -import './register_all_kernels'; - // All exports from this package should be in base. export * from './base'; +import './register_all_kernels'; diff --git a/tfjs-backend-webgl/package.json b/tfjs-backend-webgl/package.json index 64b91654b52..f2de6562434 100644 --- a/tfjs-backend-webgl/package.json +++ b/tfjs-backend-webgl/package.json @@ -80,5 +80,11 @@ "browser": { "util": false, "crypto": false - } + }, + "sideEffects": [ + "./dist/register_all_kernels.js", + "./dist/flags_webgl.js", + "./dist/base.js", + "./dist/index.js" + ] } diff --git a/tfjs-backend-webgl/src/base.ts b/tfjs-backend-webgl/src/base.ts new file mode 100644 index 00000000000..f2412eab16a --- /dev/null +++ b/tfjs-backend-webgl/src/base.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// base.ts is the webgl backend without auto kernel registration. + +import {device_util, registerBackend} from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from './backend_webgl'; +export {version as version_webgl} from './version'; + +if (device_util.isBrowser()) { + registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */); +} + +// Export webgl utilities +export * from './webgl'; + +// Export forceHalfFlost under webgl namespace for the union bundle. +import {forceHalfFloat} from './webgl'; +export const webgl = {forceHalfFloat}; diff --git a/tfjs-backend-webgl/src/index.ts b/tfjs-backend-webgl/src/index.ts index 14f5863c480..abf89246ec3 100644 --- a/tfjs-backend-webgl/src/index.ts +++ b/tfjs-backend-webgl/src/index.ts @@ -15,18 +15,6 @@ * ============================================================================= */ -import {device_util, registerBackend} from '@tensorflow/tfjs-core'; -import {MathBackendWebGL} from './backend_webgl'; -export {version as version_webgl} from './version'; - -if (device_util.isBrowser()) { - registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */); -} +// All exports from this package should be in base. +export * from './base'; import './register_all_kernels'; - -// Export webgl utilities -export * from './webgl'; - -// Export forceHalfFlost under webgl namespace for the union bundle. -import {forceHalfFloat} from './webgl'; -export const webgl = {forceHalfFloat}; diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts index c5b51211eef..0ac71e981c0 100644 --- a/tfjs-backend-webgl/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -18,7 +18,7 @@ // Import shared functionality from tfjs-backend-cpu without triggering // side effects. // tslint:disable-next-line: no-imports-from-dist -import {shared} from '@tensorflow/tfjs-backend-cpu/dist/base'; +import * as shared from '@tensorflow/tfjs-backend-cpu/dist/shared'; const {maxImpl: maxImplCPU, transposeImpl: transposeImplCPU} = shared; diff --git a/tfjs-core/package.json b/tfjs-core/package.json index 1d73f344c78..74ef2a84f50 100644 --- a/tfjs-core/package.json +++ b/tfjs-core/package.json @@ -91,5 +91,15 @@ "node-fetch": false, "util": false, "crypto": false - } + }, + "sideEffects": [ + "./dist/index.js", + "./dist/base_side_effects.js", + "./dist/engine.js", + "./dist/flags.js", + "./dist/platforms/*.js", + "./dist/register_all_gradients.js", + "./dist/public/chained_ops/*.js", + "./dist/io/*.js" + ] } diff --git a/tfjs-core/src/base.ts b/tfjs-core/src/base.ts new file mode 100644 index 00000000000..831f00dee0a --- /dev/null +++ b/tfjs-core/src/base.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// base.ts is tfjs-core without auto registration of gradients or chained ops. + +/** + * @fileoverview + * @suppress {partialAlias} Optimization disabled due to passing the module + * object into a function below: + * + * import * as ops from './ops/ops'; + * setOpHandler(ops); + */ + +// Serialization. +import * as io from './io/io'; +import * as math from './math'; +import * as browser from './ops/browser'; +import * as gather_util from './ops/gather_nd_util'; +import * as scatter_util from './ops/scatter_nd_util'; +import * as slice_util from './ops/slice_util'; +import * as serialization from './serialization'; +import {setOpHandler} from './tensor'; +import * as tensor_util from './tensor_util'; +import * as test_util from './test_util'; +import * as util from './util'; +import {version} from './version'; + +export {InferenceModel, MetaGraph, MetaGraphInfo, ModelPredictConfig, ModelTensorInfo, SavedModelTensorInfo, SignatureDef, SignatureDefInfo} from './model_types'; +// Optimizers. +export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; +export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; +export {AdamOptimizer} from './optimizers/adam_optimizer'; +export {AdamaxOptimizer} from './optimizers/adamax_optimizer'; +export {MomentumOptimizer} from './optimizers/momentum_optimizer'; +export {Optimizer} from './optimizers/optimizer'; +export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; +export {SGDOptimizer} from './optimizers/sgd_optimizer'; +export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; +export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; +export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types'; + +export * from './ops/ops'; +export {Reduction} from './ops/loss_ops_utils'; + +export * from './train'; +export * from './globals'; +export * from './kernel_registry'; +export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; + +export {TimingInfo, MemoryInfo, ForwardFunc} from './engine'; +export {Environment, env, ENV} from './environment'; +export {Platform} from './platforms/platform'; + +export {version as version_core}; + +// Top-level method exports. +export {nextFrame} from './browser_util'; + +// Second level exports. +import * as backend_util from './backends/backend_util'; +import * as device_util from './device_util'; +export { + browser, + io, + math, + serialization, + test_util, + util, + backend_util, + tensor_util, + slice_util, + gather_util, + scatter_util, + device_util +}; + +import * as kernel_impls from './backends/kernel_impls'; +export {kernel_impls}; +// Backend specific. +export {KernelBackend, BackendTimingInfo, DataMover, DataStorage} from './backends/backend'; + +import * as ops from './ops/ops'; +setOpHandler(ops); + +// Export all kernel names / info. +export * from './kernel_names'; diff --git a/tfjs-core/src/base_side_effects.ts b/tfjs-core/src/base_side_effects.ts new file mode 100644 index 00000000000..f645809a7c1 --- /dev/null +++ b/tfjs-core/src/base_side_effects.ts @@ -0,0 +1,27 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// Required side effectful code for tfjs-core (in any build) + +// Engine is the global singleton that needs to be initialized before the rest +// of the app. +import './engine'; +// Register backend-agnostic flags. +import './flags'; +// Register platforms +import './platforms/platform_browser'; +import './platforms/platform_node'; diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 49a95123436..52090c4e992 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -15,98 +15,12 @@ * ============================================================================= */ -/** - * @fileoverview - * @suppress {partialAlias} Optimization disabled due to passing the module - * object into a function below: - * - * import * as ops from './ops/ops'; - * setOpHandler(ops); - */ +// Required side effectful code. +import './base_side_effects'; +// All exports from this package should be in base. +export * from './base'; -// Engine is the global singleton that needs to be initialized before the rest -// of the app. -import './engine'; -// Register backend-agnostic flags. -import './flags'; // Register all the gradients. import './register_all_gradients'; -import './platforms/platform_browser'; -import './platforms/platform_node'; - -// Serialization. -import * as io from './io/io'; -import * as math from './math'; -import * as browser from './ops/browser'; -import * as gather_util from './ops/gather_nd_util'; -import * as scatter_util from './ops/scatter_nd_util'; -import * as slice_util from './ops/slice_util'; -import * as serialization from './serialization'; -import {setOpHandler} from './tensor'; -import * as tensor_util from './tensor_util'; -import * as test_util from './test_util'; -import * as util from './util'; -import {version} from './version'; - -export {InferenceModel, MetaGraph, MetaGraphInfo, ModelPredictConfig, ModelTensorInfo, SavedModelTensorInfo, SignatureDef, SignatureDefInfo} from './model_types'; -// Optimizers. -export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; -export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; -export {AdamOptimizer} from './optimizers/adam_optimizer'; -export {AdamaxOptimizer} from './optimizers/adamax_optimizer'; -export {MomentumOptimizer} from './optimizers/momentum_optimizer'; -export {Optimizer} from './optimizers/optimizer'; -export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; -export {SGDOptimizer} from './optimizers/sgd_optimizer'; -export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; -export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; -export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types'; - -export * from './ops/ops'; -export {Reduction} from './ops/loss_ops_utils'; - -export * from './train'; -export * from './globals'; -export * from './kernel_registry'; -export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; - -export {TimingInfo, MemoryInfo, ForwardFunc} from './engine'; -export {Environment, env, ENV} from './environment'; -export {Platform} from './platforms/platform'; - -export {version as version_core}; - -// Top-level method exports. -export {nextFrame} from './browser_util'; - -// Second level exports. -import * as backend_util from './backends/backend_util'; -import * as device_util from './device_util'; -export { - browser, - io, - math, - serialization, - test_util, - util, - backend_util, - tensor_util, - slice_util, - gather_util, - scatter_util, - device_util -}; - -import * as kernel_impls from './backends/kernel_impls'; -export {kernel_impls}; -// Backend specific. -export {KernelBackend, BackendTimingInfo, DataMover, DataStorage} from './backends/backend'; - -import * as ops from './ops/ops'; -setOpHandler(ops); - -// Export all kernel names / info. -export * from './kernel_names'; - // Import all op chainers and add type info to Tensor. import './public/chained_ops/register_all_chained_ops'; diff --git a/tfjs-core/src/ops/browser.ts b/tfjs-core/src/ops/browser.ts index 22aba683829..e060db10fa1 100644 --- a/tfjs-core/src/ops/browser.ts +++ b/tfjs-core/src/ops/browser.ts @@ -23,6 +23,9 @@ import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {PixelData, TensorLike} from '../types'; +import {cast} from './cast'; +import {max} from './max'; +import {min} from './min'; import {op} from './operation'; import {tensor3d} from './tensor3d'; @@ -177,7 +180,7 @@ export async function toPixels( if (!(img instanceof Tensor)) { // Assume int32 if user passed a native array. const originalImgTensor = $img; - $img = originalImgTensor.toInt(); + $img = cast(originalImgTensor, 'int32'); originalImgTensor.dispose(); } if ($img.rank !== 2 && $img.rank !== 3) { @@ -194,26 +197,26 @@ export async function toPixels( } const data = await $img.data(); - const minTensor = $img.min(); - const maxTensor = $img.max(); + const minTensor = min($img); + const maxTensor = max($img); const vals = await Promise.all([minTensor.data(), maxTensor.data()]); const minVals = vals[0]; const maxVals = vals[1]; - const min = minVals[0]; - const max = maxVals[0]; + const minVal = minVals[0]; + const maxVal = maxVals[0]; minTensor.dispose(); maxTensor.dispose(); if ($img.dtype === 'float32') { - if (min < 0 || max > 1) { + if (minVal < 0 || maxVal > 1) { throw new Error( `Tensor values for a float32 Tensor must be in the ` + - `range [0 - 1] but got range [${min} - ${max}].`); + `range [0 - 1] but got range [${minVal} - ${maxVal}].`); } } else if ($img.dtype === 'int32') { - if (min < 0 || max > 255) { + if (minVal < 0 || maxVal > 255) { throw new Error( `Tensor values for a int32 Tensor must be in the ` + - `range [0 - 255] but got range [${min} - ${max}].`); + `range [0 - 255] but got range [${minVal} - ${maxVal}].`); } } else { throw new Error( diff --git a/tfjs/package.json b/tfjs/package.json index 97c0c28745b..63683a0df16 100644 --- a/tfjs/package.json +++ b/tfjs/package.json @@ -14,6 +14,9 @@ "url": "https://github.com/tensorflow/tfjs.git" }, "license": "Apache-2.0", + "bin": { + "tfjs-custom-bundle": "dist/tools/custom_bundle/cli.js" + }, "devDependencies": { "@babel/core": "^7.9.0", "@babel/preset-env": "^7.9.5", @@ -67,12 +70,15 @@ "build-backend-webgl-ci": "cd ../tfjs-backend-webgl && yarn && yarn build-ci", "build-deps": "yarn build-core && yarn build-layers && yarn build-converter && yarn build-data && yarn build-backend-cpu && yarn build-backend-webgl", "build-deps-ci": "yarn build-core-ci && yarn build-layers-ci && yarn build-converter-ci && yarn build-data-ci && yarn build-backend-cpu-ci && yarn build-backend-webgl-ci", + "build-cli": "tsc --project ./tools/custom_bundle/tsconfig.json && chmod +x ./dist/tools/custom_bundle/cli.js", + "run-custom-build": "ts-node -s ./src/tools/custom_bundle/cli.ts", "build-npm": "./scripts/build-npm.sh", "link-local": "yalc link", "publish-local": "yarn build-npm && yalc push", "publish-npm": "npm publish", "lint": "tslint -p . -t verbose", "test": "yarn && yarn build-deps && yarn build && karma start", + "test-tools": "ts-node --project ./tools/custom_bundle/tsconfig.json run_tools_tests.ts", "test-ci": "./scripts/test-ci.sh" }, "dependencies": { @@ -82,6 +88,8 @@ "@tensorflow/tfjs-core": "link:../tfjs-core", "@tensorflow/tfjs-data": "link:../tfjs-data", "@tensorflow/tfjs-layers": "link:../tfjs-layers", + "argparse": "^1.0.10", + "chalk": "^4.1.0", "core-js": "3", "regenerator-runtime": "^0.13.5" } diff --git a/tfjs/run_tools_tests.ts b/tfjs/run_tools_tests.ts new file mode 100644 index 00000000000..cd7e46bc933 --- /dev/null +++ b/tfjs/run_tools_tests.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// tslint:disable-next-line:no-require-imports +const jasmineCtor = require('jasmine'); +// tslint:disable-next-line:no-require-imports + +Error.stackTraceLimit = Infinity; + +process.on('unhandledRejection', e => { + throw e; +}); +const toolsTests = './tools/**/*_test.ts'; + +const runner = new jasmineCtor(); +runner.loadConfig({spec_files: [toolsTests], random: false}); + +runner.execute(); diff --git a/tfjs/scripts/test-ci.sh b/tfjs/scripts/test-ci.sh index 85f5fd0cb50..8316f8414d6 100755 --- a/tfjs/scripts/test-ci.sh +++ b/tfjs/scripts/test-ci.sh @@ -18,6 +18,7 @@ set -e yarn karma start --browsers='bs_firefox_mac,bs_chrome_mac' --singleRun +yarn test-tools # cd integration_tests # yarn benchmark-cloud # Reinstall the following line once https://github.com/tensorflow/tfjs/pull/1663 diff --git a/tfjs/tools/custom_bundle/cli.ts b/tfjs/tools/custom_bundle/cli.ts new file mode 100755 index 00000000000..d94fc76ca14 --- /dev/null +++ b/tfjs/tools/custom_bundle/cli.ts @@ -0,0 +1,113 @@ +#!/usr/bin/env node + +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Entry point for cli tool to build custom tfjs bundles + */ + +import * as argparse from 'argparse'; +import * as chalk from 'chalk'; +import * as fs from 'fs'; + +import {getCustomModuleString} from './custom_module'; +import {CustomTFJSBundleConfig, SupportedBackends} from './types'; +import {esmModuleProvider} from './esm_module_provider'; + +const DEFAULT_CUSTOM_BUNDLE_ARGS: Partial = { + entries: [], + models: [], + kernels: [], + forwardModeOnly: true, + backends: ['cpu', 'webgl'], +}; + +const parser = new argparse.ArgumentParser(); +parser.addArgument( + '--config', {help: 'path to custom bundle config file.', required: true}); + +function bail(errorMsg: string) { + console.log(chalk.red(errorMsg)); + process.exit(1); +} + +function validateArgs(): CustomTFJSBundleConfig { + const args = parser.parseArgs(); + const configFilePath = args.config; + if (!fs.existsSync(configFilePath)) { + bail(`Error: config file does not exist at ${configFilePath}`); + } + let config; + try { + config = JSON.parse(fs.readFileSync(configFilePath, 'utf-8')); + } catch (error) { + bail(`Error could not read/parse JSON config file. \n ${error.message}`); + } + + if (config.outputPath == null) { + bail('Error: config must specify "outputPath" property'); + } + + console.log(`Using custom bundle configuration from ${ + configFilePath}. Final config:`); + const replacer: null = null; + const space = 2; + console.log(`${JSON.stringify(config, replacer, space)}\n`); + + const finalConfig = Object.assign({}, DEFAULT_CUSTOM_BUNDLE_ARGS, config); + + if (finalConfig.entries.length !== 0) { + bail('Error: config.entries not yet supported'); + } + + if (finalConfig.models.length !== 0) { + bail('Error: config.models not yet supported'); + } + + for (const requestedBackend of finalConfig.backends) { + if (requestedBackend !== SupportedBackends.cpu && + requestedBackend !== SupportedBackends.webgl) { + bail(`Error: Unsupported backend specified '${requestedBackend}'`); + } + } + + return finalConfig; +} + +function getKernelNamesForConfig(config: CustomTFJSBundleConfig) { + // Later on this will do a union of kernels from entries, models and kernels, + // (and kernels used by the converter itself) Currently we only support + // directly listing kernels. remember that this also needs to handle + // kernels used by gradients if forwardModeOnly is false. + return config.kernels; +} +function produceCustomTFJSModule( + kernels: string[], backends: string[], forwardModeOnly: boolean, + outputPath: string) { + const moduleStr = getCustomModuleString( + kernels, backends, forwardModeOnly, esmModuleProvider); + + console.log(`Writing custom tfjs module to ${outputPath}`); + fs.writeFileSync(outputPath, moduleStr); +} + +const customConfig = validateArgs(); +const kernelsToInclude = getKernelNamesForConfig(customConfig); +produceCustomTFJSModule( + kernelsToInclude, customConfig.backends, customConfig.forwardModeOnly, + customConfig.outputPath); diff --git a/tfjs/tools/custom_bundle/custom_module.ts b/tfjs/tools/custom_bundle/custom_module.ts new file mode 100644 index 00000000000..9725e9f6483 --- /dev/null +++ b/tfjs/tools/custom_bundle/custom_module.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ModuleProvider} from './types'; + +export function getCustomModuleString( + kernels: string[], + backends: string[], + forwardModeOnly: boolean, + moduleProvider: ModuleProvider, + ): string { + const result: string[] = []; + + addLine(result, moduleProvider.importCoreStr()); + addLine(result, moduleProvider.importConverterStr()); + + for (const backend of backends) { + addLine(result, `\n//backend = ${backend}`); + addLine(result, moduleProvider.importBackendStr(backend)); + for (const kernelName of kernels) { + const kernelImport = moduleProvider.importKernelStr(kernelName, backend); + addLine(result, kernelImport.importStatement); + addLine(result, registerKernelStr(kernelImport.kernelConfigId)); + } + } + + if (!forwardModeOnly) { + addLine(result, `\n//Gradients`); + for (const kernelName of kernels) { + const gradImport = moduleProvider.importGradientConfigStr(kernelName); + addLine(result, gradImport.importStatement); + addLine(result, registerGradientConfigStr(gradImport.gradConfigId)); + } + } + + return result.join('\n'); +} + +function addLine(target: string[], line: string) { + target.push(line); +} + +function registerKernelStr(kernelConfigId: string) { + return `registerKernel(${kernelConfigId});`; +} + +function registerGradientConfigStr(gradConfigId: string) { + return `registerGradient(${gradConfigId});`; +} diff --git a/tfjs/tools/custom_bundle/custom_module_test.ts b/tfjs/tools/custom_bundle/custom_module_test.ts new file mode 100644 index 00000000000..ce7549fc738 --- /dev/null +++ b/tfjs/tools/custom_bundle/custom_module_test.ts @@ -0,0 +1,165 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getCustomModuleString} from './custom_module'; +import {ModuleProvider} from './types'; + +const mockModuleProvider: ModuleProvider = { + importCoreStr: () => 'import CORE', + importConverterStr: () => 'import CONVERTER', + importBackendStr: (name: string) => `import BACKEND ${name}`, + importKernelStr: (kernelName: string, backend: string) => ({ + importStatement: `import KERNEL ${kernelName} from BACKEND ${backend}`, + kernelConfigId: `${kernelName}_${backend}` + }), + importGradientConfigStr: (kernel: string) => ({ + importStatement: `import GRADIENT ${kernel}`, + gradConfigId: `${kernel}_GRAD_CONFIG`, + }), +}; + +describe('ESM Module Provider forwardModeOnly=true', () => { + const forwardModeOnly = true; + it('one kernel, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd'], forwardModeOnly, mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); + + it('one kernel, two backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + + expect(res).toContain('import BACKEND SlowBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd'); + expect(res).toContain('registerKernel(MathKrnl_SlowBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); + + it('two kernels, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + expect(res).toContain('registerKernel(MathKrn2_FastBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); + + it('two kernels, two backends', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + expect(res).toContain('registerKernel(MathKrn2_FastBcknd)'); + + expect(res).toContain('import BACKEND SlowBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd'); + expect(res).toContain('import KERNEL MathKrn2 from BACKEND SlowBcknd'); + expect(res).toContain('registerKernel(MathKrnl_SlowBcknd)'); + expect(res).toContain('registerKernel(MathKrn2_SlowBcknd)'); + + expect(res).not.toContain('GRADIENT'); + }); +}); + +describe('ESM Module Provider forwardModeOnly=false', () => { + const forwardModeOnly = false; + + it('one kernel, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd'], forwardModeOnly, mockModuleProvider); + + expect(res).toContain('import CORE'); + expect(res).toContain('import CONVERTER'); + + expect(res).toContain('import BACKEND FastBcknd'); + expect(res).toContain('import KERNEL MathKrnl from BACKEND FastBcknd'); + expect(res).toContain('registerKernel(MathKrnl_FastBcknd)'); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + }); + + it('one kernel, two backend', () => { + const res = getCustomModuleString( + ['MathKrnl'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + + const gradIndex = res.indexOf('GRADIENT'); + expect(res.indexOf('GRADIENT', gradIndex + 1)) + .toBe(-1, `Gradient import appears twice in:\n ${res}`); + }); + + it('two kernels, one backend', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + + expect(res).toContain('import GRADIENT MathKrn2'); + expect(res).toContain('registerGradient(MathKrn2_GRAD_CONFIG)'); + }); + + it('two kernels, two backends', () => { + const res = getCustomModuleString( + ['MathKrnl', 'MathKrn2'], ['FastBcknd', 'SlowBcknd'], forwardModeOnly, + mockModuleProvider); + + expect(res).toContain('import GRADIENT MathKrnl'); + expect(res).toContain('registerGradient(MathKrnl_GRAD_CONFIG)'); + + expect(res).toContain('import GRADIENT MathKrn2'); + expect(res).toContain('registerGradient(MathKrn2_GRAD_CONFIG)'); + }); +}); diff --git a/tfjs/tools/custom_bundle/esm_module_provider.ts b/tfjs/tools/custom_bundle/esm_module_provider.ts new file mode 100644 index 00000000000..24d31cdfc5d --- /dev/null +++ b/tfjs/tools/custom_bundle/esm_module_provider.ts @@ -0,0 +1,81 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ModuleProvider, SupportedBackend} from './types'; + +/** + * A module provider to generate custom esm modules. + */ +export const esmModuleProvider: ModuleProvider = { + importCoreStr() { + return ` +import {registerKernel, registerGradient} from '@tensorflow/tfjs-core/dist/base'; +import '@tensorflow/tfjs-core/dist/base_side_effects'; +export * from '@tensorflow/tfjs-core/dist/base'; + `; + }, + + importConverterStr() { + return `export * from '@tensorflow/tfjs-converter';`; + }, + + importBackendStr(backend: SupportedBackend) { + const backendPkg = getBackendPath(backend); + return `export * from '${backendPkg}/dist/base';`; + }, + + importKernelStr(kernelName: string, backend: SupportedBackend) { + // TODO(yassogba) validate whether the target file referenced by + // importStatement exists and warn the user if it doesn't. That could happen + // here or in an earlier validation phase that uses this function + + const backendPkg = getBackendPath(backend); + const kernelConfigId = `${kernelName}_${backend}`; + const importStatement = + `import {${kernelNameToVariableName(kernelName)}Config as ${ + kernelConfigId}} from '${backendPkg}/dist/kernels/${kernelName}';`; + + return {importStatement, kernelConfigId}; + }, + + importGradientConfigStr(kernelName: string) { + // TODO(yassogba) validate whether the target file referenced by + // importStatement exists and warn the user if it doesn't. That could happen + // here or in an earlier validation phase that uses this function + + const gradConfigId = `${kernelNameToVariableName(kernelName)}GradConfig`; + const importStatement = + `import {${gradConfigId}} from '@tensorflow/tfjs-core/dist/gradients/${ + kernelName}_grad';`; + + return {importStatement, gradConfigId}; + } +}; + +function getBackendPath(backend: SupportedBackend) { + switch (backend) { + case 'cpu': + return '@tensorflow/tfjs-backend-cpu'; + case 'webgl': + return '@tensorflow/tfjs-backend-webgl'; + default: + throw new Error(`Unsupported backend ${backend}`); + } +} + +function kernelNameToVariableName(kernelName: string) { + return kernelName.charAt(0).toLowerCase() + kernelName.slice(1); +} diff --git a/tfjs/tools/custom_bundle/esm_module_provider_test.ts b/tfjs/tools/custom_bundle/esm_module_provider_test.ts new file mode 100644 index 00000000000..aac2acf5c6b --- /dev/null +++ b/tfjs/tools/custom_bundle/esm_module_provider_test.ts @@ -0,0 +1,64 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {esmModuleProvider} from './esm_module_provider'; + +describe('ESM Module Provider', () => { + it('importCoreStr', () => { + const res = esmModuleProvider.importCoreStr(); + expect(res).toContain( + // tslint:disable-next-line: max-line-length + `import {registerKernel, registerGradient} from '@tensorflow/tfjs-core/dist/base'`); + expect(res).toContain( + `import '@tensorflow/tfjs-core/dist/base_side_effects';`); + expect(res).toContain(`export * from '@tensorflow/tfjs-core/dist/base';`); + }); + + it('importConverterStr', () => { + const res = esmModuleProvider.importConverterStr(); + expect(res).toBe(`export * from '@tensorflow/tfjs-converter';`); + }); + + it('importBackendStr cpu', () => { + const res = esmModuleProvider.importBackendStr('cpu'); + expect(res).toBe(`export * from '@tensorflow/tfjs-backend-cpu/dist/base';`); + }); + + it('importBackendStr webgl', () => { + const res = esmModuleProvider.importBackendStr('webgl'); + expect(res).toBe( + `export * from '@tensorflow/tfjs-backend-webgl/dist/base';`); + }); + + it('importKernelStr Max cpu', () => { + const res = esmModuleProvider.importKernelStr('Max', 'cpu'); + expect(res.importStatement).toContain('import {maxConfig as Max_cpu}'); + expect(res.importStatement) + .toContain(`from '@tensorflow/tfjs-backend-cpu/dist/kernels/Max'`); + + expect(res.kernelConfigId).toBe('Max_cpu'); + }); + + it('importGradientConfigStr Max', () => { + const res = esmModuleProvider.importGradientConfigStr('Max'); + expect(res.importStatement).toContain('import {maxGradConfig}'); + expect(res.importStatement) + .toContain(`from '@tensorflow/tfjs-core/dist/gradients/Max_grad'`); + + expect(res.gradConfigId).toBe('maxGradConfig'); + }); +}); diff --git a/tfjs/tools/custom_bundle/tsconfig.json b/tfjs/tools/custom_bundle/tsconfig.json new file mode 100644 index 00000000000..a981dfa447b --- /dev/null +++ b/tfjs/tools/custom_bundle/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../../tsconfig.test", + "include": [ + "../../tools/" + ], + "compilerOptions": { + "module": "commonjs", + "outDir": "../../dist/tools/custom_bundle" + } +} diff --git a/tfjs/tools/custom_bundle/types.ts b/tfjs/tools/custom_bundle/types.ts new file mode 100644 index 00000000000..7fb3989b776 --- /dev/null +++ b/tfjs/tools/custom_bundle/types.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export enum SupportedBackends { + cpu = 'cpu', + webgl = 'webgl' +} +export type SupportedBackend = keyof typeof SupportedBackends; + +export interface CustomTFJSBundleConfig { + entries?: string[]; // paths to javascript files to walk + models?: string[]; // paths to model.json files to walk + backends?: SupportedBackend[]; // backends to include/use kernels from + forwardModeOnly?: boolean; // whether to drop gradients + outputPath: string; // path to output file + kernels?: string[]; // Kernels to include +} + +// Interface for an object that can provide functionality to generate +// a correct custom module for that build environment (e.g. node vs g3). +export interface ModuleProvider { + importCoreStr: () => string; + importConverterStr: () => string; + importBackendStr: (backendPkg: string) => string; + importKernelStr: (kernelName: string, backend: string) => { + importStatement: string, kernelConfigId: string + }; + importGradientConfigStr: (kernelName: string) => { + importStatement: string, gradConfigId: string + }; +} diff --git a/tfjs/yarn.lock b/tfjs/yarn.lock index c8aa0e0d28d..d66381a954b 100644 --- a/tfjs/yarn.lock +++ b/tfjs/yarn.lock @@ -975,7 +975,7 @@ ansi-styles@^3.2.1: dependencies: color-convert "^1.9.0" -ansi-styles@^4.0.0: +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.2.1" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.2.1.tgz#90ae75c424d008d2624c5bf29ead3177ebfcf359" integrity sha512-9VGjrMsG1vePxcSweQsN20KY/c4zN0h9fLjqAbwbPfahM3t+NL+M9HC8xeXG2I8pX5NoamTGNuomEUFI7fcUjA== @@ -1001,7 +1001,7 @@ arg@^4.1.0: resolved "https://registry.yarnpkg.com/arg/-/arg-4.1.3.tgz#269fc7ad5b8e42cb63c896d5666017261c144089" integrity sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA== -argparse@^1.0.7: +argparse@^1.0.10, argparse@^1.0.7: version "1.0.10" resolved "https://registry.yarnpkg.com/argparse/-/argparse-1.0.10.tgz#bcd6791ea5ae09725e17e5ad988134cd40b3d911" integrity sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg== @@ -1446,6 +1446,14 @@ chalk@^2.0.0, chalk@^2.3.0, chalk@^2.4.1: escape-string-regexp "^1.0.5" supports-color "^5.3.0" +chalk@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.0.tgz#4e14870a618d9e2edd97dd8345fd9d9dc315646a" + integrity sha512-qwx12AxXe2Q5xQ43Ac//I6v5aXTipYrSESdOgzrN+9XjgEpyjpKuvSGaN4qE93f7TQTlerQQ8S+EQ0EyDoVL1A== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + chokidar@^3.0.0: version "3.4.0" resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.4.0.tgz#b30611423ce376357c765b9b8f904b9fba3c0be8" @@ -2305,6 +2313,11 @@ has-flag@^3.0.0: resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd" integrity sha1-tdRU3CGZriJWmfNGfloH87lVuv0= +has-flag@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" + integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== + has-symbols@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.0.tgz#ba1a8f1af2a0fc39650f5c850367704122063b44" @@ -4261,6 +4274,13 @@ supports-color@^6.1.0: dependencies: has-flag "^3.0.0" +supports-color@^7.1.0: + version "7.1.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.1.0.tgz#68e32591df73e25ad1c4b49108a2ec507962bfd1" + integrity sha512-oRSIpR8pxT1Wr2FquTNnGet79b3BWljqOuoW/h4oBhxJ/HUbX5nX6JSruTkvXDCFMwDPvsaTTbvMLKZWSy0R5g== + dependencies: + has-flag "^4.0.0" + terser@^4.6.2: version "4.6.11" resolved "https://registry.yarnpkg.com/terser/-/terser-4.6.11.tgz#12ff99fdd62a26de2a82f508515407eb6ccd8a9f"