From a0d761e40c406328bcb77ca202f96f333a4d2e7b Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 10:46:15 -0400 Subject: [PATCH 1/7] add script to extract ops to new files update touch_modular_files script to emit a more complete gradient file template. --- tfjs-core/package.json | 1 + tfjs-core/scripts/extract_op.ts | 145 +++++++++++++++ tfjs-core/scripts/touch_modular_op_files.ts | 73 ++++++-- tfjs-core/yarn.lock | 195 +++++++++++++++++++- 4 files changed, 393 insertions(+), 21 deletions(-) create mode 100644 tfjs-core/scripts/extract_op.ts diff --git a/tfjs-core/package.json b/tfjs-core/package.json index 3220b04d1b3..1d73f344c78 100644 --- a/tfjs-core/package.json +++ b/tfjs-core/package.json @@ -42,6 +42,7 @@ "rollup-plugin-terser": "~5.3.0", "rollup-plugin-visualizer": "~3.3.2", "shelljs": "~0.8.3", + "ts-morph": "^7.1.2", "ts-node": "~8.8.2", "tslint": "~5.11.0", "tslint-no-circular-imports": "~0.5.0", diff --git a/tfjs-core/scripts/extract_op.ts b/tfjs-core/scripts/extract_op.ts new file mode 100644 index 00000000000..2fdf1b8edbe --- /dev/null +++ b/tfjs-core/scripts/extract_op.ts @@ -0,0 +1,145 @@ +import * as argparse from 'argparse'; +import {execSync} from 'child_process'; +import * as path from 'path'; +import {FunctionDeclaration, ImportDeclaration, Project, SourceFile, VariableStatement} from 'ts-morph'; + +const parser = new argparse.ArgumentParser(); + +parser.addArgument( + '--op_file', + {help: 'path to op file. e.g. src/ops/unary_ops.ts', required: true}); +parser.addArgument('--ops', { + help: 'comma seprated list of ops to extract (e.g. tanh,tan).' + + ' Skip this param to extract all ops from the file.', + defaultValue: [], +}); + +// initialize +const project = new Project({}); + +interface OpInfo { + variableStatement: VariableStatement; + opFuncName: string; + opIdentifier: string; + opFunc: FunctionDeclaration; +} + +function getImports(sourceFile: SourceFile) { + return sourceFile.getImportDeclarations(); +} + +function getOpExports(file: SourceFile): OpInfo[] { + const variables = file.getVariableStatements(); + const opFuncRegex = /op\(\{(.*_)\}\)/; + const exported = variables.filter( + v => v.isExported() && v.getFullText().match(opFuncRegex)); + const opInfo = exported.map(variable => { + const declaration = variable.getDeclarations()[0]; + const opFuncName = variable.getFullText().match(opFuncRegex)[1]; + const opFunc = getOpFunc(file, opFuncName); + if (opFunc == null) { + console.warn(`Warning: could not find implementation function for ${ + declaration.getName()}`); + return null; + } + return { + variableStatement: variable, + // string with exported name of op + opIdentifier: declaration.getName(), + // string with name of the function that actually implements the op + opFuncName, + // function that implements the op + opFunc, + }; + }); + return opInfo.filter(op => op != null); +} + +function getOpFunc(sourceFile: SourceFile, opFuncName: string) { + return sourceFile.getFunction(opFuncName); +} + +function toSnakeCase(str: string) { + // add exceptions here. + if (str === 'isNaN') { + return 'is_nan'; + } + return str.replace(/[A-Z]/g, (s: string) => `_${s.toLowerCase()}`); +} + +async function moveToNewFile( + opInfo: OpInfo, imports: ImportDeclaration[], sourceFile: SourceFile) { + // + // Move code to a new file + // + const targetFp = `src/ops/${toSnakeCase(opInfo.opIdentifier)}.ts`; + const newOpFile = project.createSourceFile(targetFp, (writer) => { + // By using getFullText here we will also get the copyright notice at the + // begining of the file + const importsStr = imports.map(i => i.getFullText()).join(''); + const functionStr = opInfo.opFunc.getFullText(); + const exportString = opInfo.variableStatement.getFullText(); + const contents = [importsStr, functionStr, exportString].join(''); + writer.write(contents); + }, {overwrite: true}); + newOpFile.fixUnusedIdentifiers(); + await newOpFile.save(); + + // make a test file + // create a test file + const testFilePath = `src/ops/${args.op}_test.ts`; + const command = `touch ${testFilePath}`; + execSync(command); + + // Add export to ops file. + const opsExportFile = project.getSourceFile('src/ops/ops.ts'); + opsExportFile.addExportDeclaration({ + namedExports: [opInfo.opIdentifier], + moduleSpecifier: `./${path.basename(targetFp, '.ts')}`, + }); + + await opsExportFile.save(); +} + +async function run(filePath: string, ops: string[]) { + console.log('ops', ops); + project.addSourceFilesAtPaths(filePath); + // add the ops export file to the project + project.addSourceFilesAtPaths('src/ops/ops.ts'); + const opFile = project.getSourceFile(filePath); + const imports = getImports(opFile); + const opExports = getOpExports(opFile); + + opExports.forEach(async o => { + if (ops.length === 0 || ops.indexOf(o.opIdentifier) !== -1) { + await moveToNewFile(o, imports, opFile); + } + }); + + // Save the ops export file + // const opsExportFile = project.getSourceFile('src/ops/ops.ts'); + // await opsExportFile.save(); + + // Remove the op from the source file and save it + opExports.forEach(async o => { + if (ops.length === 0 || ops.indexOf(o.opIdentifier) !== -1) { + opFile.removeStatement(o.variableStatement.getChildIndex()); + } + }); + opFile.fixUnusedIdentifiers(); + await opFile.save(); +} + +// add source files + +const args = parser.parseArgs(); +let opsToExtract = args.ops; +if (!Array.isArray(opsToExtract)) { + opsToExtract = opsToExtract.split(','); +} +console.log('Extracting from', args.op_file); +if (opsToExtract.length > 0) { + console.log('Only extract: ', opsToExtract); +} + +run(args.op_file, opsToExtract); diff --git a/tfjs-core/scripts/touch_modular_op_files.ts b/tfjs-core/scripts/touch_modular_op_files.ts index d62b214d761..99b3f6e3d78 100644 --- a/tfjs-core/scripts/touch_modular_op_files.ts +++ b/tfjs-core/scripts/touch_modular_op_files.ts @@ -40,11 +40,12 @@ import * as argparse from 'argparse'; import {execSync} from 'child_process'; +import * as fs from 'fs'; const parser = new argparse.ArgumentParser(); parser.addArgument('--op', {help: 'the name of the op'}); -parser.addArgument('--kernel', {help: 'the name of the kernel.'}); +parser.addArgument('--grad', {help: 'the name of the grad.'}); parser.addArgument('--chained', { action: 'storeTrue', defaultValue: false, @@ -55,29 +56,63 @@ async function main() { const args = parser.parseArgs(); console.log('Called touch_modular_op_files with args:', args); - if (args.op == null) { - throw new Error('You must specify an op'); - } - - let filePath = `./src/ops/${args.op}.ts`; - let command = `touch ${filePath}`; - execSync(command); - - // create a test file - filePath = `./src/ops/${args.op}_test.ts`; - command = `touch ${filePath}`; - execSync(command); + if (args.op != null) { + let filePath = `./src/ops/${args.op}.ts`; + let command = `touch ${filePath}`; + execSync(command); - if (args.chained) { - filePath = `./src/public/chained_ops/${args.op}.ts`; + // create a test file + filePath = `./src/ops/${args.op}_test.ts`; command = `touch ${filePath}`; execSync(command); + + if (args.chained) { + filePath = `./src/public/chained_ops/${args.op}.ts`; + command = `touch ${filePath}`; + execSync(command); + } } - if (args.kernel) { - filePath = `./src/gradients/${args.kernel}_grad.ts`; - command = `touch ${filePath}`; - execSync(command); + if (args.grad) { + const downcaseFirstChar = (str: string) => { + return str.charAt(0).toLowerCase() + str.slice(1); + }; + + const gradientFileTemplate = `/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelName, KernelNameAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; + +export const ${downcaseFirstChar(args.grad)}GradConfig: GradConfig = { + kernelName: KernelName, + inputsToSave: [], // UPDATE ME + outputsToSave: [], // UPDATE ME + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const [] = saved; + const {} = attrs as {} as KernelNameAttrs; + return { + }; + } +}; +`; + const filePath = `./src/gradients/${args.grad}_grad.ts`; + fs.writeFileSync(filePath, gradientFileTemplate, {flag: 'a'}); } } diff --git a/tfjs-core/yarn.lock b/tfjs-core/yarn.lock index 21d1f7bef23..cb922f3def6 100644 --- a/tfjs-core/yarn.lock +++ b/tfjs-core/yarn.lock @@ -38,6 +38,35 @@ source-map-support "0.5.9" tsutils "2.27.2" +"@dsherret/to-absolute-glob@^2.0.2": + version "2.0.2" + resolved "https://registry.yarnpkg.com/@dsherret/to-absolute-glob/-/to-absolute-glob-2.0.2.tgz#1f6475dc8bd974cea07a2daf3864b317b1dd332c" + integrity sha1-H2R13IvZdM6gei2vOGSzF7HdMyw= + dependencies: + is-absolute "^1.0.0" + is-negated-glob "^1.0.0" + +"@nodelib/fs.scandir@2.1.3": + version "2.1.3" + resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.3.tgz#3a582bdb53804c6ba6d146579c46e52130cf4a3b" + integrity sha512-eGmwYQn3gxo4r7jdQnkrrN6bY478C3P+a/y72IJukF8LjB6ZHeB3c+Ehacj3sYeSmUXGlnA67/PmbM9CVwL7Dw== + dependencies: + "@nodelib/fs.stat" "2.0.3" + run-parallel "^1.1.9" + +"@nodelib/fs.stat@2.0.3", "@nodelib/fs.stat@^2.0.2": + version "2.0.3" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.3.tgz#34dc5f4cabbc720f4e60f75a747e7ecd6c175bd3" + integrity sha512-bQBFruR2TAwoevBEd/NWMoAAtNGzTRgdrqnYCc7dhzfoNvqPzLyqlEQnzZ3kVnNrSp25iyxE00/3h2fqGAGArA== + +"@nodelib/fs.walk@^1.2.3": + version "1.2.4" + resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.4.tgz#011b9202a70a6366e436ca5c065844528ab04976" + integrity sha512-1V9XOY4rDW0rehzbrcqAmHnz8e7SKvX27gh8Gt2WgB0+pdzdiLV83p72kZPU+jvMbS1qU5mauP2iOvO8rhmurQ== + dependencies: + "@nodelib/fs.scandir" "2.1.3" + fastq "^1.6.0" + "@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": version "1.1.2" resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf" @@ -136,6 +165,18 @@ version "0.0.0" uid "" +"@ts-morph/common@~0.5.1": + version "0.5.1" + resolved "https://registry.yarnpkg.com/@ts-morph/common/-/common-0.5.1.tgz#c85037c9ed420755a68fce613348a94d8ef81a3d" + integrity sha512-0qasHorGK8VfUK20oECpIfmu/B6cwGSNTj2HoNsIKeDE1kB/uCk5jWFHkgBuoZu/3i3ysLOwO9QsFJaRAH65UA== + dependencies: + "@dsherret/to-absolute-glob" "^2.0.2" + fast-glob "^3.2.2" + fs-extra "^9.0.0" + is-negated-glob "^1.0.0" + multimatch "^4.0.0" + typescript "~3.9.2" + "@types/color-name@^1.1.1": version "1.1.1" resolved "https://registry.yarnpkg.com/@types/color-name/-/color-name-1.1.1.tgz#1c1261bbeaa10a8055bbc5d8ab84b7b2afc846a0" @@ -156,6 +197,11 @@ resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.0.tgz#719551d2352d301ac8b81db732acb6bdc28dbdef" integrity sha512-1w52Nyx4Gq47uuu0EVcsHBxZFJgurQ+rTKS3qMHxR1GY2T8c2AJYd6vZoZ9q1rupaDjU0yT+Jc2XTyXkjeMA+Q== +"@types/minimatch@^3.0.3": + version "3.0.3" + resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-3.0.3.tgz#3dca0e3f33b200fc7d1139c0cd96c1268cadfd9d" + integrity sha512-tHq6qdbT9U1IRSGf14CL0pUlULksvY9OZ+5eEgl1N7t+OA3tGvNpxJCzuKQlsNgCVwbAs670L1vcVQi8j9HjnA== + "@types/node-fetch@~2.1.2": version "2.1.7" resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.1.7.tgz#0231559340f6e3f3a0608692077d744c87b5b367" @@ -327,6 +373,11 @@ arr-union@^3.1.0: resolved "https://registry.yarnpkg.com/arr-union/-/arr-union-3.1.0.tgz#e39b09aea9def866a8f206e288af63919bae39c4" integrity sha1-45sJrqne+Gao8gbiiK9jkZuuOcQ= +array-differ@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/array-differ/-/array-differ-3.0.0.tgz#3cbb3d0f316810eafcc47624734237d6aee4ae6b" + integrity sha512-THtfYS6KtME/yIAhKjZ2ul7XI96lQGHRputJQHO80LAWQnuGP4iCIN8vdMRboGbIEYBwU33q8Tch1os2+X0kMg== + array-filter@~0.0.0: version "0.0.1" resolved "https://registry.yarnpkg.com/array-filter/-/array-filter-0.0.1.tgz#7da8cf2e26628ed732803581fd21f67cacd2eeec" @@ -347,11 +398,21 @@ array-reduce@~0.0.0: resolved "https://registry.yarnpkg.com/array-reduce/-/array-reduce-0.0.0.tgz#173899d3ffd1c7d9383e4479525dbe278cab5f2b" integrity sha1-FziZ0//Rx9k4PkR5Ul2+J4yrXys= +array-union@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/array-union/-/array-union-2.1.0.tgz#b798420adbeb1de828d84acd8a2e23d3efe85e8d" + integrity sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw== + arraybuffer.slice@~0.0.7: version "0.0.7" resolved "https://registry.yarnpkg.com/arraybuffer.slice/-/arraybuffer.slice-0.0.7.tgz#3bbc4275dd584cc1b10809b89d4e8b63a69e7675" integrity sha512-wGUIVQXuehL5TCqQun8OW81jGzAWycqzFF8lFp+GOM5BXLYj3bKNsYC4daB7n6XjCqxQA/qgTJ+8ANR3acjrog== +arrify@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/arrify/-/arrify-2.0.1.tgz#c9655e9331e0abcd588d2a7cad7e9956f66701fa" + integrity sha512-3duEwti880xqi4eAMN8AyR4a0ByT90zoYdLlevfrvU43vb0YZwZVfxOgxWrLXXXpyugL0hNZc9G6BiB5B3nUug== + asn1.js@^4.0.0: version "4.10.1" resolved "https://registry.yarnpkg.com/asn1.js/-/asn1.js-4.10.1.tgz#b9c2bf5805f1e64aadeed6df3a2bfafb5a73f5a0" @@ -398,6 +459,11 @@ async@^3.0.1: resolved "https://registry.yarnpkg.com/async/-/async-3.1.0.tgz#42b3b12ae1b74927b5217d8c0016baaf62463772" integrity sha512-4vx/aaY6j/j3Lw3fbCHNWP0pPaTCew3F6F3hYyl/tHs/ndmV1q7NW9T5yuJ2XAGwdQrP+6Wu20x06U4APo/iQQ== +at-least-node@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/at-least-node/-/at-least-node-1.0.0.tgz#602cd4b46e844ad4effc92a8011a3c46e0238dc2" + integrity sha512-+q/t7Ekv1EDY2l6Gda6LLiX14rU9TV20Wa3ofeQmwPFZbOMo9DXrLbOjFaaclkXKWidIaopwAObQDqwWtGUjqg== + babel-code-frame@^6.22.0: version "6.26.0" resolved "https://registry.yarnpkg.com/babel-code-frame/-/babel-code-frame-6.26.0.tgz#63fd43f7dc1e3bb7ce35947db8fe369a3f58c74b" @@ -779,6 +845,11 @@ clone@^1.0.2: resolved "https://registry.yarnpkg.com/clone/-/clone-1.0.4.tgz#da309cc263df15994c688ca902179ca3c7cd7c7e" integrity sha1-2jCcwmPfFZlMaIypAheco8fNfH4= +code-block-writer@^10.1.0: + version "10.1.0" + resolved "https://registry.yarnpkg.com/code-block-writer/-/code-block-writer-10.1.0.tgz#54fc410ebef2af836d9c2314ac40af7d7b37eee9" + integrity sha512-RG9hpXtWFeUWhuUav1YuP/vGcyncW+t90yJLk9fNZs1De2OuHTHKAKThVCokt29PYq5RoJ0QSZaIZ+rvPO23hA== + code-point-at@^1.0.0: version "1.1.0" resolved "https://registry.yarnpkg.com/code-point-at/-/code-point-at-1.1.0.tgz#0d070b4d043a5bea33a2f1a40e2edb3d9a4ccf77" @@ -1352,11 +1423,30 @@ extend@^3.0.0: resolved "https://registry.yarnpkg.com/extend/-/extend-3.0.2.tgz#f8b1136b4071fbd8eb140aff858b1019ec2915fa" integrity sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g== +fast-glob@^3.2.2: + version "3.2.4" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.4.tgz#d20aefbf99579383e7f3cc66529158c9b98554d3" + integrity sha512-kr/Oo6PX51265qeuCYsyGypiO5uJFgBS0jksyG7FUeCyQzNwYnzrNIMR1NXfkZXsMYXYLRAHgISHBz8gQcxKHQ== + dependencies: + "@nodelib/fs.stat" "^2.0.2" + "@nodelib/fs.walk" "^1.2.3" + glob-parent "^5.1.0" + merge2 "^1.3.0" + micromatch "^4.0.2" + picomatch "^2.2.1" + fast-levenshtein@~2.0.6: version "2.0.6" resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha1-PYpcZog6FqMMqGQ+hR8Zuqd5eRc= +fastq@^1.6.0: + version "1.8.0" + resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.8.0.tgz#550e1f9f59bbc65fe185cb6a9b4d95357107f481" + integrity sha512-SMIZoZdLh/fgofivvIkmknUXyPnvxRE3DhtZ5Me3Mrsk5gyPL42F0xr51TdRXskBxHfMp+07bcYzfsYEsSQA9Q== + dependencies: + reusify "^1.0.4" + fill-range@^7.0.1: version "7.0.1" resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" @@ -1437,6 +1527,16 @@ fs-extra@^8.0.1: jsonfile "^4.0.0" universalify "^0.1.0" +fs-extra@^9.0.0: + version "9.0.1" + resolved "https://registry.yarnpkg.com/fs-extra/-/fs-extra-9.0.1.tgz#910da0062437ba4c39fedd863f1675ccfefcb9fc" + integrity sha512-h2iAoN838FqAFJY2/qVpzFXy+EBxfVE220PalAqQLDVsFOHLJrZvut5puAbCdNv6WJk+B8ihI+k0c7JK5erwqQ== + dependencies: + at-least-node "^1.0.0" + graceful-fs "^4.2.0" + jsonfile "^6.0.1" + universalify "^1.0.0" + fs.realpath@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" @@ -1484,7 +1584,7 @@ get-stream@^4.0.0: dependencies: pump "^3.0.0" -glob-parent@~5.1.0: +glob-parent@^5.1.0, glob-parent@~5.1.0: version "5.1.1" resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.1.tgz#b6c1ef417c4e5663ea498f1c45afac6916bbc229" integrity sha512-FnI+VGOpnlGHWZxthPGR+QhR78fuiK0sNLkHQv+bL9fQi57lNNdquIbna/WrfROrolq8GK5Ek6BiMwqL/voRYQ== @@ -1727,6 +1827,14 @@ invert-kv@^2.0.0: resolved "https://registry.yarnpkg.com/invert-kv/-/invert-kv-2.0.0.tgz#7393f5afa59ec9ff5f67a27620d11c226e3eec02" integrity sha512-wPVv/y/QQ/Uiirj/vh3oP+1Ww+AWehmi1g5fFWGPF6IpCBCDVrhgHRMvrLfdYcwDh3QJbGXDW4JAuzxElLSqKA== +is-absolute@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-absolute/-/is-absolute-1.0.0.tgz#395e1ae84b11f26ad1795e73c17378e48a301576" + integrity sha512-dOWoqflvcydARa360Gvv18DZ/gRuHKi2NU/wU5X1ZFzdYfH29nkiNZsF3mp4OJ3H4yo9Mx8A/uAGNzpzPN3yBA== + dependencies: + is-relative "^1.0.0" + is-windows "^1.0.1" + is-arguments@^1.0.4: version "1.0.4" resolved "https://registry.yarnpkg.com/is-arguments/-/is-arguments-1.0.4.tgz#3faf966c7cba0ff437fb31f6250082fcf0448cf3" @@ -1814,6 +1922,11 @@ is-nan@^1.2.1: dependencies: define-properties "^1.1.1" +is-negated-glob@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-negated-glob/-/is-negated-glob-1.0.0.tgz#6910bca5da8c95e784b5751b976cf5a10fee36d2" + integrity sha1-aRC8pdqMleeEtXUbl2z1oQ/uNtI= + is-number@^7.0.0: version "7.0.0" resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" @@ -1840,6 +1953,13 @@ is-regex@^1.0.4: dependencies: has "^1.0.1" +is-relative@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-relative/-/is-relative-1.0.0.tgz#a1bb6935ce8c5dba1e8b9754b9b2dcc020e2260d" + integrity sha512-Kw/ReK0iqwKeu0MITLFuj0jbPAmEiOsIwyIXvvbfa6QfmN9pkD1M+8pdk7Rl/dTKbH34/XBFMbgD4iMJhLQbGA== + dependencies: + is-unc-path "^1.0.0" + is-stream@^1.1.0: version "1.1.0" resolved "https://registry.yarnpkg.com/is-stream/-/is-stream-1.1.0.tgz#12d4a3dd4e68e0b79ceb8dbc84173ae80d91ca44" @@ -1852,11 +1972,23 @@ is-symbol@^1.0.2: dependencies: has-symbols "^1.0.1" +is-unc-path@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-unc-path/-/is-unc-path-1.0.0.tgz#d731e8898ed090a12c352ad2eaed5095ad322c9d" + integrity sha512-mrGpVd0fs7WWLfVsStvgF6iEJnbjDFZh9/emhRDcGWTduTfNHd9CHeUwH3gYIjdbwo4On6hunkztwOaAw0yllQ== + dependencies: + unc-path-regex "^0.1.2" + is-utf8@^0.2.0: version "0.2.1" resolved "https://registry.yarnpkg.com/is-utf8/-/is-utf8-0.2.1.tgz#4b0da1442104d1b336340e80797e865cf39f7d72" integrity sha1-Sw2hRCEE0bM2NA6AeX6GXPOffXI= +is-windows@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + is-wsl@^1.1.0: version "1.1.0" resolved "https://registry.yarnpkg.com/is-wsl/-/is-wsl-1.1.0.tgz#1f16e4aa22b04d1336b66188a66af3c600c3a66d" @@ -1965,6 +2097,15 @@ jsonfile@^4.0.0: optionalDependencies: graceful-fs "^4.1.6" +jsonfile@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/jsonfile/-/jsonfile-6.0.1.tgz#98966cba214378c8c84b82e085907b40bf614179" + integrity sha512-jR2b5v7d2vIOust+w3wtFKZIfpC2pnRmFAhAC/BuweZFQR8qZzxH1OyrQ10HmdVYiXWkYUqPVsz91cG7EL2FBg== + dependencies: + universalify "^1.0.0" + optionalDependencies: + graceful-fs "^4.1.6" + jsonify@~0.0.0: version "0.0.0" resolved "https://registry.yarnpkg.com/jsonify/-/jsonify-0.0.0.tgz#2c74b6ee41d93ca51b7b5aaee8f503631d252a73" @@ -2261,6 +2402,11 @@ merge-stream@^2.0.0: resolved "https://registry.yarnpkg.com/merge-stream/-/merge-stream-2.0.0.tgz#52823629a14dd00c9770fb6ad47dc6310f2c1f60" integrity sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w== +merge2@^1.3.0: + version "1.4.1" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" + integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== + micromatch@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.2.tgz#4fcb0999bf9fbc2fcbdd212f6d629b9a56c39259" @@ -2360,6 +2506,17 @@ ms@^2.1.1: resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009" integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w== +multimatch@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/multimatch/-/multimatch-4.0.0.tgz#8c3c0f6e3e8449ada0af3dd29efb491a375191b3" + integrity sha512-lDmx79y1z6i7RNx0ZGCPq1bzJ6ZoDDKbvh7jxr9SJcWLkShMzXrHbYVpTdnhNM5MXpDUxCQ4DgqVttVXlBgiBQ== + dependencies: + "@types/minimatch" "^3.0.3" + array-differ "^3.0.0" + array-union "^2.1.0" + arrify "^2.0.1" + minimatch "^3.0.4" + nanoid@^2.1.6: version "2.1.11" resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-2.1.11.tgz#ec24b8a758d591561531b4176a01e3ab4f0f0280" @@ -2729,7 +2886,7 @@ pbkdf2@^3.0.3: safe-buffer "^5.0.1" sha.js "^2.4.8" -picomatch@^2.0.4, picomatch@^2.0.5, picomatch@^2.0.7: +picomatch@^2.0.4, picomatch@^2.0.5, picomatch@^2.0.7, picomatch@^2.2.1: version "2.2.2" resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.2.2.tgz#21f333e9b6b8eaff02468f5146ea406d345f4dad" integrity sha512-q0M/9eZHzmr0AulXyPwNfZjtwZ/RBZlbN3K3CErVrk50T2ASYI7Bye0EvekFY3IP1Nt2DHu0re+V2ZHIpMkuWg== @@ -3043,6 +3200,11 @@ resolve@^1.3.2: dependencies: path-parse "^1.0.6" +reusify@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" + integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== + rfdc@^1.1.4: version "1.1.4" resolved "https://registry.yarnpkg.com/rfdc/-/rfdc-1.1.4.tgz#ba72cc1367a0ccd9cf81a870b3b58bd3ad07f8c2" @@ -3107,6 +3269,11 @@ rollup@~2.3.2: optionalDependencies: fsevents "~2.1.2" +run-parallel@^1.1.9: + version "1.1.9" + resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.1.9.tgz#c9dd3a7cf9f4b2c4b6244e173a6ed866e61dd679" + integrity sha512-DEqnSRTDw/Tc3FXf49zedI638Z9onwUotBMiUFKmrO2sdFKIbXamXGQ3Axd4qgphxKB4kw/qP1w5kTxnfU1B9Q== + safe-buffer@^5.0.1, safe-buffer@^5.1.0, safe-buffer@^5.1.1, safe-buffer@^5.1.2, safe-buffer@~5.2.0: version "5.2.0" resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.2.0.tgz#b74daec49b1148f88c64b68d49b1e815c1f2f519" @@ -3570,6 +3737,15 @@ trim-newlines@^1.0.0: resolved "https://registry.yarnpkg.com/trim-newlines/-/trim-newlines-1.0.0.tgz#5887966bb582a4503a41eb524f7d35011815a613" integrity sha1-WIeWa7WCpFA6QetST301ARgVphM= +ts-morph@^7.1.2: + version "7.1.2" + resolved "https://registry.yarnpkg.com/ts-morph/-/ts-morph-7.1.2.tgz#7da8c3686b238f89988c3ee658bc5f3953ed8ec7" + integrity sha512-0ggF46muGv3v09Yf8Ce5ykTLiQ8I6hGvdB5ID/3+K4J11nCHo/vTaucqTvdFprJzQALpwQx+9bKi31mTxO0+tw== + dependencies: + "@dsherret/to-absolute-glob" "^2.0.2" + "@ts-morph/common" "~0.5.1" + code-block-writer "^10.1.0" + ts-node@~8.8.2: version "8.8.2" resolved "https://registry.yarnpkg.com/ts-node/-/ts-node-8.8.2.tgz#0b39e690bee39ea5111513a9d2bcdc0bc121755f" @@ -3648,6 +3824,11 @@ typescript@3.5.3: resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.5.3.tgz#c830f657f93f1ea846819e929092f5fe5983e977" integrity sha512-ACzBtm/PhXBDId6a6sDJfroT2pOWt/oOnk4/dElG5G33ZL776N3Y6/6bKZJBFpd+b05F3Ct9qDjMeJmRWtE2/g== +typescript@~3.9.2: + version "3.9.6" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.9.6.tgz#8f3e0198a34c3ae17091b35571d3afd31999365a" + integrity sha512-Pspx3oKAPJtjNwE92YS05HQoY7z2SFyOpHo9MqJor3BXAGNaPUs83CuVp9VISFkSjyRfiTpmKuAYGJB7S7hOxw== + uglify-js@^3.1.4: version "3.7.2" resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.7.2.tgz#cb1a601e67536e9ed094a92dd1e333459643d3f9" @@ -3661,11 +3842,21 @@ ultron@~1.1.0: resolved "https://registry.yarnpkg.com/ultron/-/ultron-1.1.1.tgz#9fe1536a10a664a65266a1e3ccf85fd36302bc9c" integrity sha512-UIEXBNeYmKptWH6z8ZnqTeS8fV74zG0/eRU9VGkpzz+LIJNs8W/zM/L+7ctCkRrgbNnnR0xxw4bKOr0cW0N0Og== +unc-path-regex@^0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/unc-path-regex/-/unc-path-regex-0.1.2.tgz#e73dd3d7b0d7c5ed86fbac6b0ae7d8c6a69d50fa" + integrity sha1-5z3T17DXxe2G+6xrCufYxqadUPo= + universalify@^0.1.0: version "0.1.2" resolved "https://registry.yarnpkg.com/universalify/-/universalify-0.1.2.tgz#b646f69be3942dabcecc9d6639c80dc105efaa66" integrity sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg== +universalify@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/universalify/-/universalify-1.0.0.tgz#b61a1da173e8435b2fe3c67d29b9adf8594bd16d" + integrity sha512-rb6X1W158d7pRQBg5gkR8uPaSfiids68LTJQYOtEUhoJUWBdaQHsuT/EUduxXYxcrt4r5PJ4fuHW1MHT6p0qug== + unpipe@1.0.0, unpipe@~1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/unpipe/-/unpipe-1.0.0.tgz#b2bf4ee8514aae6165b4817829d21b2ef49904ec" From e31720d4d348cc2f093fb8e6d2daf04fb12ca87a Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 10:52:01 -0400 Subject: [PATCH 2/7] support multipole kernels in touch_modular_op_file --- tfjs-core/scripts/touch_modular_op_files.ts | 23 ++++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tfjs-core/scripts/touch_modular_op_files.ts b/tfjs-core/scripts/touch_modular_op_files.ts index 99b3f6e3d78..41d23f2d348 100644 --- a/tfjs-core/scripts/touch_modular_op_files.ts +++ b/tfjs-core/scripts/touch_modular_op_files.ts @@ -17,16 +17,16 @@ /** * A helper script to generate empty files typically used in modularizing an op. - * Takes two params: + * params: * --op the name of the op - * --kernel the name of the kernel (this is optional) * --chained is this op part of the chained api (optional) + * --grad the name of the kernel(s) to create gradient files for * * It assumes you run it from tfjs_core. * * Example - * npx ts-node -s scripts/touch_modular_op_files.ts --op "op_name" --kernel \ - * "KernelName" --chained + * npx ts-node -s scripts/touch_modular_op_files.ts --op op_name --grad \ + * KernelName --chained * * Generates the following files (they will be empty) * tfjs_core/src/ops/op_name.ts @@ -36,6 +36,9 @@ * * if --kernel is present * tfjs_core/src/gradients/KernelName_grad.ts + * + * Example 2 (multiple kernels) + * npx ts-node -s scripts/touch_modular_op_files.ts --grad Kernel1,Kernel2 */ import * as argparse from 'argparse'; @@ -78,7 +81,10 @@ async function main() { return str.charAt(0).toLowerCase() + str.slice(1); }; - const gradientFileTemplate = `/** + const kernels: string[] = args.grad.split(','); + + kernels.forEach(kernelName => { + const gradientFileTemplate = `/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -99,7 +105,7 @@ import {KernelName, KernelNameAttrs} from '../kernel_names'; import {GradConfig, NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; -export const ${downcaseFirstChar(args.grad)}GradConfig: GradConfig = { +export const ${downcaseFirstChar(kernelName)}GradConfig: GradConfig = { kernelName: KernelName, inputsToSave: [], // UPDATE ME outputsToSave: [], // UPDATE ME @@ -111,8 +117,9 @@ export const ${downcaseFirstChar(args.grad)}GradConfig: GradConfig = { } }; `; - const filePath = `./src/gradients/${args.grad}_grad.ts`; - fs.writeFileSync(filePath, gradientFileTemplate, {flag: 'a'}); + const filePath = `./src/gradients/${kernelName}_grad.ts`; + fs.writeFileSync(filePath, gradientFileTemplate, {flag: 'a'}); + }); } } From 7d2cd0a0775f9c09f3d083e2f5f3d1725b99935e Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 11:38:02 -0400 Subject: [PATCH 3/7] support multiple ops in script --- tfjs-core/scripts/touch_modular_op_files.ts | 25 ++++++++++++--------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tfjs-core/scripts/touch_modular_op_files.ts b/tfjs-core/scripts/touch_modular_op_files.ts index 41d23f2d348..ae2e8644710 100644 --- a/tfjs-core/scripts/touch_modular_op_files.ts +++ b/tfjs-core/scripts/touch_modular_op_files.ts @@ -60,20 +60,23 @@ async function main() { console.log('Called touch_modular_op_files with args:', args); if (args.op != null) { - let filePath = `./src/ops/${args.op}.ts`; - let command = `touch ${filePath}`; - execSync(command); - - // create a test file - filePath = `./src/ops/${args.op}_test.ts`; - command = `touch ${filePath}`; - execSync(command); + const ops: string[] = args.op.split(','); + ops.forEach(op => { + let filePath = `./src/ops/${op}.ts`; + let command = `touch ${filePath}`; + execSync(command); - if (args.chained) { - filePath = `./src/public/chained_ops/${args.op}.ts`; + // create a test file + filePath = `./src/ops/${op}_test.ts`; command = `touch ${filePath}`; execSync(command); - } + + if (args.chained) { + filePath = `./src/public/chained_ops/${op}.ts`; + command = `touch ${filePath}`; + execSync(command); + } + }); } if (args.grad) { From 0a6c6b9e0b61148252d8cc0076a8572a664acb70 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 11:39:20 -0400 Subject: [PATCH 4/7] modularize sign,ceil,floor,neg --- tfjs-core/src/gradients/Atan2_grad.ts | 2 +- tfjs-core/src/gradients/Ceil_grad.ts | 29 ++ tfjs-core/src/gradients/Div_grad.ts | 2 +- tfjs-core/src/gradients/Floor_grad.ts | 28 ++ tfjs-core/src/gradients/Mod_grad.ts | 3 +- tfjs-core/src/gradients/Negate_grad.ts | 28 ++ tfjs-core/src/gradients/Sign_grad.ts | 28 ++ tfjs-core/src/gradients/Sub_grad.ts | 2 +- tfjs-core/src/kernel_names.ts | 14 + tfjs-core/src/ops/ceil.ts | 46 ++++ tfjs-core/src/ops/ceil_test.ts | 93 +++++++ tfjs-core/src/ops/floor.ts | 45 ++++ tfjs-core/src/ops/floor_test.ts | 94 +++++++ tfjs-core/src/ops/log_loss.ts | 4 +- tfjs-core/src/ops/neg.ts | 47 ++++ tfjs-core/src/ops/neg_test.ts | 94 +++++++ tfjs-core/src/ops/ops.ts | 4 + tfjs-core/src/ops/qr.ts | 2 +- tfjs-core/src/ops/sigmoid_cross_entropy.ts | 3 +- tfjs-core/src/ops/sign.ts | 45 ++++ tfjs-core/src/ops/sign_test.ts | 93 +++++++ tfjs-core/src/ops/softmax_cross_entropy.ts | 3 +- tfjs-core/src/ops/unary_ops.ts | 92 ------- tfjs-core/src/ops/unary_ops_test.ts | 294 --------------------- tfjs-core/src/register_all_gradients.ts | 30 ++- tfjs-core/src/tests.ts | 4 + 26 files changed, 724 insertions(+), 405 deletions(-) create mode 100644 tfjs-core/src/gradients/Ceil_grad.ts create mode 100644 tfjs-core/src/gradients/Floor_grad.ts create mode 100644 tfjs-core/src/gradients/Negate_grad.ts create mode 100644 tfjs-core/src/gradients/Sign_grad.ts create mode 100644 tfjs-core/src/ops/ceil.ts create mode 100644 tfjs-core/src/ops/ceil_test.ts create mode 100644 tfjs-core/src/ops/floor.ts create mode 100644 tfjs-core/src/ops/floor_test.ts create mode 100644 tfjs-core/src/ops/neg.ts create mode 100644 tfjs-core/src/ops/neg_test.ts create mode 100644 tfjs-core/src/ops/sign.ts create mode 100644 tfjs-core/src/ops/sign_test.ts diff --git a/tfjs-core/src/gradients/Atan2_grad.ts b/tfjs-core/src/gradients/Atan2_grad.ts index fae7f8b6d85..b94a282be32 100644 --- a/tfjs-core/src/gradients/Atan2_grad.ts +++ b/tfjs-core/src/gradients/Atan2_grad.ts @@ -21,10 +21,10 @@ import {add} from '../ops/add'; import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util'; import {div} from '../ops/div'; import {mul} from '../ops/mul'; +import {neg} from '../ops/neg'; import {reshape} from '../ops/reshape'; import {square} from '../ops/square'; import {sum} from '../ops/sum'; -import {neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const atan2GradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Ceil_grad.ts b/tfjs-core/src/gradients/Ceil_grad.ts new file mode 100644 index 00000000000..e3e08144ab8 --- /dev/null +++ b/tfjs-core/src/gradients/Ceil_grad.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Ceil} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const ceilGradConfig: GradConfig = { + kernelName: Ceil, + gradFunc: (dy: Tensor) => { + // TODO(manrajgrover): Return null for gradients when backprop supports it. + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Div_grad.ts b/tfjs-core/src/gradients/Div_grad.ts index 6fe57626df8..6e0b2db35ad 100644 --- a/tfjs-core/src/gradients/Div_grad.ts +++ b/tfjs-core/src/gradients/Div_grad.ts @@ -20,10 +20,10 @@ import {GradConfig} from '../kernel_registry'; import * as broadcast_util from '../ops/broadcast_util'; import {div} from '../ops/div'; import {mul} from '../ops/mul'; +import {neg} from '../ops/neg'; import {reshape} from '../ops/reshape'; import {square} from '../ops/square'; import {sum} from '../ops/sum'; -import {neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const divGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Floor_grad.ts b/tfjs-core/src/gradients/Floor_grad.ts new file mode 100644 index 00000000000..30771c454d5 --- /dev/null +++ b/tfjs-core/src/gradients/Floor_grad.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Floor} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const floorGradConfig: GradConfig = { + kernelName: Floor, + gradFunc: (dy: Tensor) => { + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Mod_grad.ts b/tfjs-core/src/gradients/Mod_grad.ts index b6acd988df8..7b50e7752fc 100644 --- a/tfjs-core/src/gradients/Mod_grad.ts +++ b/tfjs-core/src/gradients/Mod_grad.ts @@ -19,10 +19,11 @@ import {Mod} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util'; import {div} from '../ops/div'; +import {floor} from '../ops/floor'; import {mul} from '../ops/mul'; +import {neg} from '../ops/neg'; import {reshape} from '../ops/reshape'; import {sum} from '../ops/sum'; -import {floor, neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const modGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Negate_grad.ts b/tfjs-core/src/gradients/Negate_grad.ts new file mode 100644 index 00000000000..eb3456716c6 --- /dev/null +++ b/tfjs-core/src/gradients/Negate_grad.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Negate} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {neg} from '../ops/neg'; +import {Tensor} from '../tensor'; + +export const negateGradConfig: GradConfig = { + kernelName: Negate, + gradFunc: (dy: Tensor) => { + return {x: () => neg(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Sign_grad.ts b/tfjs-core/src/gradients/Sign_grad.ts new file mode 100644 index 00000000000..f6ac46229ba --- /dev/null +++ b/tfjs-core/src/gradients/Sign_grad.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sign} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {zerosLike} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; + +export const signGradConfig: GradConfig = { + kernelName: Sign, + gradFunc: (dy: Tensor) => { + return {x: () => zerosLike(dy)}; + } +}; diff --git a/tfjs-core/src/gradients/Sub_grad.ts b/tfjs-core/src/gradients/Sub_grad.ts index 6685aa3aab4..c0ea0258eac 100644 --- a/tfjs-core/src/gradients/Sub_grad.ts +++ b/tfjs-core/src/gradients/Sub_grad.ts @@ -17,9 +17,9 @@ import {Sub} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import * as broadcast_util from '../ops/broadcast_util'; +import {neg} from '../ops/neg'; import {reshape} from '../ops/reshape'; import {sum} from '../ops/sum'; -import {neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; export const subGradConfig: GradConfig = { diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index ac5d8e59300..8993803759c 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -119,6 +119,9 @@ export interface BroadCastToAttrs { inputShape: number[]; // for gradient } +export const Ceil = 'Ceil'; +export type CeilInputs = UnaryInputs; + export const Complex = 'Complex'; export type ComplexInputs = Pick; @@ -256,6 +259,9 @@ export type EluGradInputs = Pick; export const Equal = 'Equal'; export type EqualInputs = BinaryInputs; +export const Floor = 'Floor'; +export type FloorInputs = UnaryInputs; + export const FloorDiv = 'FloorDiv'; export type FloorDivInputs = BinaryInputs; @@ -404,6 +410,9 @@ export type ModInputs = BinaryInputs; export const Multiply = 'Multiply'; export type MultiplyInputs = BinaryInputs; +export const Negate = 'Negate'; +export type NegateInputs = UnaryInputs; + export const NotEqual = 'NotEqual'; export type NotEqualInputs = BinaryInputs; @@ -505,6 +514,9 @@ export type SelectV2Inputs = Pick; export const Selu = 'Selu'; export type SeluInputs = Pick; +export const Sign = 'Sign'; +export type SignInputs = UnaryInputs; + export const Sum = 'Sum'; export type SumInputs = Pick; export interface SumAttrs { @@ -547,6 +559,8 @@ export interface TransposeAttrs { perm: number[]; } +export type UnaryInputs = Pick; + export const Unpack = 'Unpack'; export type UnpackInputs = Pick; export interface UnpackAttrs { diff --git a/tfjs-core/src/ops/ceil.ts b/tfjs-core/src/ops/ceil.ts new file mode 100644 index 00000000000..b3d7d10cc79 --- /dev/null +++ b/tfjs-core/src/ops/ceil.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Ceil, CeilInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)` + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.ceil().print(); // or tf.ceil(x) + * ``` + * @param x The input Tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function ceil_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'ceil'); + + const inputs: CeilInputs = {x: $x}; + return ENGINE.runKernelFunc( + backend => backend.ceil($x), inputs as {} as NamedTensorMap, + null /* grad */, Ceil); +} +export const ceil = op({ceil_}); diff --git a/tfjs-core/src/ops/ceil_test.ts b/tfjs-core/src/ops/ceil_test.ts new file mode 100644 index 00000000000..9bdb4cf3d74 --- /dev/null +++ b/tfjs-core/src/ops/ceil_test.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('ceil', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([1.5, 2.1, -1.4]); + const r = tf.ceil(a); + expectArraysClose(await r.data(), [2, 3, -1]); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([1.5, NaN, -1.4]); + const r = tf.ceil(a); + expectArraysClose(await r.data(), [2, NaN, -1]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.ceil(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.ceil(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.ceil(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.ceil(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.ceil({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'ceil' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.ceil([1.5, 2.1, -1.4]); + expectArraysClose(await r.data(), [2, 3, -1]); + }); + + it('throws for string tensor', () => { + expect(() => tf.ceil('q')) + .toThrowError(/Argument 'x' passed to 'ceil' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/floor.ts b/tfjs-core/src/ops/floor.ts new file mode 100644 index 00000000000..3f5fbcf25e3 --- /dev/null +++ b/tfjs-core/src/ops/floor.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {Floor, FloorInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes floor of input `tf.Tensor` element-wise: `floor(x)`. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.floor().print(); // or tf.floor(x) + * ``` + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function floor_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'floor'); + + const inputs: FloorInputs = {x: $x}; + return ENGINE.runKernelFunc( + backend => backend.floor($x), inputs as {} as NamedTensorMap, + null /* grad */, Floor); +} +export const floor = op({floor_}); diff --git a/tfjs-core/src/ops/floor_test.ts b/tfjs-core/src/ops/floor_test.ts new file mode 100644 index 00000000000..4059d087362 --- /dev/null +++ b/tfjs-core/src/ops/floor_test.ts @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('floor', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([1.5, 2.1, -1.4]); + const r = tf.floor(a); + + expectArraysClose(await r.data(), [1, 2, -2]); + }); + + it('propagates NaNs', async () => { + const a = tf.tensor1d([1.5, NaN, -1.4]); + const r = tf.floor(a); + expectArraysClose(await r.data(), [1, NaN, -2]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.floor(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.floor(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const gradients = tf.grad(a => tf.floor(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.floor(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.floor({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'floor' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.floor([1.5, 2.1, -1.4]); + expectArraysClose(await r.data(), [1, 2, -2]); + }); + + it('throws for string tensor', () => { + expect(() => tf.floor('q')) + .toThrowError(/Argument 'x' passed to 'floor' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/log_loss.ts b/tfjs-core/src/ops/log_loss.ts index 17fe5c75508..0543646c578 100644 --- a/tfjs-core/src/ops/log_loss.ts +++ b/tfjs-core/src/ops/log_loss.ts @@ -19,14 +19,16 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assertShapesMatch} from '../util'; + import {add} from './add'; import {computeWeightedLoss} from './compute_weighted_loss'; import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; +import {neg} from './neg'; import {op} from './operation'; import {sub} from './sub'; import {scalar} from './tensor_ops'; -import {log, neg} from './unary_ops'; +import {log} from './unary_ops'; /** * Computes the log loss between two tensors. diff --git a/tfjs-core/src/ops/neg.ts b/tfjs-core/src/ops/neg.ts new file mode 100644 index 00000000000..a7943b2b0b4 --- /dev/null +++ b/tfjs-core/src/ops/neg.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Negate, NegateInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Computes `-1 * x` element-wise. + * + * ```js + * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]); + * + * x.neg().print(); // or tf.neg(x) + * ``` + * + * @param x The input tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function neg_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'neg'); + + const inputs: NegateInputs = {x: $x}; + return ENGINE.runKernelFunc( + backend => backend.neg($x), inputs as {} as NamedTensorMap, + null /* grad */, Negate); +} +export const neg = op({neg_}); diff --git a/tfjs-core/src/ops/neg_test.ts b/tfjs-core/src/ops/neg_test.ts new file mode 100644 index 00000000000..a44e81c882d --- /dev/null +++ b/tfjs-core/src/ops/neg_test.ts @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('neg', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([1, -3, 2, 7, -4]); + const result = tf.neg(a); + expectArraysClose(await result.data(), [-1, 3, -2, -7, 4]); + }); + + it('propagate NaNs', async () => { + const a = tf.tensor1d([1, -3, 2, 7, NaN]); + const result = tf.neg(a); + const expected = [-1, 3, -2, -7, NaN]; + expectArraysClose(await result.data(), expected); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.neg(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [8 * -1]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(4); + const dy = tf.scalar(8); + + const da = tf.grad(a => tf.neg(a.clone()).clone())(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [8 * -1]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([1, 2, -3, 5]); + const dy = tf.tensor1d([1, 2, 3, 4]); + + const da = tf.grad(a => tf.neg(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [1 * -1, 2 * -1, 3 * -1, 4 * -1]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([3, -1, -2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const da = tf.grad(a => tf.neg(a))(a, dy); + + expect(da.shape).toEqual(a.shape); + expect(da.dtype).toEqual('float32'); + expectArraysClose(await da.data(), [1 * -1, 2 * -1, 3 * -1, 4 * -1]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.neg({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'neg' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const result = tf.neg([1, -3, 2, 7, -4]); + expectArraysClose(await result.data(), [-1, 3, -2, -7, 4]); + }); + + it('throws for string tensor', () => { + expect(() => tf.neg('q')) + .toThrowError(/Argument 'x' passed to 'neg' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 5cce7903ad0..a23cbd9f1c6 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -205,3 +205,7 @@ const losses = { // Second level exports. export {image, linalg, losses, spectral, fused, signal}; +export {sign} from './sign'; +export {neg} from './neg'; +export {ceil} from './ceil'; +export {floor} from './floor'; diff --git a/tfjs-core/src/ops/qr.ts b/tfjs-core/src/ops/qr.ts index b0e95b1cb1c..31dd3573936 100644 --- a/tfjs-core/src/ops/qr.ts +++ b/tfjs-core/src/ops/qr.ts @@ -26,6 +26,7 @@ import {eye} from './eye'; import {greater} from './greater'; import {matMul} from './mat_mul'; import {mul} from './mul'; +import {neg} from './neg'; import {norm} from './norm'; import {op} from './operation'; import {reshape} from './reshape'; @@ -34,7 +35,6 @@ import {stack} from './stack'; import {sub} from './sub'; import {tensor2d} from './tensor_ops'; import {transpose} from './transpose'; -import {neg} from './unary_ops'; import {unstack} from './unstack'; import {where} from './where'; diff --git a/tfjs-core/src/ops/sigmoid_cross_entropy.ts b/tfjs-core/src/ops/sigmoid_cross_entropy.ts index dc6f8d100d2..36123ab567c 100644 --- a/tfjs-core/src/ops/sigmoid_cross_entropy.ts +++ b/tfjs-core/src/ops/sigmoid_cross_entropy.ts @@ -24,11 +24,12 @@ import {add} from './add'; import {computeWeightedLoss} from './compute_weighted_loss'; import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; +import {neg} from './neg'; import {op} from './operation'; import {relu} from './relu'; import {sub} from './sub'; import {scalar} from './tensor_ops'; -import {abs, exp, log1p, neg} from './unary_ops'; +import {abs, exp, log1p} from './unary_ops'; function sigmoidCrossEntropyWithLogits_( labels: T|TensorLike, logits: T|TensorLike): O { diff --git a/tfjs-core/src/ops/sign.ts b/tfjs-core/src/ops/sign.ts new file mode 100644 index 00000000000..de547d2c25d --- /dev/null +++ b/tfjs-core/src/ops/sign.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Sign, SignInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Returns an element-wise indication of the sign of a number. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]); + * + * x.sign().print(); // or tf.sign(x) + * ``` + * @param x The input Tensor. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function sign_(x: T|TensorLike): T { + const $x = convertToTensor(x, 'x', 'sign'); + const inputs: SignInputs = {x: $x}; + return ENGINE.runKernelFunc( + backend => backend.sign($x), inputs as {} as NamedTensorMap, + null /* grad */, Sign); +} +export const sign = op({sign_}); diff --git a/tfjs-core/src/ops/sign_test.ts b/tfjs-core/src/ops/sign_test.ts new file mode 100644 index 00000000000..40410b3f924 --- /dev/null +++ b/tfjs-core/src/ops/sign_test.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('sign', ALL_ENVS, () => { + it('basic', async () => { + const a = tf.tensor1d([1.5, 0, NaN, -1.4]); + const r = tf.sign(a); + expectArraysClose(await r.data(), [1, 0, 0, -1]); + }); + + it('does not propagate NaNs', async () => { + const a = tf.tensor1d([1.5, NaN, -1.4]); + const r = tf.sign(a); + expectArraysClose(await r.data(), [1, 0, -1]); + }); + + it('gradients: Scalar', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.sign(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradient with clones', async () => { + const a = tf.scalar(5.2); + const dy = tf.scalar(3); + + const gradients = tf.grad(a => tf.sign(a.clone()).clone())(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0]); + }); + + it('gradients: Tensor1D', async () => { + const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); + const dy = tf.tensor1d([-1, 1, 1, -1]); + + const gradients = tf.grad(a => tf.sign(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('gradients: Tensor2D', async () => { + const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); + const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); + + const gradients = tf.grad(a => tf.sign(a))(a, dy); + + expect(gradients.shape).toEqual(a.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(await gradients.data(), [0, 0, 0, 0]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.sign({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'sign' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const r = tf.sign([1.5, 0, NaN, -1.4]); + expectArraysClose(await r.data(), [1, 0, 0, -1]); + }); + + it('throws for string tensor', () => { + expect(() => tf.sign('q')) + .toThrowError(/Argument 'x' passed to 'sign' must be numeric/); + }); +}); diff --git a/tfjs-core/src/ops/softmax_cross_entropy.ts b/tfjs-core/src/ops/softmax_cross_entropy.ts index 25158f7cc37..b86f01f1db9 100644 --- a/tfjs-core/src/ops/softmax_cross_entropy.ts +++ b/tfjs-core/src/ops/softmax_cross_entropy.ts @@ -29,12 +29,13 @@ import {div} from './div'; import {logSumExp} from './log_sum_exp'; import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; +import {neg} from './neg'; import {op} from './operation'; import {reshape} from './reshape'; import {sub} from './sub'; import {sum} from './sum'; import {scalar} from './tensor_ops'; -import {exp, neg} from './unary_ops'; +import {exp} from './unary_ops'; /** * Computes softmax cross entropy between logits and labels. diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index f1a5059da7a..3c2441a3066 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -23,94 +23,6 @@ import * as util from '../util'; import {op} from './operation'; import {scalar, zerosLike} from './tensor_ops'; -/** - * Computes `-1 * x` element-wise. - * - * ```js - * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]); - * - * x.neg().print(); // or tf.neg(x) - * ``` - * - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function neg_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'neg'); - - const grad = (dy: T) => { - return {x: () => dy.neg()}; - }; - - const attrs = {}; - const inputsToSave = [$x]; - return ENGINE.runKernelFunc( - backend => backend.neg($x), {x: $x}, grad, 'Neg', attrs, inputsToSave); -} - -/** - * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)` - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3]); - * - * x.ceil().print(); // or tf.ceil(x) - * ``` - * @param x The input Tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function ceil_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'ceil'); - - // TODO(manrajgrover): Return null for gradients when backprop supports it. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.ceil($x), {$x}, grad); -} - -/** - * Computes floor of input `tf.Tensor` element-wise: `floor(x)`. - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3]); - * - * x.floor().print(); // or tf.floor(x) - * ``` - * @param x The input tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function floor_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'floor'); - - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.floor($x), {$x}, grad); -} - -/** - * Returns an element-wise indication of the sign of a number. - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]); - * - * x.sign().print(); // or tf.sign(x) - * ``` - * @param x The input Tensor. - */ -/** @doc {heading: 'Operations', subheading: 'Basic math'} */ -function sign_(x: T|TensorLike): T { - const $x = convertToTensor(x, 'x', 'sign'); - - const grad = (dy: T) => { - return {$x: () => zerosLike(dy)}; - }; - return ENGINE.runKernelFunc(backend => backend.sign($x), {$x}, grad); -} - /** * RReturns which elements of x are NaN. * @@ -932,23 +844,19 @@ export const asin = op({asin_}); export const asinh = op({asinh_}); export const atan = op({atan_}); export const atanh = op({atanh_}); -export const ceil = op({ceil_}); export const clipByValue = op({clipByValue_}); export const cos = op({cos_}); export const cosh = op({cosh_}); export const erf = op({erf_}); export const exp = op({exp_}); export const expm1 = op({expm1_}); -export const floor = op({floor_}); export const log = op({log_}); export const log1p = op({log1p_}); export const logSigmoid = op({logSigmoid_}); -export const neg = op({neg_}); export const reciprocal = op({reciprocal_}); export const round = op({round_}); export const rsqrt = op({rsqrt_}); export const sigmoid = op({sigmoid_}); -export const sign = op({sign_}); export const isNaN = op({isNaN_}); export const isInf = op({isInf_}); export const isFinite = op({isFinite_}); diff --git a/tfjs-core/src/ops/unary_ops_test.ts b/tfjs-core/src/ops/unary_ops_test.ts index a6d944a183b..e7538511175 100644 --- a/tfjs-core/src/ops/unary_ops_test.ts +++ b/tfjs-core/src/ops/unary_ops_test.ts @@ -253,80 +253,6 @@ describeWithFlags('step', ALL_ENVS, () => { }); }); -describeWithFlags('neg', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([1, -3, 2, 7, -4]); - const result = tf.neg(a); - expectArraysClose(await result.data(), [-1, 3, -2, -7, 4]); - }); - - it('propagate NaNs', async () => { - const a = tf.tensor1d([1, -3, 2, 7, NaN]); - const result = tf.neg(a); - const expected = [-1, 3, -2, -7, NaN]; - expectArraysClose(await result.data(), expected); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.neg(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [8 * -1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(4); - const dy = tf.scalar(8); - - const da = tf.grad(a => tf.neg(a.clone()).clone())(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [8 * -1]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([1, 2, -3, 5]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const da = tf.grad(a => tf.neg(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [1 * -1, 2 * -1, 3 * -1, 4 * -1]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([3, -1, -2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const da = tf.grad(a => tf.neg(a))(a, dy); - - expect(da.shape).toEqual(a.shape); - expect(da.dtype).toEqual('float32'); - expectArraysClose(await da.data(), [1 * -1, 2 * -1, 3 * -1, 4 * -1]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.neg({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'neg' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const result = tf.neg([1, -3, 2, 7, -4]); - expectArraysClose(await result.data(), [-1, 3, -2, -7, 4]); - }); - - it('throws for string tensor', () => { - expect(() => tf.neg('q')) - .toThrowError(/Argument 'x' passed to 'neg' must be numeric/); - }); -}); - describeWithFlags('sigmoid', ALL_ENVS, () => { it('basic', async () => { const values = [1, -3, 2, 7, -4]; @@ -1230,226 +1156,6 @@ describeWithFlags('log1p', ALL_ENVS, () => { }); }); -describeWithFlags('ceil', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([1.5, 2.1, -1.4]); - const r = tf.ceil(a); - expectArraysClose(await r.data(), [2, 3, -1]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([1.5, NaN, -1.4]); - const r = tf.ceil(a); - expectArraysClose(await r.data(), [2, NaN, -1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.ceil(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.ceil(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.ceil(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.ceil(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.ceil({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'ceil' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.ceil([1.5, 2.1, -1.4]); - expectArraysClose(await r.data(), [2, 3, -1]); - }); - - it('throws for string tensor', () => { - expect(() => tf.ceil('q')) - .toThrowError(/Argument 'x' passed to 'ceil' must be numeric/); - }); -}); - -describeWithFlags('floor', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([1.5, 2.1, -1.4]); - const r = tf.floor(a); - - expectArraysClose(await r.data(), [1, 2, -2]); - }); - - it('propagates NaNs', async () => { - const a = tf.tensor1d([1.5, NaN, -1.4]); - const r = tf.floor(a); - expectArraysClose(await r.data(), [1, NaN, -2]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.floor(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.floor(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); - const dy = tf.tensor1d([1, 2, 3, 4]); - - const gradients = tf.grad(a => tf.floor(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.floor(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.floor({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'floor' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.floor([1.5, 2.1, -1.4]); - expectArraysClose(await r.data(), [1, 2, -2]); - }); - - it('throws for string tensor', () => { - expect(() => tf.floor('q')) - .toThrowError(/Argument 'x' passed to 'floor' must be numeric/); - }); -}); - -describeWithFlags('sign', ALL_ENVS, () => { - it('basic', async () => { - const a = tf.tensor1d([1.5, 0, NaN, -1.4]); - const r = tf.sign(a); - expectArraysClose(await r.data(), [1, 0, 0, -1]); - }); - - it('does not propagate NaNs', async () => { - const a = tf.tensor1d([1.5, NaN, -1.4]); - const r = tf.sign(a); - expectArraysClose(await r.data(), [1, 0, -1]); - }); - - it('gradients: Scalar', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.sign(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradient with clones', async () => { - const a = tf.scalar(5.2); - const dy = tf.scalar(3); - - const gradients = tf.grad(a => tf.sign(a.clone()).clone())(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0]); - }); - - it('gradients: Tensor1D', async () => { - const a = tf.tensor1d([-1.1, 2.6, 3, -5.9]); - const dy = tf.tensor1d([-1, 1, 1, -1]); - - const gradients = tf.grad(a => tf.sign(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('gradients: Tensor2D', async () => { - const a = tf.tensor2d([-3, 1, 2.2, 3], [2, 2]); - const dy = tf.tensor2d([1, 2, 3, 4], [2, 2]); - - const gradients = tf.grad(a => tf.sign(a))(a, dy); - - expect(gradients.shape).toEqual(a.shape); - expect(gradients.dtype).toEqual('float32'); - expectArraysClose(await gradients.data(), [0, 0, 0, 0]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.sign({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'sign' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const r = tf.sign([1.5, 0, NaN, -1.4]); - expectArraysClose(await r.data(), [1, 0, 0, -1]); - }); - - it('throws for string tensor', () => { - expect(() => tf.sign('q')) - .toThrowError(/Argument 'x' passed to 'sign' must be numeric/); - }); -}); - describeWithFlags('isNaN', ALL_ENVS, () => { it('basic', async () => { const a = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index f3968bdb284..dd274e5a2f2 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -24,6 +24,7 @@ import {avgPoolGradConfig} from './gradients/AvgPool_grad'; import {batchMatMulGradConfig} from './gradients/BatchMatMul_grad'; import {batchToSpaceNDGradConfig} from './gradients/BatchToSpaceND_grad'; import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; +import {ceilGradConfig} from './gradients/Ceil_grad'; import {concatGradConfig} from './gradients/Concat_grad'; import {conv2DGradConfig} from './gradients/Conv2D_grad'; import {conv2DBackpropInputGradConfig} from './gradients/Conv2DBackpropInput_grad'; @@ -33,6 +34,7 @@ import {depthwiseConv2dNativeGradConfig} from './gradients/DepthwiseConv2dNative import {dilation2dGradConfig} from './gradients/Dilation2D_grad'; import {divGradConfig} from './gradients/Div_grad'; import {eluGradConfig} from './gradients/Elu_grad'; +import {floorGradConfig} from './gradients/Floor_grad'; import {floorDivGradConfig} from './gradients/FloorDiv_grad'; import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; @@ -46,6 +48,7 @@ import {minGradConfig} from './gradients/Min_grad'; import {minimumGradConfig} from './gradients/Minimum_grad'; import {modGradConfig} from './gradients/Mod_grad'; import {multiplyGradConfig} from './gradients/Multiply_grad'; +import {negateGradConfig} from './gradients/Negate_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; import {powGradConfig} from './gradients/Pow_grad'; @@ -58,6 +61,7 @@ import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor import {reverseGradConfig} from './gradients/Reverse_grad'; import {selectV2PoolGradConfig} from './gradients/SelectV2_grad'; import {seluGradConfig} from './gradients/Selu_grad'; +import {signGradConfig} from './gradients/Sign_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -77,14 +81,15 @@ const gradConfigs: GradConfig[] = [ argMaxGradConfig, argMinGradConfig, atan2GradConfig, - avgPoolGradConfig, avgPool3DGradConfig, + avgPoolGradConfig, batchMatMulGradConfig, batchToSpaceNDGradConfig, broadcastToGradConfig, + ceilGradConfig, concatGradConfig, - conv2DGradConfig, conv2DBackpropInputGradConfig, + conv2DGradConfig, conv3DGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, @@ -92,39 +97,42 @@ const gradConfigs: GradConfig[] = [ divGradConfig, eluGradConfig, floorDivGradConfig, + floorGradConfig, fusedBatchNormGradConfig, greaterEqualGradConfig, identityGradConfig, lrnGradConfig, - oneHotGradConfig, - padV2GradConfig, - splitVGradConfig, maxGradConfig, - spaceToBatchNDGradConfig, maxGradConfig, - minGradConfig, maximumGradConfig, - maxPoolGradConfig, maxPool3DGradConfig, + maxPoolGradConfig, + minGradConfig, minimumGradConfig, modGradConfig, multiplyGradConfig, + negateGradConfig, + oneHotGradConfig, oneHotGradConfig, padV2GradConfig, + padV2GradConfig, powGradConfig, preluGradConfig, + relu6GradConfig, reluGradConfig, reshapeGradConfig, resizeBilinearGradConfig, resizeNearestNeighborGradConfig, - relu6GradConfig, reverseGradConfig, - seluGradConfig, selectV2PoolGradConfig, + seluGradConfig, + signGradConfig, + spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, - squareGradConfig, + splitVGradConfig, squaredDifferenceGradConfig, + squareGradConfig, subGradConfig, sumGradConfig, tileGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index f330796dac6..2b31dbbf67d 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -58,6 +58,7 @@ import './ops/binary_ops_test'; import './ops/boolean_mask_test'; import './ops/broadcast_to_test'; import './ops/broadcast_util_test'; +import './ops/ceil_test'; import './ops/clone_test'; import './ops/compare_ops_test'; import './ops/complex_ops_test'; @@ -85,6 +86,7 @@ import './ops/elu_test'; import './ops/equal_test'; import './ops/expand_dims_test'; import './ops/eye_test'; +import './ops/floor_test'; import './ops/fused_test'; import './ops/gather_nd_test'; import './ops/gram_schmidt_test'; @@ -115,6 +117,7 @@ import './ops/moments_test'; import './ops/moving_average_test'; import './ops/multi_rnn_cell_test'; import './ops/multinomial_test'; +import './ops/neg_test'; import './ops/non_max_suppression_async_test'; import './ops/non_max_suppression_test'; import './ops/norm_test'; @@ -142,6 +145,7 @@ import './ops/scatter_nd_test'; import './ops/segment_ops_test'; import './ops/selu_test'; import './ops/sigmoid_cross_entropy_test'; +import './ops/sign_test'; import './ops/signal_ops_test'; import './ops/slice_test'; import './ops/slice_util_test'; From 7eee6c42e9de5e80ccdb1dd4c587e84e7ee45754 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 11:42:16 -0400 Subject: [PATCH 5/7] save --- tfjs-core/src/ops/ops.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index a23cbd9f1c6..e3d06732dd5 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -32,6 +32,7 @@ export {batchNorm2d} from './batchnorm2d'; export {batchNorm3d} from './batchnorm3d'; export {batchNorm4d} from './batchnorm4d'; export {broadcastTo} from './broadcast_to'; +export {ceil} from './ceil'; export {clone} from './clone'; export {complex} from './complex'; export {concat} from './concat'; @@ -57,6 +58,7 @@ export {equal} from './equal'; export {expandDims} from './expand_dims'; export {eye} from './eye'; export {fill} from './fill'; +export {floor} from './floor'; export {floorDiv} from './floorDiv'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; @@ -84,6 +86,7 @@ export {moments} from './moments'; export {mul} from './mul'; export {LSTMCellFunc, multiRNNCell} from './multi_rnn_cell'; export {multinomial} from './multinomial'; +export {neg} from './neg'; export {notEqual} from './not_equal'; export {oneHot} from './one_hot'; export {outerProduct} from './outer_product'; @@ -111,6 +114,7 @@ export {reverse3d} from './reverse_3d'; export {reverse4d} from './reverse_4d'; export {selu} from './selu'; export {separableConv2d} from './separable_conv2d'; +export {sign} from './sign'; export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; export {square} from './square'; @@ -205,7 +209,3 @@ const losses = { // Second level exports. export {image, linalg, losses, spectral, fused, signal}; -export {sign} from './sign'; -export {neg} from './neg'; -export {ceil} from './ceil'; -export {floor} from './floor'; From 9305b1ba006798351b26a367aa1259ce96dc8c0c Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 11:43:04 -0400 Subject: [PATCH 6/7] save --- tfjs-core/scripts/extract_op.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tfjs-core/scripts/extract_op.ts b/tfjs-core/scripts/extract_op.ts index 2fdf1b8edbe..1b2fba037f9 100644 --- a/tfjs-core/scripts/extract_op.ts +++ b/tfjs-core/scripts/extract_op.ts @@ -116,10 +116,6 @@ async function run(filePath: string, ops: string[]) { } }); - // Save the ops export file - // const opsExportFile = project.getSourceFile('src/ops/ops.ts'); - // await opsExportFile.save(); - // Remove the op from the source file and save it opExports.forEach(async o => { if (ops.length === 0 || ops.indexOf(o.opIdentifier) !== -1) { From cbb1b83af1c4de511a4d7a6f558e3418366dbb96 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 7 Jul 2020 12:51:15 -0400 Subject: [PATCH 7/7] update wasm negate --- tfjs-backend-wasm/src/cc/BUILD | 6 +++--- tfjs-backend-wasm/src/cc/kernels/{Neg.cc => Negate.cc} | 4 ++-- tfjs-backend-wasm/src/kernels/{Neg.ts => Negate.ts} | 2 +- tfjs-backend-wasm/src/kernels/all_kernels.ts | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) rename tfjs-backend-wasm/src/cc/kernels/{Neg.cc => Negate.cc} (88%) rename tfjs-backend-wasm/src/kernels/{Neg.ts => Negate.ts} (96%) diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 0ef8f64fc2c..ec2b78b2dd4 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -213,7 +213,7 @@ tfjs_cc_library( ":Min", ":Minimum", ":Multiply", - ":Neg", + ":Negate", ":NonMaxSuppressionV3", ":NonMaxSuppressionV5", ":NotEqual", @@ -634,8 +634,8 @@ tfjs_cc_library( ) tfjs_cc_library( - name = "Neg", - srcs = ["kernels/Neg.cc"], + name = "Negate", + srcs = ["kernels/Negate.cc"], deps = [ ":backend", ":unary", diff --git a/tfjs-backend-wasm/src/cc/kernels/Neg.cc b/tfjs-backend-wasm/src/cc/kernels/Negate.cc similarity index 88% rename from tfjs-backend-wasm/src/cc/kernels/Neg.cc rename to tfjs-backend-wasm/src/cc/kernels/Negate.cc index 7570e1cc810..130c301cdd6 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Neg.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Negate.cc @@ -22,7 +22,7 @@ #include "src/cc/unary.h" namespace { -inline float neg(const float val) { return -val; } +inline float negate(const float val) { return -val; } } // namespace namespace tfjs { @@ -33,7 +33,7 @@ extern "C" { #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE #endif -void Neg(const int x_id, const int out_id) { unary(x_id, out_id, neg); } +void Negate(const int x_id, const int out_id) { unary(x_id, out_id, negate); } } // extern "C" } // namespace wasm diff --git a/tfjs-backend-wasm/src/kernels/Neg.ts b/tfjs-backend-wasm/src/kernels/Negate.ts similarity index 96% rename from tfjs-backend-wasm/src/kernels/Neg.ts rename to tfjs-backend-wasm/src/kernels/Negate.ts index b226e799331..4e1cce04b27 100644 --- a/tfjs-backend-wasm/src/kernels/Neg.ts +++ b/tfjs-backend-wasm/src/kernels/Negate.ts @@ -16,4 +16,4 @@ */ import {registerUnaryKernel} from './unary_kernel'; -registerUnaryKernel('Neg'); +registerUnaryKernel('Negate'); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index ee032cfc119..b411e73634e 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -55,7 +55,7 @@ import './MaxPool'; import './Min'; import './Minimum'; import './Muliply'; -import './Neg'; +import './Negate'; import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; import './NotEqual';