From 3351dab0443584a004c15813a16dfba1ed730159 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 28 May 2020 09:11:05 -0400 Subject: [PATCH 01/13] save --- scripts/publish-npm.ts | 4 ++-- tfjs-backend-wasm/README.md | 6 ++--- tfjs-backend-wasm/scripts/build-ci.sh | 4 ++-- tfjs-core/benchmarks/index.html | 30 ++++++++++++++++++++----- tfjs-core/benchmarks/modelConfig.js | 4 ++-- tfjs-core/benchmarks/util.js | 32 +++++++++++++++++++++++++++ 6 files changed, 66 insertions(+), 14 deletions(-) diff --git a/scripts/publish-npm.ts b/scripts/publish-npm.ts index 70d88ff0bcd..52d00f4b699 100755 --- a/scripts/publish-npm.ts +++ b/scripts/publish-npm.ts @@ -123,8 +123,8 @@ async function main() { shell.cd('..'); $('git clone https://github.com/emscripten-core/emsdk.git'); shell.cd('./emsdk'); - $('./emsdk install 1.39.1'); - $('./emsdk activate 1.39.1'); + $('./emsdk install 1.39.16'); + $('./emsdk activate 1.39.16'); shell.cd('..'); shell.cd(pkg); } diff --git a/tfjs-backend-wasm/README.md b/tfjs-backend-wasm/README.md index 513e9efd92f..fb2e86e132f 100644 --- a/tfjs-backend-wasm/README.md +++ b/tfjs-backend-wasm/README.md @@ -182,13 +182,13 @@ We'd love your feedback as we develop this backend! Please file an issue ## Emscripten installation -Install the Emscripten SDK (version 1.39.1): +Install the Emscripten SDK (version 1.39.16): ```sh git clone https://github.com/emscripten-core/emsdk.git cd emsdk -./emsdk install 1.39.1 -./emsdk activate 1.39.1 +./emsdk install 1.39.16 +./emsdk activate 1.39.16 ``` ## Prepare the environment diff --git a/tfjs-backend-wasm/scripts/build-ci.sh b/tfjs-backend-wasm/scripts/build-ci.sh index f1b242e4908..3f6161e2019 100755 --- a/tfjs-backend-wasm/scripts/build-ci.sh +++ b/tfjs-backend-wasm/scripts/build-ci.sh @@ -29,10 +29,10 @@ do [ $i -gt 0 ] && echo "Retry in 15 seconds, count: $i" && sleep 15 # If install is successful, $? will hold 0 and execution will break from the # loop. - ./emsdk install 1.39.1 && break + ./emsdk install 1.39.16 && break done -./emsdk activate 1.39.1 +./emsdk activate 1.39.16 source ./emsdk_env.sh cd .. diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index 73afdccff83..d4d6fdf8673 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -19,7 +19,7 @@ TensorFlow.js Model Benchmark - + @@ -58,6 +58,8 @@

TensorFlow.js Model Benchmark

+ + @@ -77,6 +79,19 @@

TensorFlow.js Model Benchmark

run: (v) => { runBenchmark(); }, + testCorrectness: async () => { + await loadModel(); + tf.setBackend('cpu'); + const referencePrediction = predict(model); + const referenceData = await getPredictionData(referencePrediction); + + await tf.setBackend(state.backend); + const prediction = predict(model); + const predictionData = await getPredictionData(prediction); + + const match = arraysClose(referenceData, predictionData); + appendRow(timeTable, `Prediction matches CPU: ${match}`); + }, backend: 'wasm', kernelTiming: 'aggregate', }; @@ -163,7 +178,8 @@

TensorFlow.js Model Benchmark

appendRow(timeTable, '1st inference', printTime(elapsed)); } - async function loadAndRecordTime(benchmark) { + async function loadModel() { + const benchmark = benchmarks[state.benchmark]; await showMsg('Loading the model'); const start = performance.now(); if (benchmark.model == null) { @@ -173,7 +189,10 @@

TensorFlow.js Model Benchmark

model = benchmark.model; } predict = benchmark.predictFunc(); + await showMsg(null); + } + async function loadAndRecordTime() { const elapsed = performance.now() - start; await showMsg(null); @@ -345,9 +364,8 @@

TensorFlow.js Model Benchmark

} async function runBenchmark() { - const benchmark = benchmarks[state.benchmark]; - await setupTable(); - await loadAndRecordTime(benchmark); + await loadModel(); + await loadAndRecordTime(); await warmUpAndRecordTime(); await showMsg('Waiting for GC'); await sleep(1000); @@ -375,9 +393,11 @@

TensorFlow.js Model Benchmark

}); gui.add(state, 'kernelTiming', ['aggregate', 'individual']); gui.add(state, 'run'); + gui.add(state, 'testCorrectness').name('test correctness'); showVersions(); await showEnvironment(); + await setupTable(); } onPageLoad(); diff --git a/tfjs-core/benchmarks/modelConfig.js b/tfjs-core/benchmarks/modelConfig.js index d75caec96d4..92d028ab64b 100644 --- a/tfjs-core/benchmarks/modelConfig.js +++ b/tfjs-core/benchmarks/modelConfig.js @@ -79,8 +79,8 @@ const benchmarks = { return tf.loadGraphModel(url); }, predictFunc: () => { - const zeros = tf.zeros([1, 224, 224, 3]); - return model => model.predict(zeros); + const input = tf.randomNormal([1, 224, 224, 3]); + return model => model.predict(input); } }, 'mesh_128': { diff --git a/tfjs-core/benchmarks/util.js b/tfjs-core/benchmarks/util.js index 125ddcaa2d4..7191c48f47b 100644 --- a/tfjs-core/benchmarks/util.js +++ b/tfjs-core/benchmarks/util.js @@ -15,6 +15,38 @@ * ============================================================================= */ +async function getPredictionData(prediction) { + let output = prediction; + if (output instanceof Promise) { + output = await output; + } + if (output instanceof tf.Tensor) { + output = await output.data(); + } + return output; +} + +function arraysClose(n1, n2) { + const epsilon = 1e-3; + + if (n1 === n2) { + return true; + } + if (n1 == null || n2 == null) { + return false; + } + + if (n1.length !== n2.length) { + return false; + } + for (let i = 0; i < n1.length; i++) { + if (Math.abs(n1[i] - n2[i]) > epsilon) { + return false; + } + } + return true; +} + function printTime(elapsed) { return elapsed.toFixed(1) + ' ms'; } From 3b919a186a8d819335de2260f571917cbbcc9203 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 28 May 2020 09:13:01 -0400 Subject: [PATCH 02/13] clean --- tfjs-core/benchmarks/util.js | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tfjs-core/benchmarks/util.js b/tfjs-core/benchmarks/util.js index 7191c48f47b..643dd20d69e 100644 --- a/tfjs-core/benchmarks/util.js +++ b/tfjs-core/benchmarks/util.js @@ -26,16 +26,14 @@ async function getPredictionData(prediction) { return output; } +const epsilon = 1e-3; function arraysClose(n1, n2) { - const epsilon = 1e-3; - if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } - if (n1.length !== n2.length) { return false; } From 2d6f753f08a8da32766164ca7063d257f5beb1fa Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 28 May 2020 09:39:27 -0400 Subject: [PATCH 03/13] load --- tfjs-backend-wasm/WORKSPACE | 37 +- tfjs-core/benchmarks/tf-backend-cpu.js | 3670 +++ tfjs-core/benchmarks/tf-backend-wasm.js | 2941 ++ tfjs-core/benchmarks/tf-core.js | 26220 ++++++++++++++++++ tfjs-core/benchmarks/tfjs-backend-wasm.wasm | Bin 0 -> 149561 bytes 5 files changed, 32849 insertions(+), 19 deletions(-) create mode 100644 tfjs-core/benchmarks/tf-backend-cpu.js create mode 100644 tfjs-core/benchmarks/tf-backend-wasm.js create mode 100644 tfjs-core/benchmarks/tf-core.js create mode 100644 tfjs-core/benchmarks/tfjs-backend-wasm.wasm diff --git a/tfjs-backend-wasm/WORKSPACE b/tfjs-backend-wasm/WORKSPACE index 10d6da61368..676fa58c3ed 100644 --- a/tfjs-backend-wasm/WORKSPACE +++ b/tfjs-backend-wasm/WORKSPACE @@ -8,9 +8,9 @@ emsdk_configure(name = "emsdk") git_repository( name = "xnnpack", - commit = "15d1f511d37a8dad1ab7a80cfefd7014accf72ac", + commit = "1e5f80293b3c0197aaf44f3adb9329401fd36ed4", remote = "https://github.com/google/XNNPACK.git", - shallow_since = "1582560423 -0800", + shallow_since = "1590511147 -0700", ) # The libraries below are transitive dependencies of XNNPACK that we need to @@ -20,32 +20,30 @@ git_repository( http_archive( name = "FP16", build_file = "@xnnpack//third_party:FP16.BUILD", - sha256 = "9764297a339ad73b0717331a2c3e9c42a52105cd04cab62cb160e2b4598d2ea6", - strip_prefix = "FP16-ba1d31f5eed2eb4a69e4dea3870a68c7c95f998f", + sha256 = "0d56bb92f649ec294dbccb13e04865e3c82933b6f6735d1d7145de45da700156", + strip_prefix = "FP16-3c54eacb74f6f5e39077300c5564156c424d77ba", urls = [ - "https://github.com/Maratyszcza/FP16/archive/ba1d31f5eed2eb4a69e4dea3870a68c7c95f998f.tar.gz", + "https://github.com/Maratyszcza/FP16/archive/3c54eacb74f6f5e39077300c5564156c424d77ba.zip", ], ) # FXdiv library, used for repeated integer division by the same factor http_archive( name = "FXdiv", - build_file = "@xnnpack//third_party:FXdiv.BUILD", - sha256 = "7d3215bea832fe77091ec5666200b91156df6724da1e348205078346325fc45e", - strip_prefix = "FXdiv-f8c5354679ec2597792bc70a9e06eff50c508b9a", + sha256 = "ab7dfb08829bee33dca38405d647868fb214ac685e379ec7ef2bebcd234cd44d", + strip_prefix = "FXdiv-b408327ac2a15ec3e43352421954f5b1967701d1", urls = [ - "https://github.com/Maratyszcza/FXdiv/archive/f8c5354679ec2597792bc70a9e06eff50c508b9a.tar.gz", + "https://github.com/Maratyszcza/FXdiv/archive/b408327ac2a15ec3e43352421954f5b1967701d1.zip", ], ) # pthreadpool library, used for parallelization http_archive( name = "pthreadpool", - build_file = "@xnnpack//third_party:pthreadpool.BUILD", - sha256 = "c2328fdf9e48ac9b928953bcbc442eb14402d393e4cfae0541581a3d39efca9d", - strip_prefix = "pthreadpool-0e275fe56094626349c55a524ea8b71a85daa64b", + sha256 = "af8c518b6ec65dca216143ddf5ef9d2e6b133123f9a47a24841ef447c5d91bd1", + strip_prefix = "pthreadpool-6525d8bb736b323eb4df9e4f3afdd3a8458d1a20", urls = [ - "https://github.com/Maratyszcza/pthreadpool/archive/0e275fe56094626349c55a524ea8b71a85daa64b.tar.gz", + "https://github.com/Maratyszcza/pthreadpool/archive/6525d8bb736b323eb4df9e4f3afdd3a8458d1a20.zip", ], ) @@ -60,15 +58,16 @@ http_archive( ], ) + # cpuinfo library, used for detecting processor characteristics http_archive( name = "cpuinfo", build_file = "@xnnpack//third_party:cpuinfo.BUILD", patches = ["@xnnpack//third_party:cpuinfo.patch"], - sha256 = "3f2dc1970f397a0e59db72f9fca6ff144b216895c1d606f6c94a507c1e53a025", - strip_prefix = "cpuinfo-d5e37adf1406cf899d7d9ec1d317c47506ccb970", + sha256 = "ea56c399a4f6ca5f749e71acb6a7bfdc653eb65d8f658cb2e414a2fcdca1fe8b", + strip_prefix = "cpuinfo-c2092219e7c874783a00a62edb94ddc672f57ab3", urls = [ - "https://github.com/pytorch/cpuinfo/archive/d5e37adf1406cf899d7d9ec1d317c47506ccb970.tar.gz", + "https://github.com/pytorch/cpuinfo/archive/c2092219e7c874783a00a62edb94ddc672f57ab3.zip", ], ) @@ -76,10 +75,10 @@ http_archive( http_archive( name = "psimd", build_file = "@xnnpack//third_party:psimd.BUILD", - sha256 = "c621f9bb1ff9ab8f0fa4a04f3239d13b345a6e865318d7b464aa80531a1abb2c", - strip_prefix = "psimd-88882f601f8179e1987b7e7cf4a8012c9080ad44", + sha256 = "dc615342bcbe51ca885323e51b68b90ed9bb9fa7df0f4419dbfa0297d5e837b7", + strip_prefix = "psimd-072586a71b55b7f8c584153d223e95687148a900", urls = [ - "https://github.com/Maratyszcza/psimd/archive/88882f601f8179e1987b7e7cf4a8012c9080ad44.tar.gz", + "https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip", ], ) diff --git a/tfjs-core/benchmarks/tf-backend-cpu.js b/tfjs-core/benchmarks/tf-backend-cpu.js new file mode 100644 index 00000000000..3ab42671e52 --- /dev/null +++ b/tfjs-core/benchmarks/tf-backend-cpu.js @@ -0,0 +1,3670 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +(function (global, factory) { + typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core'), require('seedrandom')) : + typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core', 'seedrandom'], factory) : + (global = global || self, factory(global.tf = global.tf || {}, global.tf, global.seedrandom)); +}(this, (function (exports, tf, seedrandom) { 'use strict'; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function maxImpl(aVals, reduceSize, outShape, dtype) { + var vals = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(outShape)); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var max = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return vals; + } + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function transposeImpl(xVals, xShape, dtype, perm, newShape) { + var xRank = xShape.length; + var xSize = tf.util.sizeFromShape(xShape); + var xStrides = tf.util.computeStrides(xShape); + var newStrides = tf.util.computeStrides(newShape); + var result = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(newShape)); + for (var i = 0; i < xSize; ++i) { + var loc = tf.util.indexToLoc(i, xRank, xStrides); + // Permute location. + var newLoc = new Array(loc.length); + for (var i_1 = 0; i_1 < newLoc.length; i_1++) { + newLoc[i_1] = loc[perm[i_1]]; + } + var newIndex = tf.util.locToIndex(newLoc, xRank, newStrides); + result[newIndex] = xVals[i]; + } + return result; + } + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + var shared = { + __proto__: null, + maxImpl: maxImpl, + transposeImpl: transposeImpl + }; + + /*! ***************************************************************************** + Copyright (c) Microsoft Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of the + License at http://www.apache.org/licenses/LICENSE-2.0 + + THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED + WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, + MERCHANTABLITY OR NON-INFRINGEMENT. + + See the Apache Version 2.0 License for specific language governing permissions + and limitations under the License. + ***************************************************************************** */ + /* global Reflect, Promise */ + + var extendStatics = function(d, b) { + extendStatics = Object.setPrototypeOf || + ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || + function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; + return extendStatics(d, b); + }; + + function __extends(d, b) { + extendStatics(d, b); + function __() { this.constructor = d; } + d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); + } + + function __awaiter(thisArg, _arguments, P, generator) { + function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } + return new (P || (P = Promise))(function (resolve, reject) { + function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } + function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } + function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } + step((generator = generator.apply(thisArg, _arguments || [])).next()); + }); + } + + function __generator(thisArg, body) { + var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; + return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; + function verb(n) { return function (v) { return step([n, v]); }; } + function step(op) { + if (f) throw new TypeError("Generator is already executing."); + while (_) try { + if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; + if (y = 0, t) op = [op[0] & 2, t.value]; + switch (op[0]) { + case 0: case 1: t = op; break; + case 4: _.label++; return { value: op[1], done: false }; + case 5: _.label++; y = op[1]; op = [0]; continue; + case 7: op = _.ops.pop(); _.trys.pop(); continue; + default: + if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } + if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } + if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } + if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } + if (t[2]) _.ops.pop(); + _.trys.pop(); continue; + } + op = body.call(thisArg, _); + } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } + if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; + } + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function assertNotComplex(tensor, opName) { + if (!Array.isArray(tensor)) { + tensor = [tensor]; + } + tensor.forEach(function (t) { + if (t != null) { + tf.util.assert(t.dtype !== 'complex64', function () { return opName + " does not support complex64 tensors in the CPU backend."; }); + } + }); + } + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function pool(xValues, xShape, dtype, strides, convInfo, poolType) { + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY : + Number.POSITIVE_INFINITY); + var output = tf.buffer(convInfo.outShape, dtype); + var outputVals = output.values; + var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3]; + var outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3]; + var outputColStrides = convInfo.outShape[3]; + for (var b = 0; b < convInfo.batchSize; ++b) { + var outputBatchOffset = b * outputBatchStrides; + var inputBatchOffset = b * strides[0]; + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var xRCorner = yR * strideHeight - padTop; + var xRMin = Math.max(0, xRCorner); + var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner); + var outputRowOffset = outputBatchOffset + yR * outputRowStrides; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var xCCorner = yC * strideWidth - padLeft; + var xCMin = Math.max(0, xCCorner); + var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner); + var minMaxValue = initialValue; + var avgValue = 0; + var count = 0; + for (var xR = xRMin; xR < xRMax; xR += dilationHeight) { + var xROffset = inputBatchOffset + xR * strides[1]; + for (var xC = xCMin; xC < xCMax; xC += dilationWidth) { + var xCOffset = xROffset + xC * strides[2]; + var pixel = xValues[xCOffset + d]; + if ((poolType === 'max' && pixel > minMaxValue)) { + minMaxValue = pixel; + } + else if (poolType === 'avg') { + avgValue += pixel; + count++; + } + } + if (isNaN(minMaxValue)) { + break; + } + } + var outputOffset = outputRowOffset + yC * outputColStrides + d; + outputVals[outputOffset] = + poolType === 'avg' ? avgValue / count : minMaxValue; + } + } + } + } + return output; + } + function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions, includeBatchInIndex) { + if (flattenPositions === void 0) { flattenPositions = false; } + if (includeBatchInIndex === void 0) { includeBatchInIndex = false; } + var maxPositions = tf.buffer(convInfo.outShape, 'int32'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var xBuf = tf.buffer(xShape, dtype, xValues); + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var xRCorner = yR * strideHeight - padTop; + var xRMin = xRCorner; + while (xRMin < 0) { + xRMin += dilationHeight; + } + // const xRMin = Math.max(0, xRCorner); + var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner); + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var xCCorner = yC * strideWidth - padLeft; + var xCMin = xCCorner; + while (xCMin < 0) { + xCMin += dilationWidth; + } + var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner); + var maxValue = Number.NEGATIVE_INFINITY; + var maxPosition = -1; + for (var xR = xRMin; xR < xRMax; xR += dilationHeight) { + var wR = xR - xRCorner; + for (var xC = xCMin; xC < xCMax; xC += dilationWidth) { + var wC = xC - xCCorner; + var pixel = xBuf.get(b, xR, xC, d); + if (pixel > maxValue) { + maxValue = pixel; + if (flattenPositions) { + maxPosition = includeBatchInIndex ? + ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) * + convInfo.inChannels + + d : + (xR * convInfo.inWidth + xC) * convInfo.inChannels + d; + } + else { + maxPosition = wR * effectiveFilterWidth + wC; + } + } + } + } + maxPositions.set(maxPosition, b, yR, yC, d); + } + } + } + } + return maxPositions; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var nonMaxSuppressionV3 = tf.kernel_impls.nonMaxSuppressionV3; + var split = tf.kernel_impls.split; + var tile = tf.kernel_impls.tile; + var topkImpl = tf.kernel_impls.topkImpl; + var whereImpl = tf.kernel_impls.whereImpl; + function mapActivation(backend, x, activation, preluActivationWeights) { + if (activation === 'linear') { + return backend.linear(x); + } + else if (activation === 'relu') { + return backend.relu(x); + } + else if (activation === 'elu') { + return backend.elu(x); + } + else if (activation === 'relu6') { + return backend.relu6(x); + } + else if (activation === 'prelu') { + return backend.prelu(x, preluActivationWeights); + } + throw new Error("Activation " + activation + " has not been implemented for the CPU backend."); + } + var MathBackendCPU = /** @class */ (function (_super) { + __extends(MathBackendCPU, _super); + function MathBackendCPU() { + var _this = _super.call(this) || this; + _this.blockSize = 48; + _this.firstUse = true; + _this.data = new tf.DataStorage(_this, tf.engine()); + return _this; + } + MathBackendCPU.prototype.write = function (values, shape, dtype) { + if (this.firstUse) { + this.firstUse = false; + if (tf.env().get('IS_NODE')) { + tf.backend_util.warn('\n============================\n' + + 'Hi there 👋. Looks like you are running TensorFlow.js in ' + + 'Node.js. To speed things up dramatically, install our node ' + + 'backend, which binds to TensorFlow C++, by running ' + + 'npm i @tensorflow/tfjs-node, ' + + 'or npm i @tensorflow/tfjs-node-gpu if you have CUDA. ' + + 'Then call require(\'@tensorflow/tfjs-node\'); (-gpu ' + + 'suffix for CUDA) at the start of your program. ' + + 'Visit https://github.com/tensorflow/tfjs-node for more details.' + + '\n============================'); + } + } + var dataId = {}; + this.data.set(dataId, { values: values, dtype: dtype }); + return dataId; + }; + MathBackendCPU.prototype.move = function (dataId, values, shape, dtype) { + this.data.set(dataId, { values: values, dtype: dtype }); + }; + MathBackendCPU.prototype.numDataIds = function () { + return this.data.numDataIds(); + }; + MathBackendCPU.prototype.read = function (dataId) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.readSync(dataId)]; + }); + }); + }; + MathBackendCPU.prototype.readSync = function (dataId) { + var _a = this.data.get(dataId), dtype = _a.dtype, complexTensors = _a.complexTensors; + if (dtype === 'complex64') { + var realValues = this.readSync(complexTensors.real.dataId); + var imagValues = this.readSync(complexTensors.imag.dataId); + return tf.backend_util.mergeRealAndImagArrays(realValues, imagValues); + } + return this.data.get(dataId).values; + }; + MathBackendCPU.prototype.bufferSync = function (t) { + var data = this.readSync(t.dataId); + var decodedData = data; + if (t.dtype === 'string') { + try { + // Decode the bytes into string. + decodedData = data.map(function (d) { return tf.util.decodeString(d); }); + } + catch (_a) { + throw new Error('Failed to decode encoded string bytes into utf-8'); + } + } + return tf.buffer(t.shape, t.dtype, decodedData); + }; + MathBackendCPU.prototype.makeOutput = function (values, shape, dtype) { + var dataId = this.write(values, shape, dtype); + return tf.engine().makeTensorFromDataId(dataId, shape, dtype, this); + }; + MathBackendCPU.prototype.disposeData = function (dataId) { + if (this.data.has(dataId)) { + var complexTensors = this.data.get(dataId).complexTensors; + if (complexTensors != null) { + complexTensors.real.dispose(); + complexTensors.imag.dispose(); + } + this.data.delete(dataId); + } + }; + MathBackendCPU.prototype.time = function (f) { + return __awaiter(this, void 0, void 0, function () { + var start, kernelMs; + return __generator(this, function (_a) { + start = tf.util.now(); + f(); + kernelMs = tf.util.now() - start; + return [2 /*return*/, { kernelMs: kernelMs }]; + }); + }); + }; + MathBackendCPU.prototype.memory = function () { + return { + // Unreliable due to automatic gc. The numbers above are cumulative. + unreliable: true, + reasons: ['The reported memory is an upper bound. Due to automatic garbage ' + + 'collection, the true allocated memory may be less.'] + }; + }; + MathBackendCPU.prototype.complex = function (real, imag) { + var result = this.makeOutput(null, real.shape, 'complex64'); + var resultData = this.data.get(result.dataId); + // The backend owns the reference to the underlying real and imaginary + // clones. These will explicitly get disposed when the complex tensor is + // disposed. + resultData.complexTensors = { + real: tf.engine().keep(real.clone()), + imag: tf.engine().keep(imag.clone()) + }; + return result; + }; + MathBackendCPU.prototype.real = function (input) { + var resultData = this.data.get(input.dataId); + return resultData.complexTensors.real.clone(); + }; + MathBackendCPU.prototype.imag = function (input) { + var resultData = this.data.get(input.dataId); + return resultData.complexTensors.imag.clone(); + }; + MathBackendCPU.prototype.slice = function (x, begin, size) { + assertNotComplex(x, 'slice'); + var isContinous = tf.slice_util.isSliceContinous(x.shape, begin, size); + if (isContinous) { + var flatOffset = tf.slice_util.computeFlatOffset(begin, x.strides); + var length_1 = tf.util.sizeFromShape(size); + var vals = this.readSync(x.dataId); + return tf.tensor(vals.subarray(flatOffset, flatOffset + length_1), size, x.dtype); + } + var buffer = tf.buffer(size, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < buffer.size; ++i) { + var loc = buffer.indexToLoc(i); + var xLoc = loc.map(function (idx, j) { return idx + begin[j]; }); + buffer.values[i] = xBuf.get.apply(xBuf, xLoc); + } + return buffer.toTensor(); + }; + MathBackendCPU.prototype.stridedSlice = function (x, begin, end, strides) { + assertNotComplex(x, 'stridedSlice'); + var outShape = tf.slice_util.computeOutShape(begin, end, strides); + if (outShape.some(function (axis) { return axis === 0; })) { + return tf.tensor([], outShape); + } + var buffer = tf.buffer(outShape, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < buffer.size; i++) { + var loc = buffer.indexToLoc(i); + var newLoc = new Array(loc.length); + for (var j = 0; j < newLoc.length; j++) { + newLoc[j] = loc[j] * strides[j] + begin[j]; + } + buffer.set.apply(buffer, [xBuf.get.apply(xBuf, newLoc)].concat(loc)); + } + return buffer.toTensor(); + }; + MathBackendCPU.prototype.diag = function (x) { + var xVals = this.readSync(x.dataId); + var buffer = tf.buffer([x.size, x.size], x.dtype); + var vals = buffer.values; + for (var i = 0; i < xVals.length; i++) { + vals[i * x.size + i] = xVals[i]; + } + return buffer.toTensor(); + }; + MathBackendCPU.prototype.unstack = function (x, axis) { + var num = x.shape[axis]; + var outShape = new Array(x.rank - 1); + var outIndex = 0; + for (var i = 0; i < x.rank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + var begin = new Array(x.rank).fill(0); + var size = x.shape.slice(); + size[axis] = 1; + var res = new Array(num); + for (var i = 0; i < res.length; i++) { + begin[axis] = i; + res[i] = this.slice(x, begin, size).reshape(outShape); + } + return res; + }; + MathBackendCPU.prototype.reverse = function (x, axis) { + assertNotComplex(x, 'reverse'); + var buffer = tf.buffer(x.shape, x.dtype); + var xBuf = this.bufferSync(x); + var _loop_1 = function (i) { + var outLoc = buffer.indexToLoc(i); + var inLoc = outLoc.slice(); + axis.forEach(function (ax) { return inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]; }); + buffer.set.apply(buffer, [xBuf.get.apply(xBuf, inLoc)].concat(outLoc)); + }; + for (var i = 0; i < buffer.size; i++) { + _loop_1(i); + } + return buffer.toTensor(); + }; + MathBackendCPU.prototype.concat = function (tensors, axis) { + var _this = this; + if (tensors[0].dtype === 'complex64') { + var reals = tensors.map(function (t) { return tf.real(t); }); + var imags = tensors.map(function (t) { return tf.imag(t); }); + return tf.complex(this.concat(reals, axis), this.concat(imags, axis)); + } + var tensors2D = tensors.map(function (t) { + var innerSize = tf.util.sizeFromShape(t.shape.slice(axis)); + return t.as2D(-1, innerSize); + }); + var outShape = tf.backend_util.computeOutShape(tensors2D.map(function (t) { return t.shape; }), 1 /* axis + */); + var values = tf.buffer(outShape, tensors[0].dtype) + .values; + if (tensors2D[0].shape[0] === 1) { + // Use built-in TypedArray.set() method for speed. + var offset_1 = 0; + tensors2D.forEach(function (t) { + values.set(_this.readSync(t.dataId), offset_1); + offset_1 += t.size; + }); + } + else { + var colOffset_1 = 0; + tensors2D.forEach(function (t) { + var tVals = _this.readSync(t.dataId); + var tIdx = 0; + for (var row = 0; row < t.shape[0]; ++row) { + var resIdx = row * outShape[1] + colOffset_1; + for (var col = 0; col < t.shape[1]; ++col) { + values[resIdx + col] = tVals[tIdx++]; + } + } + colOffset_1 += t.shape[1]; + }); + } + var finalOutShape = tf.backend_util.computeOutShape(tensors.map(function (t) { return t.shape; }), axis); + return tf.tensor(values, finalOutShape, tensors[0].dtype); + }; + MathBackendCPU.prototype.neg = function (x) { + assertNotComplex(x, 'neg'); + return this.multiply(tf.scalar(-1), x); + }; + MathBackendCPU.prototype.add = function (a, b) { + if (a.dtype === 'complex64' || b.dtype === 'complex64') { + return this.broadcastedBinaryComplexOp(a.cast('complex64'), b.cast('complex64'), function (aReal, aImag, bReal, bImag) { + return { real: aReal + bReal, imag: aImag + bImag }; + }); + } + return this.broadcastedBinaryOp(a, b, tf.upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue + bValue; }); + }; + MathBackendCPU.prototype.addN = function (tensors) { + var _this = this; + assertNotComplex(tensors, 'addN'); + var vals = tensors.map(function (t) { return _this.readSync(t.dataId); }); + var result = tf.buffer(tensors[0].shape, tensors[0].dtype); + var resultVals = result.values; + for (var i = 0; i < tensors.length; i++) { + var currVals = vals[i]; + for (var j = 0; j < resultVals.length; j++) { + resultVals[j] += currVals[j]; + } + } + return result.toTensor(); + }; + MathBackendCPU.prototype.softmax = function (logits, dim) { + var axes = tf.util.parseAxisParam([dim], logits.shape); + // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel + // modularization. + var maxLogit = tf.max(logits, axes); + var expandedShape = tf.backend_util.expandShapeToKeepDim(maxLogit.shape, axes); + var a = this.subtract(logits, maxLogit.reshape(expandedShape)); + var b = this.exp(a); + var sumExp = this.sum(b, axes).reshape(expandedShape); + // TODO(annxingyuan): Call divImpl rather than op as part of softmax + // kernel modularization. + return tf.div(b, sumExp); + }; + MathBackendCPU.prototype.subtract = function (a, b) { + if (a.dtype === 'complex64' || b.dtype === 'complex64') { + return this.broadcastedBinaryComplexOp(a.cast('complex64'), b.cast('complex64'), function (aReal, aImag, bReal, bImag) { + return { real: aReal - bReal, imag: aImag - bImag }; + }); + } + return this.broadcastedBinaryOp(a, b, tf.upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue - bValue; }); + }; + MathBackendCPU.prototype.pow = function (a, b) { + assertNotComplex([a, b], 'pow'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aValue, bValue) { return Math.pow(aValue, bValue); }); + }; + MathBackendCPU.prototype.batchMatMul = function (a, b, transposeA, transposeB) { + assertNotComplex([a, b], 'matMul'); + var sharedDim = transposeA ? a.shape[1] : a.shape[2]; + var leftDim = transposeA ? a.shape[2] : a.shape[1]; + var rightDim = transposeB ? b.shape[1] : b.shape[2]; + var batchDim = a.shape[0]; + var aValues = this.readSync(a.dataId); + var bValues = this.readSync(b.dataId); + var _a = transposeA ? + [a.strides[0], 1, a.strides[1]] : + [a.strides[0], a.strides[1], 1], aBatch = _a[0], aOuterStep = _a[1], aInnerStep = _a[2]; + var _b = transposeB ? + [1, b.strides[1], b.strides[0]] : + [b.strides[1], 1, b.strides[0]], bInnerStep = _b[0], bOuterStep = _b[1], bBatch = _b[2]; + var size = leftDim * rightDim; + var result = tf.buffer([batchDim, leftDim, rightDim], a.dtype); + var resVals = result.values; + var blockSize = this.blockSize; + for (var b_1 = 0; b_1 < batchDim; b_1++) { + for (var i0 = 0; i0 < leftDim; i0 += blockSize) { + for (var j0 = 0; j0 < rightDim; j0 += blockSize) { + for (var k0 = 0; k0 < sharedDim; k0 += blockSize) { + // for when blockSize doesn't evenly divide the input + var iBlock = Math.min(i0 + blockSize, leftDim); + var jBlock = Math.min(j0 + blockSize, rightDim); + var kBlock = Math.min(k0 + blockSize, sharedDim); + for (var i = i0; i < iBlock; i++) { + for (var j = j0; j < jBlock; j++) { + var sum = 0.0; + for (var k = k0; k < kBlock; k++) { + sum += aValues[b_1 * aBatch + i * aOuterStep + k * aInnerStep] * + bValues[k * bInnerStep + j * bOuterStep + b_1 * bBatch]; + } + resVals[b_1 * size + (i * rightDim + j)] += sum; + } + } + } + } + } + } + return result.toTensor(); + }; + MathBackendCPU.prototype.fusedBatchMatMul = function (_a) { + var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var result = this.batchMatMul(a, b, transposeA, transposeB); + if (bias) { + result = this.add(result, bias); + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights); + } + return result; + }; + MathBackendCPU.prototype.multiply = function (a, b) { + if (a.dtype === 'complex64' || b.dtype === 'complex64') { + return this.broadcastedBinaryComplexOp(a.cast('complex64'), b.cast('complex64'), function (aReal, aImag, bReal, bImag) { + return { + real: aReal * bReal - aImag * bImag, + imag: aReal * bImag + aImag * bReal + }; + }); + } + return this.broadcastedBinaryOp(a, b, tf.upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue * bValue; }); + }; + MathBackendCPU.prototype.floorDiv = function (a, b) { + assertNotComplex([a, b], 'floorDiv'); + var op = function (a, b) { return Math.floor(a / b); }; + var outputDtype = 'int32'; + return this.broadcastedBinaryOp(a, b, outputDtype, op); + }; + MathBackendCPU.prototype.sum = function (x, axes) { + assertNotComplex(x, 'sum'); + tf.backend_util.assertAxesAreInnerMostDims('sum', axes, x.rank); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var resultDtype = tf.upcastType(x.dtype, 'int32'); + var result = tf.zeros(outShape, resultDtype); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var sum = 0; + for (var j = 0; j < reduceSize; ++j) { + sum += aVals[offset + j]; + } + vals[i] = sum; + } + return result; + }; + MathBackendCPU.prototype.prod = function (x, axes) { + assertNotComplex(x, 'sum'); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var resultDtype = tf.upcastType(x.dtype, 'int32'); + var result = tf.zeros(outShape, resultDtype); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var prod = 1; + for (var j = 0; j < reduceSize; ++j) { + prod *= aVals[offset + j]; + } + vals[i] = prod; + } + return result; + }; + MathBackendCPU.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) { + assertNotComplex(x, 'unsortedSegmentSum'); + var res = []; + // Reshape the segment id's so that they can be broadcast with + // x. The new shape should be [segmentIds.shape, 1, ..., 1] + var numIters = x.rank - segmentIds.rank; + for (var i = 0; i < numIters; ++i) { + segmentIds = segmentIds.expandDims(i + 1); + } + for (var i = 0; i < numSegments; ++i) { + var segmentId = tf.scalar(i, 'int32'); + var mask = tf.equal(segmentId, segmentIds).asType('float32'); + var sum = mask.mul(x).sum(0); + res.push(sum); + } + return tf.stack(res); + }; + MathBackendCPU.prototype.argMin = function (x, axis) { + assertNotComplex(x, 'argMin'); + var axes = [axis]; + tf.backend_util.assertAxesAreInnerMostDims('argMin', axes, x.rank); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = tf.zeros(outShape, 'int32'); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var min = aVals[offset]; + var minIndex = 0; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value < min) { + min = value; + minIndex = j; + } + } + vals[i] = minIndex; + } + return result; + }; + MathBackendCPU.prototype.argMax = function (x, axis) { + assertNotComplex(x, 'argMax'); + var axes = [axis]; + tf.backend_util.assertAxesAreInnerMostDims('argMax', axes, x.rank); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = tf.zeros(outShape, 'int32'); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var max_1 = aVals[offset]; + var maxIndex = 0; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value > max_1) { + max_1 = value; + maxIndex = j; + } + } + vals[i] = maxIndex; + } + return result; + }; + MathBackendCPU.prototype.cumsum = function (x, axis, exclusive, reverse) { + assertNotComplex(x, 'cumsum'); + if (axis !== x.rank - 1) { + throw new Error("backend.cumsum in CPU expects an inner-most axis=" + (x.rank - 1) + " " + + ("but got axis=" + axis)); + } + var resultDtype = tf.upcastType(x.dtype, 'int32'); + var result = tf.zeros(x.shape, resultDtype); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + var finalDim = x.shape[x.rank - 1]; + var indexAdjuster = reverse ? + function (i, j) { return i + finalDim - j - 1; } : + function (i, j) { return i + j; }; + for (var i = 0; i < aVals.length; i += finalDim) { + for (var j = 0; j < finalDim; j++) { + var idx = indexAdjuster(i, j); + if (j === 0) { + vals[idx] = exclusive ? 0 : aVals[idx]; + } + else { + var prevIdx = indexAdjuster(i, j - 1); + vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] : + aVals[idx] + vals[prevIdx]; + } + } + } + return result; + }; + MathBackendCPU.prototype.equal = function (a, b) { + assertNotComplex([a, b], 'equal'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal === bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.notEqual = function (a, b) { + assertNotComplex([a, b], 'notEqual'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal !== bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.less = function (a, b) { + assertNotComplex([a, b], 'less'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal < bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.lessEqual = function (a, b) { + assertNotComplex([a, b], 'lessEqual'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal <= bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.greater = function (a, b) { + assertNotComplex([a, b], 'greater'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal > bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.greaterEqual = function (a, b) { + assertNotComplex([a, b], 'greaterEqual'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal >= bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.logicalNot = function (x) { + assertNotComplex(x, 'logicalNot'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = values[i] ? 0 : 1; + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.logicalAnd = function (a, b) { + assertNotComplex([a, b], 'logicalAnd'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return aVal && bVal; + }); + }; + MathBackendCPU.prototype.logicalOr = function (a, b) { + assertNotComplex([a, b], 'logicalOr'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return aVal || bVal; + }); + }; + MathBackendCPU.prototype.select = function (condition, a, b) { + assertNotComplex([condition, a, b], 'select'); + var values = this.readSync(condition.dataId); + var aValues = this.readSync(a.dataId); + var bValues = this.readSync(b.dataId); + var result = tf.zeros(a.shape, tf.upcastType(a.dtype, b.dtype)); + var newValues = this.readSync(result.dataId); + var index = 0; + var offset = condition.rank === 0 || condition.rank > 1 || a.rank === 1 ? + 1 : + tf.util.sizeFromShape(a.shape.slice(1)); + for (var i = 0; i < values.length; i++) { + for (var j = 0; j < offset; j++) { + if (values[i] === 1) { + newValues[index++] = aValues[i]; + } + else { + newValues[index++] = bValues[i]; + } + } + } + return result; + }; + MathBackendCPU.prototype.where = function (condition) { + assertNotComplex([condition], 'where'); + var condVals = this.readSync(condition.dataId); + return whereImpl(condition.shape, condVals); + }; + MathBackendCPU.prototype.topk = function (x, k, sorted) { + assertNotComplex(x, 'topk'); + var xVals = this.readSync(x.dataId); + return topkImpl(xVals, x.shape, x.dtype, k, sorted); + }; + MathBackendCPU.prototype.min = function (x, axes) { + assertNotComplex(x, 'min'); + tf.backend_util.assertAxesAreInnerMostDims('min', axes, x.rank); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = tf.zeros(outShape, x.dtype); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var min = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value < min) { + min = value; + } + } + vals[i] = min; + } + return result; + }; + MathBackendCPU.prototype.minimum = function (a, b) { + assertNotComplex([a, b], 'minimum'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { return Math.min(aVal, bVal); }); + }; + MathBackendCPU.prototype.mod = function (a, b) { + assertNotComplex([a, b], 'mod'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { + var rem = aVal % bVal; + if ((aVal < 0 && bVal < 0) || (aVal >= 0 && bVal >= 0)) { + return rem; + } + else { + return (rem + bVal) % bVal; + } + }); + }; + MathBackendCPU.prototype.maximum = function (a, b) { + assertNotComplex([a, b], 'maximum'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { return Math.max(aVal, bVal); }); + }; + MathBackendCPU.prototype.all = function (x, axes) { + assertNotComplex(x, 'all'); + tf.backend_util.assertAxesAreInnerMostDims('all', axes, x.rank); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = tf.zeros(outShape, x.dtype); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var all = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + all = all && value; + } + vals[i] = all; + } + return result; + }; + MathBackendCPU.prototype.any = function (x, axes) { + assertNotComplex(x, 'any'); + tf.backend_util.assertAxesAreInnerMostDims('any', axes, x.rank); + var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = tf.zeros(outShape, x.dtype); + var reduceSize = tf.util.sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var anyVal = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + anyVal = anyVal || value; + } + vals[i] = anyVal; + } + return result; + }; + MathBackendCPU.prototype.squaredDifference = function (a, b) { + assertNotComplex([a, b], 'squaredDifference'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { + var diff = aVal - bVal; + return diff * diff; + }); + }; + MathBackendCPU.prototype.ceil = function (x) { + assertNotComplex(x, 'ceil'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.ceil(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.floor = function (x) { + assertNotComplex(x, 'floor'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.floor(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.sign = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (values[i] < 0) { + newValues[i] = -1; + } + else if (values[i] > 0) { + newValues[i] = 1; + } + else { + newValues[i] = 0; + } + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.isNaN = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (Number.isNaN(values[i])) { + newValues[i] = 1; + } + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.isInf = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (Math.abs(values[i]) === Infinity) { + newValues[i] = 1; + } + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.isFinite = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (Number.isFinite(values[i])) { + newValues[i] = 1; + } + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.round = function (x) { + assertNotComplex(x, 'round'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + // The algorithm is based on banker's rounding. + var base = Math.floor(values[i]); + if (values[i] - base < 0.5) { + newValues[i] = Math.floor(values[i]); + } + else if (values[i] - base > 0.5) { + newValues[i] = Math.ceil(values[i]); + } + else { + if (base % 2.0 === 0.0) { + newValues[i] = base; + } + else { + newValues[i] = base + 1.0; + } + } + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.exp = function (x) { + assertNotComplex(x, 'exp'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.exp(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.expm1 = function (x) { + assertNotComplex(x, 'expm1'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.expm1(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.log = function (x) { + assertNotComplex(x, 'log'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = Math.log(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.log1p = function (x) { + assertNotComplex(x, 'log1p'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = Math.log1p(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.sqrt = function (x) { + assertNotComplex(x, 'sqrt'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = Math.sqrt(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.rsqrt = function (x) { + assertNotComplex(x, 'rsqrt'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = 1 / Math.sqrt(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.reciprocal = function (x) { + assertNotComplex(x, 'reciprocal'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = 1 / values[i]; + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.linear = function (x) { + return x; + }; + MathBackendCPU.prototype.relu = function (x) { + assertNotComplex(x, 'relu'); + var res = tf.zeros(x.shape, x.dtype); + var resVals = this.readSync(res.dataId); + var inVals = this.readSync(x.dataId); + for (var i = 0; i < inVals.length; ++i) { + resVals[i] = Math.max(0, inVals[i]); + } + return res; + }; + MathBackendCPU.prototype.relu6 = function (x) { + assertNotComplex(x, 'relu'); + var res = tf.zeros(x.shape, x.dtype); + var resVals = this.readSync(res.dataId); + var inVals = this.readSync(x.dataId); + for (var i = 0; i < inVals.length; ++i) { + resVals[i] = Math.min(Math.max(0, inVals[i]), 6); + } + return res; + }; + MathBackendCPU.prototype.prelu = function (x, a) { + assertNotComplex([x, a], 'prelu'); + return this.broadcastedBinaryOp(x, a, x.dtype, function (xValue, aValue) { return xValue < 0 ? aValue * xValue : xValue; }); + }; + MathBackendCPU.prototype.elu = function (x) { + assertNotComplex(x, 'elu'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + if (v >= 0) { + resultValues[i] = v; + } + else { + resultValues[i] = (Math.exp(v) - 1); + } + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.eluDer = function (dy, y) { + assertNotComplex([dy, y], 'eluDer'); + var resultValues = new Float32Array(y.size); + var values = this.readSync(y.dataId); + var dyValues = this.readSync(dy.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + if (v >= 1) { + resultValues[i] = dyValues[i]; + } + else { + resultValues[i] = dyValues[i] * (v + 1); + } + } + return this.makeOutput(resultValues, y.shape, 'float32'); + }; + MathBackendCPU.prototype.selu = function (x) { + assertNotComplex(x, 'selu'); + // Stable and Attracting Fixed Point (0, 1) for Normalized Weights. + // see: https://arxiv.org/abs/1706.02515 + var scaleAlpha = tf.backend_util.SELU_SCALEALPHA; + var scale = tf.backend_util.SELU_SCALE; + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + if (v >= 0) { + resultValues[i] = scale * v; + } + else { + resultValues[i] = scaleAlpha * (Math.exp(v) - 1); + } + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.clip = function (x, min, max) { + assertNotComplex(x, 'clip'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + resultValues[i] = v > max ? max : (v < min ? min : v); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.abs = function (x) { + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.abs(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.complexAbs = function (x) { + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < x.size; ++i) { + var real = values[i * 2]; + var imag = values[i * 2 + 1]; + resultValues[i] = Math.hypot(real, imag); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.int = function (x) { + assertNotComplex(x, 'int'); + var resultValues = new Int32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = values[i]; + } + return this.makeOutput(resultValues, x.shape, 'int32'); + }; + MathBackendCPU.prototype.sigmoid = function (x) { + assertNotComplex(x, 'sigmoid'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = 1 / (1 + Math.exp(-values[i])); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.softplus = function (x) { + assertNotComplex(x, 'softplus'); + // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX + // epsilon is the difference between 1.0 and the next representable float. + // For a single precision 32 bit float this should be 2^-23, see: + // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm + var epsilon = 1.1920928955078125e-7; + var threshold = Math.log(epsilon) + 2.0; + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + // Value above which exp(x) may overflow, but softplus(x) == x + // is within machine epsilon. + var tooLarge = values[i] > -threshold; + // Value below which exp(x) may underflow, but softplus(x) == exp(x) + // is within machine epsilon. + var tooSmall = values[i] < threshold; + var expX = Math.exp(values[i]); + var result = void 0; + if (tooSmall) { + result = expX; + } + else if (tooLarge) { + result = values[i]; + } + else { + result = Math.log(1.0 + expX); + } + resultValues[i] = result; + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.sin = function (x) { + assertNotComplex(x, 'sin'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.sin(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.cos = function (x) { + assertNotComplex(x, 'cos'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.cos(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.tan = function (x) { + assertNotComplex(x, 'tan'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.tan(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.asin = function (x) { + assertNotComplex(x, 'asin'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.asin(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.acos = function (x) { + assertNotComplex(x, 'acos'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.acos(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.atan = function (x) { + assertNotComplex(x, 'atan'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.atan(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.atan2 = function (a, b) { + assertNotComplex([a, b], 'atan2'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aValue, bValue) { return Math.atan2(aValue, bValue); }); + }; + MathBackendCPU.prototype.sinh = function (x) { + assertNotComplex(x, 'sinh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.sinh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.cosh = function (x) { + assertNotComplex(x, 'cosh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.cosh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.tanh = function (x) { + assertNotComplex(x, 'tanh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = tf.util.tanh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.asinh = function (x) { + assertNotComplex(x, 'asinh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.asinh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.acosh = function (x) { + assertNotComplex(x, 'acosh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.acosh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.atanh = function (x) { + assertNotComplex(x, 'atanh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.atanh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.erf = function (x) { + assertNotComplex(x, 'erf'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + var p = tf.backend_util.ERF_P; + var a1 = tf.backend_util.ERF_A1; + var a2 = tf.backend_util.ERF_A2; + var a3 = tf.backend_util.ERF_A3; + var a4 = tf.backend_util.ERF_A4; + var a5 = tf.backend_util.ERF_A5; + for (var i = 0; i < values.length; ++i) { + var sign = Math.sign(values[i]); + var v = Math.abs(values[i]); + var t = 1.0 / (1.0 + p * v); + resultValues[i] = sign * + (1.0 - + (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + Math.exp(-v * v)); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.step = function (x, alpha) { + if (alpha === void 0) { alpha = 0; } + assertNotComplex(x, 'step'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + if (isNaN(value)) { + resultValues[i] = NaN; + } + else { + resultValues[i] = value > 0 ? 1 : alpha; + } + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.fusedConv2d = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var result = this.conv2d(input, filter, convInfo); + if (bias) { + result = this.add(result, bias); + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights); + } + return result; + }; + MathBackendCPU.prototype.conv2d = function (x, filter, convInfo) { + assertNotComplex([x, filter], 'conv2d'); + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var padLeft = convInfo.padInfo.left; + var padTop = convInfo.padInfo.top; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var y = tf.buffer(convInfo.outShape, x.dtype); + var xBatchStride = x.strides[0]; + var xRowStride = isChannelsLast ? x.strides[1] : x.strides[2]; + var xColStride = isChannelsLast ? x.strides[2] : 1; + var xChannelStride = isChannelsLast ? 1 : x.strides[1]; + var yBatchStride = y.strides[0]; + var yRowStride = isChannelsLast ? y.strides[1] : y.strides[2]; + var yColStride = isChannelsLast ? y.strides[2] : 1; + var yChannelStride = isChannelsLast ? 1 : y.strides[1]; + var xVals = this.readSync(x.dataId); + var wVals = this.readSync(filter.dataId); + var yVals = y.values; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * xBatchStride; + var yOffset1 = b * yBatchStride; + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var yOffset2 = yOffset1 + yR * yRowStride; + var xRCorner = yR * convInfo.strideHeight - padTop; + for (var wR = 0; wR < filterHeight; wR++) { + var xR = xRCorner + wR * dilationHeight; + if (xR < 0 || xR >= convInfo.inHeight) { + continue; + } + var wOffset1 = wR * filter.strides[0]; + var xOffset2 = xOffset1 + xR * xRowStride; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var yOffset3 = yOffset2 + yC * yColStride; + var xCCorner = yC * convInfo.strideWidth - padLeft; + for (var wC = 0; wC < filterWidth; wC++) { + var xC = xCCorner + wC * dilationWidth; + if (xC < 0 || xC >= convInfo.inWidth) { + continue; + } + var wOffset2 = wOffset1 + wC * filter.strides[1]; + var xOffset3 = xOffset2 + xC * xColStride; + var wOffset3 = wOffset2; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var xVal = xVals[xOffset3 + d1 * xChannelStride]; + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + yVals[yOffset3 + d2 * yChannelStride] += + xVal * wVals[wOffset3 + d2]; + } + wOffset3 += convInfo.outChannels; + } + } + } + } + } + } + return y.toTensor(); + }; + MathBackendCPU.prototype.conv3d = function (x, filter, convInfo) { + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var padFront = convInfo.padInfo.front; + var padLeft = convInfo.padInfo.left; + var padTop = convInfo.padInfo.top; + var y = tf.buffer(convInfo.outShape, x.dtype); + var xVals = this.readSync(x.dataId); + var wVals = this.readSync(filter.dataId); + var yVals = y.values; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * x.strides[0]; + var yOffset1 = b * y.strides[0]; + for (var yF = 0; yF < convInfo.outDepth; ++yF) { + var yOffset2 = yOffset1 + yF * y.strides[1]; + var xFCorner = yF * convInfo.strideDepth - padFront; + for (var wF = 0; wF < filterDepth; wF++) { + var xF = xFCorner + wF * dilationDepth; + if (xF < 0 || xF >= convInfo.inDepth) { + continue; + } + var wOffset1 = wF * filter.strides[0]; + var xOffset2 = xOffset1 + xF * x.strides[1]; + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var yOffset3 = yOffset2 + yR * y.strides[2]; + var xRCorner = yR * convInfo.strideHeight - padTop; + for (var wR = 0; wR < filterHeight; wR++) { + var xR = xRCorner + wR * dilationHeight; + if (xR < 0 || xR >= convInfo.inHeight) { + continue; + } + var wOffset2 = wOffset1 + wR * filter.strides[1]; + var xOffset3 = xOffset2 + xR * x.strides[2]; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var yOffset4 = yOffset3 + yC * convInfo.outChannels; + var xCCorner = yC * convInfo.strideWidth - padLeft; + for (var wC = 0; wC < filterWidth; wC++) { + var xC = xCCorner + wC * dilationWidth; + if (xC < 0 || xC >= convInfo.inWidth) { + continue; + } + var wOffset3 = wOffset2 + wC * filter.strides[2]; + var xOffset4 = xOffset3 + xC * convInfo.inChannels; + var wOffset4 = wOffset3; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var xVal = xVals[xOffset4 + d1]; + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2]; + } + wOffset4 += convInfo.outChannels; + } + } + } + } + } + } + } + } + return y.toTensor(); + }; + MathBackendCPU.prototype.conv2dDerInput = function (dy, filter, convInfo) { + assertNotComplex([dy, filter], 'conv2dDerInput'); + var dx = tf.buffer(convInfo.inShape, 'float32'); + var dxValues = dx.values; + var dyValues = this.readSync(dy.dataId); + var fltValues = this.readSync(filter.dataId); + var _a = filter.strides, fltS0 = _a[0], fltS1 = _a[1], fltS2 = _a[2]; + var batchSize = convInfo.batchSize, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth, dataFormat = convInfo.dataFormat; + var topPad = filterHeight - 1 - convInfo.padInfo.top; + var leftPad = filterWidth - 1 - convInfo.padInfo.left; + var isChannelsLast = dataFormat === 'channelsLast'; + var xBatchStride = dx.strides[0]; + var xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2]; + var xColStride = isChannelsLast ? dx.strides[2] : 1; + var xChannelStride = isChannelsLast ? 1 : dx.strides[1]; + var yBatchStride = dy.strides[0]; + var yRowStride = isChannelsLast ? dy.strides[1] : dy.strides[2]; + var yColStride = isChannelsLast ? dy.strides[2] : 1; + var yChannelStride = isChannelsLast ? 1 : dy.strides[1]; + for (var b = 0; b < batchSize; ++b) { + for (var d1 = 0; d1 < inChannels; ++d1) { + for (var xR = 0; xR < inHeight; ++xR) { + var xRCorner = xR - topPad; + var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); + for (var xC = 0; xC < inWidth; ++xC) { + var xCCorner = xC - leftPad; + var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth); + var dotProd = 0; + for (var yR = xRMin; yR < yRMax; ++yR) { + var wR = yR * strideHeight - xRCorner; + for (var yC = xCMin; yC < yCMax; ++yC) { + var wC = yC * strideWidth - xCCorner; + var dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC; + var fltOffset = fltS0 * (filterHeight - 1 - wR) + + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1; + for (var d2 = 0; d2 < outChannels; ++d2) { + var pixel = dyValues[dyOffset + yChannelStride * d2]; + var weight = fltValues[fltOffset + d2]; + dotProd += pixel * weight; + } + } + } + var dxOffset = xBatchStride * b + xRowStride * xR + + xColStride * xC + xChannelStride * d1; + dxValues[dxOffset] = dotProd; + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.conv3dDerInput = function (dy, filter, convInfo) { + var dx = tf.buffer(convInfo.inShape, 'float32'); + var dxValues = dx.values; + var _a = dx.strides, dxS0 = _a[0], dxS1 = _a[1], dxS2 = _a[2], dxS3 = _a[3]; + var dyValues = this.readSync(dy.dataId); + var _b = dy.strides, dyS0 = _b[0], dyS1 = _b[1], dyS2 = _b[2], dyS3 = _b[3]; + var fltValues = this.readSync(filter.dataId); + var _c = filter.strides, fltS0 = _c[0], fltS1 = _c[1], fltS2 = _c[2], fltS3 = _c[3]; + var batchSize = convInfo.batchSize, filterDepth = convInfo.filterDepth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inDepth = convInfo.inDepth, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outDepth = convInfo.outDepth, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideDepth = convInfo.strideDepth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth; + var frontPad = filterDepth - 1 - convInfo.padInfo.front; + var topPad = filterHeight - 1 - convInfo.padInfo.top; + var leftPad = filterWidth - 1 - convInfo.padInfo.left; + for (var b = 0; b < batchSize; ++b) { + for (var d1 = 0; d1 < inChannels; ++d1) { + // Frames of depth + for (var xF = 0; xF < inDepth; ++xF) { + var xFCorner = xF - frontPad; + var xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth)); + var yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth); + // Rows as per standard 2d matrix notation + for (var xR = 0; xR < inHeight; ++xR) { + var xRCorner = xR - topPad; + var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); + // Columns as per standard 2d matrix notation + for (var xC = 0; xC < inWidth; ++xC) { + var xCCorner = xC - leftPad; + var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth); + var dotProd = 0; + for (var yF = xFMin; yF < yFMax; ++yF) { + var wF = yF * strideDepth - xFCorner; + for (var yR = xRMin; yR < yRMax; ++yR) { + var wR = yR * strideHeight - xRCorner; + for (var yC = xCMin; yC < yCMax; ++yC) { + var wC = yC * strideWidth - xCCorner; + var dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC; + var fltOffset = fltS0 * (filterDepth - 1 - wF) + + fltS1 * (filterHeight - 1 - wR) + + fltS2 * (filterWidth - 1 - wC) + fltS3 * d1; + for (var d2 = 0; d2 < outChannels; ++d2) { + var pixel = dyValues[dyOffset + d2]; + var weight = fltValues[fltOffset + d2]; + dotProd += pixel * weight; + } + } + } + } + dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] = + dotProd; + } + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.conv2dDerFilter = function (x, dy, convInfo) { + assertNotComplex([x, dy], 'conv2dDerFilter'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var dW = tf.buffer(convInfo.filterShape, 'float32'); + var leftPad = convInfo.padInfo.left; + var topPad = convInfo.padInfo.top; + var xBuf = this.bufferSync(x); + var dyBuf = this.bufferSync(dy); + for (var wR = 0; wR < filterHeight; ++wR) { + var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight); + for (var wC = 0; wC < filterWidth; ++wC) { + var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth); + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + // Need to convolve. + var dotProd = 0; + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var yR = yRMin; yR < yRMax; ++yR) { + var xR = wR + yR * strideHeight - topPad; + for (var yC = yCMin; yC < yCMax; ++yC) { + var xC = wC + yC * strideWidth - leftPad; + if (isChannelsLast) { + dotProd += + xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2); + } + else { + dotProd += + xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC); + } + } + } + } + dW.set(dotProd, wR, wC, d1, d2); + } + } + } + } + return dW.toTensor(); + }; + MathBackendCPU.prototype.conv3dDerFilter = function (x, dy, convInfo) { + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dw = tf.buffer(convInfo.filterShape, 'float32'); + var dwValues = dw.values; + var _a = dw.strides, dwS0 = _a[0], dwS1 = _a[1], dwS2 = _a[2], dwS3 = _a[3]; + var dyValues = this.readSync(dy.dataId); + var _b = dy.strides, dyS0 = _b[0], dyS1 = _b[1], dyS2 = _b[2], dyS3 = _b[3]; + var xValues = this.readSync(x.dataId); + var _c = x.strides, xS0 = _c[0], xS1 = _c[1], xS2 = _c[2], xS3 = _c[3]; + var frontPad = convInfo.padInfo.front; + var leftPad = convInfo.padInfo.left; + var topPad = convInfo.padInfo.top; + for (var wF = 0; wF < filterDepth; ++wF) { + var yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth)); + var yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth); + var wOffset1 = wF * dwS0; + for (var wR = 0; wR < filterHeight; ++wR) { + var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight); + var wOffset2 = wR * dwS1 + wOffset1; + for (var wC = 0; wC < filterWidth; ++wC) { + var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth); + var wOffset3 = wC * dwS2 + wOffset2; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var wOffset4 = d1 * dwS3 + wOffset3; + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + var dotProd = 0; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * xS0; + var yOffset1 = b * dyS0; + for (var yF = yFMin; yF < yFMax; ++yF) { + var xF = wF + yF * strideDepth - frontPad; + var xOffset2 = xF * xS1 + xOffset1; + var yOffset2 = yF * dyS1 + yOffset1; + for (var yR = yRMin; yR < yRMax; ++yR) { + var xR = wR + yR * strideHeight - topPad; + var xOffset3 = xR * xS2 + xOffset2; + var yOffset3 = yR * dyS2 + yOffset2; + for (var yC = yCMin; yC < yCMax; ++yC) { + var xC = wC + yC * strideWidth - leftPad; + var xOffset4 = xC * xS3 + xOffset3; + var yOffset4 = yC * dyS3 + yOffset3; + dotProd += + xValues[xOffset4 + d1] * dyValues[yOffset4 + d2]; + } + } + } + } + dwValues[wOffset4 + d2] = dotProd; + } + } + } + } + } + return dw.toTensor(); + }; + MathBackendCPU.prototype.fusedDepthwiseConv2D = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var result = this.depthwiseConv2D(input, filter, convInfo); + if (bias) { + result = this.add(result, bias); + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights); + } + return result; + }; + MathBackendCPU.prototype.depthwiseConv2D = function (x, filter, convInfo) { + assertNotComplex([x, filter], 'depthwiseConv2D'); + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var padLeft = convInfo.padInfo.left; + var padTop = convInfo.padInfo.top; + var chMul = convInfo.outChannels / convInfo.inChannels; + var y = tf.buffer(convInfo.outShape, x.dtype); + var xVals = this.readSync(x.dataId); + var wVals = this.readSync(filter.dataId); + var yVals = y.values; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * x.strides[0]; + var yOffset1 = b * y.strides[0]; + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var yOffset2 = yOffset1 + yR * y.strides[1]; + var xRCorner = yR * convInfo.strideHeight - padLeft; + for (var wR = 0; wR < filterHeight; ++wR) { + var xR = xRCorner + wR * dilationHeight; + if (xR < 0 || xR >= convInfo.inHeight) { + continue; + } + var wOffset1 = wR * filter.strides[0]; + var xOffset2 = xOffset1 + xR * x.strides[1]; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var yOffset3 = yOffset2 + yC * y.strides[2]; + var xCCorner = yC * convInfo.strideWidth - padTop; + for (var wC = 0; wC < filterWidth; ++wC) { + var xC = xCCorner + wC * dilationWidth; + if (xC < 0 || xC >= convInfo.inWidth) { + continue; + } + var wOffset2 = wOffset1 + wC * filter.strides[1]; + var xOffset3 = xOffset2 + xC * convInfo.inChannels; + var yOffset4 = yOffset3; + var wOffset3 = wOffset2; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var xVal = xVals[xOffset3 + d1]; + for (var q = 0; q < chMul; ++q) { + yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q]; + } + yOffset4 += chMul; + wOffset3 += chMul; + } + } + } + } + } + } + return y.toTensor(); + }; + MathBackendCPU.prototype.depthwiseConv2DDerInput = function (dy, filter, convInfo) { + assertNotComplex([dy, filter], 'depthwiseConv2DDerInput'); + var dx = tf.buffer(convInfo.inShape, 'float32'); + var dxValues = dx.values; + var _a = dx.strides, dxS0 = _a[0], dxS1 = _a[1], dxS2 = _a[2]; + var dyValues = this.readSync(dy.dataId); + var _b = dy.strides, dyS0 = _b[0], dyS1 = _b[1], dyS2 = _b[2]; + var fltValues = this.readSync(filter.dataId); + var _c = filter.strides, fltS0 = _c[0], fltS1 = _c[1], fltS2 = _c[2]; + var batchSize = convInfo.batchSize, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth; + var topPad = filterHeight - 1 - convInfo.padInfo.top; + var leftPad = filterWidth - 1 - convInfo.padInfo.left; + var chMul = outChannels / inChannels; + for (var b = 0; b < batchSize; ++b) { + for (var d1 = 0; d1 < inChannels; ++d1) { + for (var xR = 0; xR < inHeight; ++xR) { + var xRCorner = xR - topPad; + var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); + for (var xC = 0; xC < inWidth; ++xC) { + var xCCorner = xC - leftPad; + var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth); + var dotProd = 0; + for (var yR = xRMin; yR < yRMax; ++yR) { + var wR = yR * strideHeight - xRCorner; + for (var yC = xCMin; yC < yCMax; ++yC) { + var wC = yC * strideWidth - xCCorner; + var dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC; + var fltOffset = fltS0 * (filterHeight - 1 - wR) + + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1; + for (var dm = 0; dm < chMul; ++dm) { + var d2 = d1 * chMul + dm; + var pixel = dyValues[dyOffset + d2]; + var weight = fltValues[fltOffset + dm]; + dotProd += pixel * weight; + } + } + } + dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd; + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.depthwiseConv2DDerFilter = function (x, dy, convInfo) { + assertNotComplex([x, dy], 'depthwiseConv2DDerFilter'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dW = tf.buffer(convInfo.filterShape, 'float32'); + var leftPad = convInfo.padInfo.left; + var topPad = convInfo.padInfo.top; + var chMul = convInfo.outChannels / convInfo.inChannels; + var xBuf = this.bufferSync(x); + var dyBuf = this.bufferSync(dy); + for (var wR = 0; wR < filterHeight; ++wR) { + var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight); + for (var wC = 0; wC < filterWidth; ++wC) { + var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth); + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + var d1 = Math.trunc(d2 / chMul); + var dm = d2 % chMul; + var dotProd = 0; + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var yR = yRMin; yR < yRMax; ++yR) { + var xR = wR + yR * strideHeight - topPad; + for (var yC = yCMin; yC < yCMax; ++yC) { + var xC = wC + yC * strideWidth - leftPad; + dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2); + } + } + } + dW.set(dotProd, wR, wC, d1, dm); + } + } + } + return dW.toTensor(); + }; + MathBackendCPU.prototype.tile = function (x, reps) { + assertNotComplex(x, 'tile'); + return tile(this.bufferSync(x), reps); + }; + MathBackendCPU.prototype.pad = function (x, paddings, constantValue) { + assertNotComplex(x, 'pad'); + var outShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + x.shape[i] + p[1]; } /* afterPad */); + var start = paddings.map(function (p) { return p[0]; }); + var xBuffer = this.bufferSync(x); + var buffer = tf.buffer(outShape, x.dtype); + if (constantValue !== 0) { + buffer.values.fill(constantValue); + } + for (var i = 0; i < x.size; i++) { + var coords = xBuffer.indexToLoc(i); + var outCoords = coords.map(function (c, i) { return c + start[i]; }); + buffer.set.apply(buffer, [xBuffer.get.apply(xBuffer, coords)].concat(outCoords)); + } + return buffer.toTensor(); + }; + MathBackendCPU.prototype.gather = function (x, indices, axis) { + assertNotComplex([x, indices], 'gather'); + var newShape = x.shape.slice(); + var indicesValues = this.readSync(indices.dataId); + newShape[axis] = indicesValues.length; + var result = tf.buffer(newShape, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < result.size; ++i) { + var newLoc = result.indexToLoc(i); + var originalLoc = newLoc.slice(); + originalLoc[axis] = indicesValues[newLoc[axis]]; + var originalIndex = xBuf.locToIndex(originalLoc); + result.values[i] = xBuf.values[originalIndex]; + } + return result.toTensor(); + }; + MathBackendCPU.prototype.batchToSpaceND = function (x, blockShape, crops) { + assertNotComplex([x], 'batchToSpaceND'); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + var reshaped = tf.backend_util.getReshaped(x.shape, blockShape, prod); + var permuted = tf.backend_util.getPermuted(reshaped.length, blockShape.length); + var reshapedPermuted = tf.backend_util.getReshapedPermuted(x.shape, blockShape, prod); + var sliceBeginCoords = tf.backend_util.getSliceBeginCoords(crops, blockShape.length); + var sliceSize = tf.backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length); + return tf.transpose(x.reshape(reshaped), permuted) + .reshape(reshapedPermuted) + .slice(sliceBeginCoords, sliceSize); + }; + MathBackendCPU.prototype.spaceToBatchND = function (x, blockShape, paddings) { + assertNotComplex([x], 'spaceToBatchND'); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + var completePaddings = [[0, 0]]; + completePaddings.push.apply(completePaddings, paddings); + for (var i = 1 + blockShape.length; i < x.shape.length; ++i) { + completePaddings.push([0, 0]); + } + var paddedX = x.pad(completePaddings); + var reshapedPaddedShape = tf.backend_util.getReshaped(paddedX.shape, blockShape, prod, false); + var permutedReshapedPaddedPermutation = tf.backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false); + var flattenShape = tf.backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false); + return tf.transpose(paddedX.reshape(reshapedPaddedShape), permutedReshapedPaddedPermutation) + .reshape(flattenShape); + }; + MathBackendCPU.prototype.maxPool = function (x, convInfo) { + assertNotComplex(x, 'maxPool'); + var xValues = this.readSync(x.dataId); + return pool(xValues, x.shape, x.dtype, x.strides, convInfo, 'max') + .toTensor(); + }; + MathBackendCPU.prototype.maxPoolBackprop = function (dy, x, y, convInfo) { + assertNotComplex([x, y], 'maxPoolBackprop'); + var xValues = this.readSync(x.dataId); + var maxPosBuf = tf.buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = tf.buffer(x.shape, 'float32'); + var dyBuf = this.bufferSync(dy); + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) { + for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) { + // Shader code begins. + var dyRCorner = dxR - padTop; + var dyCCorner = dxC - padLeft; + var dotProd = 0; + for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) { + var dyR = (dyRCorner + wR) / strideHeight; + if (dyR < 0 || dyR >= convInfo.outHeight || + Math.floor(dyR) !== dyR) { + continue; + } + for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) { + var dyC = (dyCCorner + wC) / strideWidth; + if (dyC < 0 || dyC >= convInfo.outWidth || + Math.floor(dyC) !== dyC) { + continue; + } + var maxPos = effectiveFilterHeight * effectiveFilterWidth - + 1 - maxPosBuf.get(b, dyR, dyC, d); + var curPos = wR * effectiveFilterWidth + wC; + var mask = maxPos === curPos ? 1 : 0; + if (mask === 0) { + continue; + } + var pixel = dyBuf.get(b, dyR, dyC, d); + dotProd += pixel * mask; + } + } + dx.set(dotProd, b, dxR, dxC, d); + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.avgPoolBackprop = function (dy, x, convInfo) { + assertNotComplex([dy, x], 'avgPoolBackprop'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = tf.buffer(x.shape, 'float32'); + var avgMultiplier = 1 / (filterHeight * filterWidth); + var dyBuf = this.bufferSync(dy); + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) { + for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) { + // Shader code begins. + var dyRCorner = dxR - padTop; + var dyCCorner = dxC - padLeft; + var dotProd = 0; + for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) { + var dyR = (dyRCorner + wR) / strideHeight; + if (dyR < 0 || dyR >= convInfo.outHeight || + Math.floor(dyR) !== dyR) { + continue; + } + for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) { + var dyC = (dyCCorner + wC) / strideWidth; + if (dyC < 0 || dyC >= convInfo.outWidth || + Math.floor(dyC) !== dyC) { + continue; + } + var pixel = dyBuf.get(b, dyR, dyC, d); + dotProd += pixel; + } + } + dx.set(dotProd * avgMultiplier, b, dxR, dxC, d); + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.pool3d = function (x, convInfo, poolType) { + assertNotComplex(x, 'pool3d'); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY : + Number.POSITIVE_INFINITY); + var xValues = this.readSync(x.dataId); + var output = tf.buffer(convInfo.outShape, x.dtype); + var outputVals = output.values; + var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * + convInfo.outShape[3] * convInfo.outShape[4]; + var outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4]; + var outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4]; + var outputColStrides = convInfo.outShape[4]; + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + var outputBatchOffset = batch * outputBatchStrides; + var inputBatchOffset = batch * x.strides[0]; + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) { + var xDepthCorner = yDepth * strideDepth - padFront; + var xDepthMin = xDepthCorner; + while (xDepthMin < 0) { + xDepthMin += dilationDepth; + } + var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner); + var outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides; + for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) { + var xRowCorner = yRow * strideHeight - padTop; + var xRowMin = xRowCorner; + while (xRowMin < 0) { + xRowMin += dilationHeight; + } + var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner); + var outputRowOffset = outputDepthOffset + yRow * outputRowStrides; + for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) { + var xColCorner = yCol * strideWidth - padLeft; + var xColMin = xColCorner; + while (xColMin < 0) { + xColMin += dilationWidth; + } + var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); + // Shader code begins + var outputColOffset = outputRowOffset + yCol * outputColStrides; + var minMaxValue = initialValue; + var avgValue = 0; + var count = 0; + for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) { + var xDepthOffset = inputBatchOffset + xDepth * x.strides[1]; + for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) { + var xRowOffset = xDepthOffset + xRow * x.strides[2]; + for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) { + var xColOffset = xRowOffset + xCol * x.strides[3]; + var pixel = xValues[xColOffset + channel]; + if ((poolType === 'max' && pixel > minMaxValue)) { + minMaxValue = pixel; + } + else if (poolType === 'avg') { + avgValue += pixel; + count++; + } + if (isNaN(minMaxValue)) { + break; + } + } + if (isNaN(minMaxValue)) { + break; + } + } + if (isNaN(minMaxValue)) { + break; + } + } + var outputOffset = outputColOffset + channel; + outputVals[outputOffset] = + poolType === 'avg' ? avgValue / count : minMaxValue; + } + } + } + } + } + return output.toTensor(); + }; + MathBackendCPU.prototype.avgPool3d = function (x, convInfo) { + assertNotComplex(x, 'avgPool3d'); + return this.pool3d(x, convInfo, 'avg').toFloat(); + }; + MathBackendCPU.prototype.avgPool3dBackprop = function (dy, x, convInfo) { + assertNotComplex([dy, x], 'avgPool3dBackprop'); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = tf.buffer(x.shape, 'float32'); + var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); + var dyBuf = this.bufferSync(dy); + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) { + for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) { + for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) { + // Shader code begins. + var dyDepthCorner = dxDepth - padFront; + var dyRowCorner = dxRow - padTop; + var dyColCorner = dxCol - padLeft; + var dotProd = 0; + for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) { + var dyDepth = (dyDepthCorner + wDepth) / strideDepth; + if (dyDepth < 0 || dyDepth >= convInfo.outDepth || + Math.floor(dyDepth) !== dyDepth) { + continue; + } + for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) { + var dyRow = (dyRowCorner + wRow) / strideHeight; + if (dyRow < 0 || dyRow >= convInfo.outHeight || + Math.floor(dyRow) !== dyRow) { + continue; + } + for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) { + var dyCol = (dyColCorner + wCol) / strideWidth; + if (dyCol < 0 || dyCol >= convInfo.outWidth || + Math.floor(dyCol) !== dyCol) { + continue; + } + var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel); + dotProd += pixel; + } + } + } + dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel); + } + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.maxPool3d = function (x, convInfo) { + assertNotComplex(x, 'maxPool3d'); + return this.pool3d(x, convInfo, 'max').toFloat(); + }; + MathBackendCPU.prototype.maxPool3dPositions = function (x, convInfo) { + var maxPositions = tf.buffer(convInfo.outShape, 'int32'); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var xBuf = this.bufferSync(x); + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) { + var xDepthCorner = yDepth * strideDepth - padFront; + var xDepthMin = xDepthCorner; + while (xDepthMin < 0) { + xDepthMin += dilationDepth; + } + var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner); + for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) { + var xRowCorner = yRow * strideHeight - padTop; + var xRowMin = xRowCorner; + while (xRowMin < 0) { + xRowMin += dilationHeight; + } + var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner); + for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) { + var xColCorner = yCol * strideWidth - padLeft; + var xColMin = xColCorner; + while (xColMin < 0) { + xColMin += dilationWidth; + } + var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); + // Shader code begins + var maxValue = Number.NEGATIVE_INFINITY; + var maxPosition = -1; + for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) { + var wDepth = xDepth - xDepthCorner; + for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) { + var wRow = xRow - xRowCorner; + for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) { + var wCol = xCol - xColCorner; + var pixel = xBuf.get(batch, xDepth, xRow, xCol, channel); + if (pixel >= maxValue) { + maxValue = pixel; + maxPosition = wDepth * effectiveFilterHeight * + effectiveFilterWidth + + wRow * effectiveFilterHeight + wCol; + } + } + } + } + maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel); + } + } + } + } + } + return maxPositions.toTensor(); + }; + MathBackendCPU.prototype.maxPool3dBackprop = function (dy, x, y, convInfo) { + assertNotComplex([x, y], 'maxPool3dBackprop'); + var maxPositions = this.maxPool3dPositions(x, convInfo); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = tf.buffer(x.shape, 'float32'); + var maxPosBuf = this.bufferSync(maxPositions); + var dyBuf = this.bufferSync(dy); + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) { + for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) { + for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) { + // Shader code begins + var dyDepthCorner = dxDepth - padFront; + var dyRowCorner = dxRow - padTop; + var dyColCorner = dxCol - padLeft; + var dotProd = 0; + for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) { + var dyDepth = (dyDepthCorner + wDepth) / strideDepth; + if (dyDepth < 0 || dyDepth >= convInfo.outDepth || + Math.floor(dyDepth) !== dyDepth) { + continue; + } + for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) { + var dyRow = (dyRowCorner + wRow) / strideHeight; + if (dyRow < 0 || dyRow >= convInfo.outHeight || + Math.floor(dyRow) !== dyRow) { + continue; + } + for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) { + var dyCol = (dyColCorner + wCol) / strideWidth; + if (dyCol < 0 || dyCol >= convInfo.outWidth || + Math.floor(dyCol) !== dyCol) { + continue; + } + var maxPos = effectiveFilterDepth * + effectiveFilterHeight * effectiveFilterWidth - + 1 - + maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel); + var curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth + + wRow * effectiveFilterWidth + wCol; + var mask = maxPos === curPos ? 1 : 0; + if (mask === 0) { + continue; + } + var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel); + dotProd += pixel * mask; + } + } + } + dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel); + } + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.cast = function (x, dtype) { + return tf.backend_util.castTensor(x, dtype, this); + }; + MathBackendCPU.prototype.reshape = function (x, shape) { + return tf.backend_util.reshapeTensor(x, shape); + }; + MathBackendCPU.prototype.avgPool = function (x, convInfo) { + assertNotComplex(x, 'avgPool'); + assertNotComplex(x, 'maxPool'); + var xValues = this.readSync(x.dataId); + return pool(xValues, x.shape, x.dtype, x.strides, convInfo, 'avg') + .toTensor() + .toFloat(); + }; + MathBackendCPU.prototype.resizeBilinear = function (x, newHeight, newWidth, alignCorners) { + assertNotComplex(x, 'resizeBilinear'); + var _a = x.shape, batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], numChannels = _a[3]; + var xValues = this.readSync(x.dataId); + var result = new Float32Array(tf.util.sizeFromShape([batch, newHeight, newWidth, numChannels])); + var effectiveInputSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutputSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + var outputIdx = 0; + var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0]; + var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1]; + for (var b = 0; b < batch; b++) { + for (var r = 0; r < newHeight; r++) { + var sourceFracRow = effectiveRowSizeRatio * r; + var sourceRowFloor = Math.floor(sourceFracRow); + var rowFrac = sourceFracRow - sourceRowFloor; + var sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow)); + var topRowOffset = b * x.strides[0] + sourceRowFloor * x.strides[1]; + var botRowOffset = b * x.strides[0] + sourceRowCeil * x.strides[1]; + for (var c = 0; c < newWidth; c++) { + var sourceFracCol = effectiveColSizeRatio * c; + var sourceColFloor = Math.floor(sourceFracCol); + var colFrac = sourceFracCol - sourceColFloor; + var sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol)); + var topLeftOffest = topRowOffset + sourceColFloor * x.strides[2]; + var botLeftOffset = botRowOffset + sourceColFloor * x.strides[2]; + var topRightOffset = topRowOffset + sourceColCeil * x.strides[2]; + var botRightOffest = botRowOffset + sourceColCeil * x.strides[2]; + for (var d = 0; d < numChannels; d++) { + // Begin shader. + // Compute the fractional index of the source. + var topLeft = xValues[topLeftOffest + d]; + var bottomLeft = xValues[botLeftOffset + d]; + var topRight = xValues[topRightOffset + d]; + var bottomRight = xValues[botRightOffest + d]; + var top_1 = topLeft + (topRight - topLeft) * colFrac; + var bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac; + var newValue = top_1 + (bottom - top_1) * rowFrac; + result[outputIdx++] = newValue; + } + } + } + } + return tf.tensor(result, [batch, newHeight, newWidth, numChannels]); + }; + MathBackendCPU.prototype.resizeBilinearBackprop = function (dy, x, alignCorners) { + assertNotComplex([dy, x], 'resizeBilinearBackprop'); + var _a = x.shape, batch = _a[0], xHeight = _a[1], xWidth = _a[2], depth = _a[3]; + var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; + var output = new Float32Array(batch * xHeight * xWidth * depth); + // In the backwards pass, we want to find the pixels that were generated + // for each pixel in the input image the forward pass and add the + // corresponding coefficient from dy to the gradient (with some + // interpolation). + var effectiveXSize = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + var effectiveYSize = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + var heightScale = effectiveXSize[0] / effectiveYSize[0]; + var widthScale = effectiveXSize[1] / effectiveYSize[1]; + // Reference implementation + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275 + var dyValues = this.readSync(dy.dataId); + var offset = 0; + for (var b = 0; b < batch; b++) { + var bOffset = b * x.strides[0]; + for (var r = 0; r < yHeight; r++) { + var dxR = r * heightScale; + var topDxRIndex = Math.floor(dxR); + var bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1); + var topDxROffset = bOffset + topDxRIndex * x.strides[1]; + var bottomDxROffset = bOffset + bottomDxRIndex * x.strides[1]; + var dxRLerp = dxR - topDxRIndex; + var inverseDxRLerp = 1.0 - dxRLerp; + for (var c = 0; c < yWidth; c++) { + var dxC = c * widthScale; + var leftDxCIndex = Math.floor(dxC); + var rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1); + var dxCLerp = dxC - leftDxCIndex; + var inverseDxCLerp = 1.0 - dxCLerp; + var topLeftRCOffset = topDxROffset + leftDxCIndex * x.strides[2]; + var topRightRCOffset = topDxROffset + rightDxCIndex * x.strides[2]; + var bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * x.strides[2]; + var bottomRightRCOffset = bottomDxROffset + rightDxCIndex * x.strides[2]; + var inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp; + var inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp; + var dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp; + var dxRLerpTimesDxCLerp = dxRLerp * dxCLerp; + for (var d = 0; d < depth; d++) { + var dyVal = dyValues[offset++]; + output[topLeftRCOffset + d] += + dyVal * inverseDxRLerpTimesInverseDxCLerp; + output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp; + output[bottomLeftRCOffset + d] += + dyVal * dxRLerpTimesInverseDxCLerp; + output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp; + } + } + } + } + return tf.tensor4d(output, [batch, xWidth, xHeight, depth], x.dtype); + }; + MathBackendCPU.prototype.resizeNearestNeighbor = function (x, newHeight, newWidth, alignCorners) { + assertNotComplex(x, 'resizeNearestNeighbor'); + var _a = x.shape, batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], numChannels = _a[3]; + var xValues = this.readSync(x.dataId); + var output = new Float32Array(batch * newHeight * newWidth * numChannels); + var effectiveInputSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutputSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0]; + var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1]; + var outputOffset = 0; + for (var b = 0; b < batch; b++) { + var batchOffset = b * x.strides[0]; + for (var r = 0; r < newHeight; r++) { + var sourceFracRow = effectiveRowSizeRatio * r; + var sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : + Math.floor(sourceFracRow)); + var rowOffset = batchOffset + sourceNearestRow * x.strides[1]; + for (var c = 0; c < newWidth; c++) { + var sourceFracCol = effectiveColSizeRatio * c; + var sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) : + Math.floor(sourceFracCol)); + var colOffset = rowOffset + sourceNearestCol * x.strides[2]; + for (var d = 0; d < numChannels; d++) { + // Begin shader. + // Compute the fractional index of the source. + var newVal = xValues[colOffset + d]; + output[outputOffset++] = newVal; + } + } + } + } + return tf.tensor(output, [batch, newHeight, newWidth, numChannels], x.dtype); + }; + MathBackendCPU.prototype.resizeNearestNeighborBackprop = function (dy, x, alignCorners) { + assertNotComplex([dy, x], 'resizeNearestNeighborBackprop'); + var _a = x.shape, batch = _a[0], xHeight = _a[1], xWidth = _a[2], depth = _a[3]; + var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; + var output = new Float32Array(batch * xHeight * xWidth * depth); + var dyValues = this.readSync(dy.dataId); + // In the backwards pass, we want to find the pixels that were generated + // for each pixel in the input image the forward pass + var effectiveXSize = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + var effectiveYSize = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + var heightScale = effectiveXSize[0] / effectiveYSize[0]; + var widthScale = effectiveXSize[1] / effectiveYSize[1]; + var invHeightScale = 1 / heightScale; + var invWidthScale = 1 / widthScale; + // This defines the size of the window of values around a particular + // index in dy that we want to search for contributions to dx. + var winHeight = (Math.ceil(invHeightScale) * 2) + 2; + var winWidth = (Math.ceil(invWidthScale) * 2) + 2; + // Loop over the output space. + for (var b = 0; b < batch; b++) { + var batchOffset = b * x.strides[0]; + for (var r = 0; r < xHeight; r++) { + var rowOffset = batchOffset + r * x.strides[1]; + // Compute bounds for where in dy we will look + var startRLerp = Math.floor(r * invHeightScale); + var startDyR = Math.floor(startRLerp - (winHeight / 2)); + for (var c = 0; c < xWidth; c++) { + var colOffset = rowOffset + c * x.strides[2]; + // Compute bounds for where in dy we will look + var startCLerp = Math.floor(c * invWidthScale); + var startDyC = Math.floor(startCLerp - (winWidth / 2)); + for (var d = 0; d < depth; d++) { + var accum = 0; + // loop over dy + for (var dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) { + var dyR = dyRIndex + startDyR; + // Guard against the window exceeding the bounds of dy + if (dyR < 0 || dyR >= yHeight) { + continue; + } + var dyROffset = batchOffset + dyR * dy.strides[1]; + var sourceFracRow = dyR * heightScale; + var sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) : + Math.floor(sourceFracRow)); + if (r !== sourceNearestRow) { + continue; + } + for (var dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) { + var dyC = dyCIndex + startDyC; + // Guard against the window exceeding the bounds of dy + if (dyC < 0 || dyC >= yWidth) { + continue; + } + var dyCOffset = dyROffset + dyC * dy.strides[2]; + var sourceFracCol = dyC * widthScale; + var sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) : + Math.floor(sourceFracCol)); + if (c === sourceNearestCol) { + accum += dyValues[dyCOffset + d]; + } + } + } + output[colOffset + d] = accum; + } + } + } + } + return tf.tensor4d(output, x.shape, x.dtype); + }; + MathBackendCPU.prototype.batchNorm = function (x, mean, variance, offset, scale, varianceEpsilon) { + assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); + var xVals = this.readSync(x.dataId); + var mVals = this.readSync(mean.dataId); + var varVals = this.readSync(variance.dataId); + var sVals = scale ? this.readSync(scale.dataId) : + new Float32Array([1]); + var offVals = offset ? this.readSync(offset.dataId) : + new Float32Array([0]); + var outVals = new Float32Array(xVals.length); + var offValsLength = offVals.length; + var sValsLength = sVals.length; + var varValsLength = varVals.length; + var mValsLength = mVals.length; + var offi = 0; + var mi = 0; + var si = 0; + var vi = 0; + for (var i = 0; i < xVals.length; ++i) { + outVals[i] = offVals[offi++] + + (xVals[i] - mVals[mi++]) * sVals[si++] / + Math.sqrt(varVals[vi++] + varianceEpsilon); + if (offi >= offValsLength) { + offi = 0; + } + if (mi >= mValsLength) { + mi = 0; + } + if (si >= sValsLength) { + si = 0; + } + if (vi >= varValsLength) { + vi = 0; + } + } + return tf.tensor4d(outVals, x.shape); + }; + MathBackendCPU.prototype.localResponseNormalization4D = function (x, depthRadius, bias, alpha, beta) { + assertNotComplex(x, 'localResponseNormalization4D'); + var channels = x.shape[3]; + var maxD = channels - 1; + var xValues = this.readSync(x.dataId); + var size = x.size; + var result = new Float32Array(size); + function sumAcrossChannels(offset) { + var currentChannel = offset % channels; + var beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius); + var endSumOffset = offset - currentChannel + + Math.min(currentChannel + depthRadius, maxD); + var sum = 0.0; + for (; beginSumOffset <= endSumOffset; beginSumOffset++) { + var z = xValues[beginSumOffset]; + sum += z * z; + } + return sum; + } + for (var offset = 0; offset < size; offset++) { + var sum = sumAcrossChannels(offset); + var val = xValues[offset] * Math.pow(bias + alpha * sum, -beta); + result[offset] = val; + } + return tf.tensor4d(result, x.shape); + }; + MathBackendCPU.prototype.LRNGrad = function (dy, inputImage, outputImage, depthRadius, bias, alpha, beta) { + assertNotComplex(dy, 'LRNGrad'); + var channels = dy.shape[3]; + var dyValues = this.readSync(dy.dataId); + var inputImageValues = this.readSync(inputImage.dataId); + var outputImageValues = this.readSync(outputImage.dataId); + var result = new Float32Array(dy.size); + var size = dy.size; + for (var offset = 0; offset < size; offset++) { + var currentChannel = offset % channels; + var depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius); + var depthEnd = (offset - currentChannel) + + Math.min(channels, currentChannel + depthRadius + 1); + var norm = 0; + for (var k = depthBegin; k < depthEnd; k++) { + norm += Math.pow(inputImageValues[k], 2); + } + norm = alpha * norm + bias; + for (var k = depthBegin; k < depthEnd; k++) { + var dyi = -2 * alpha * beta * inputImageValues[k] * + outputImageValues[offset] / norm; + if (offset === k) { + dyi += Math.pow(norm, -beta); + } + dyi *= dyValues[offset]; + result[k] += dyi; + } + } + return tf.tensor4d(result, dy.shape); + }; + MathBackendCPU.prototype.multinomial = function (logits, normalized, numSamples, seed) { + assertNotComplex(logits, 'multinomial'); + var probabilities = normalized ? logits : tf.softmax(logits); + var batchSize = probabilities.shape[0]; + var numEvents = probabilities.shape[1]; + var res = tf.zeros([batchSize, numSamples], 'int32'); + var resVals = this.readSync(res.dataId); + var probVals = this.readSync(probabilities.dataId); + for (var b = 0; b < batchSize; ++b) { + var offset = b * numEvents; + // The cdf won't include the last event. It will be implicit if no other + // event happened. + var cdf = new Float32Array(numEvents - 1); + cdf[0] = probVals[offset]; + for (var event_1 = 1; event_1 < cdf.length; ++event_1) { + cdf[event_1] = cdf[event_1 - 1] + probVals[offset + event_1]; + } + var random = seedrandom.alea(seed.toString()); + var outOffset = b * numSamples; + for (var sampleId = 0; sampleId < numSamples; ++sampleId) { + var r = random(); + // Assume last event happened by default. + resVals[outOffset + sampleId] = cdf.length; + for (var event_2 = 0; event_2 < cdf.length; event_2++) { + if (r < cdf[event_2]) { + resVals[outOffset + sampleId] = event_2; + break; + } + } + } + } + return res; + }; + MathBackendCPU.prototype.oneHot = function (indices, depth, onValue, offValue) { + assertNotComplex(indices, 'oneHot'); + var res = new Float32Array(indices.size * depth); + res.fill(offValue); + var indicesVal = this.readSync(indices.dataId); + for (var event_3 = 0; event_3 < indices.size; ++event_3) { + if (indicesVal[event_3] >= 0 && indicesVal[event_3] < depth) { + res[event_3 * depth + indicesVal[event_3]] = onValue; + } + } + return tf.tensor2d(res, [indices.size, depth], 'int32'); + }; + MathBackendCPU.prototype.nonMaxSuppression = function (boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + assertNotComplex(boxes, 'nonMaxSuppression'); + var boxesVals = this.readSync(boxes.dataId); + var scoresVals = this.readSync(scores.dataId); + return nonMaxSuppressionV3(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + }; + MathBackendCPU.prototype.fft = function (x) { + return this.fftBatch(x, false); + }; + MathBackendCPU.prototype.ifft = function (x) { + return this.fftBatch(x, true); + }; + /** + * Calculate FFT of inner most elements of batch tensor. + */ + MathBackendCPU.prototype.fftBatch = function (x, inverse) { + var batch = x.shape[0]; + var innerDim = x.shape[1]; + // Collects real and imaginary values separately. + var realResult = tf.buffer(x.shape, 'float32'); + var imagResult = tf.buffer(x.shape, 'float32'); + var real = tf.real(x).as2D(batch, innerDim); + var imag = tf.imag(x).as2D(batch, innerDim); + for (var b = 0; b < batch; b++) { + // TODO: Support slice ops for complex type. + var r = real.slice([b, 0], [1, innerDim]); + var i = imag.slice([b, 0], [1, innerDim]); + var input = tf.complex(r, i); + // Run FFT by batch element. + var res = this.readSync(this.fftImpl(input, inverse).dataId); + for (var d = 0; d < innerDim; d++) { + var c = tf.backend_util.getComplexWithIndex(res, d); + realResult.values[b * innerDim + d] = c.real; + imagResult.values[b * innerDim + d] = c.imag; + } + } + var t = tf.complex(realResult.toTensor(), imagResult.toTensor()); + return t.as2D(batch, innerDim); + }; + MathBackendCPU.prototype.fftImpl = function (x, inverse) { + var x1D = x.as1D(); + var n = x1D.size; + if (this.isExponentOf2(n)) { + var result = this.fftRadix2(x1D, n, inverse).as2D(x.shape[0], x.shape[1]); + if (inverse) { + result = tf.complex(tf.real(result).div(tf.scalar(n)), tf.imag(result).div(tf.scalar(n))); + } + return result; + } + else { + var data = this.readSync(x.dataId); + var rawOutput = this.fourierTransformByMatmul(data, n, inverse); + var output = tf.backend_util.splitRealAndImagArrays(rawOutput); + return tf.complex(output.real, output.imag).as2D(x.shape[0], x.shape[1]); + } + }; + MathBackendCPU.prototype.isExponentOf2 = function (size) { + return (size & size - 1) === 0; + }; + // FFT using Cooley-Tukey algorithm on radix 2 dimensional input. + MathBackendCPU.prototype.fftRadix2 = function (input, size, inverse) { + if (size === 1) { + return input; + } + var data = this.readSync(input.dataId); + var half = size / 2; + var evenComplex = tf.backend_util.complexWithEvenIndex(data); + var evenTensor = tf.complex(evenComplex.real, evenComplex.imag).as1D(); + var oddComplex = tf.backend_util.complexWithOddIndex(data); + var oddTensor = tf.complex(oddComplex.real, oddComplex.imag).as1D(); + // Recursive call for half part of original input. + evenTensor = this.fftRadix2(evenTensor, half, inverse); + oddTensor = this.fftRadix2(oddTensor, half, inverse); + var e = tf.backend_util.exponents(size, inverse); + var exponent = tf.complex(e.real, e.imag).mul(oddTensor); + var addPart = evenTensor.add(exponent); + var subPart = evenTensor.sub(exponent); + var realTensor = tf.real(addPart).concat(tf.real(subPart)); + var imagTensor = tf.imag(addPart).concat(tf.imag(subPart)); + return tf.complex(realTensor, imagTensor).as1D(); + }; + // Calculate fourier transform by multplying sinusoid matrix. + MathBackendCPU.prototype.fourierTransformByMatmul = function (data, size, inverse) { + var ret = new Float32Array(size * 2); + // TODO: Use matmul instead once it supports complex64 type. + for (var r = 0; r < size; r++) { + var real = 0.0; + var imag = 0.0; + for (var c = 0; c < size; c++) { + var e = tf.backend_util.exponent(r * c, size, inverse); + var term = tf.backend_util.getComplexWithIndex(data, c); + real += term.real * e.real - term.imag * e.imag; + imag += term.real * e.imag + term.imag * e.real; + } + if (inverse) { + real /= size; + imag /= size; + } + tf.backend_util.assignToTypedArray(ret, real, imag, r); + } + return ret; + }; + MathBackendCPU.prototype.depthToSpace = function (x, blockSize, dataFormat) { + tf.util.assert(dataFormat === 'NHWC', function () { return "Only NHWC dataFormat supported on CPU for depthToSpace. Got " + dataFormat; }); + tf.util.assert(blockSize > 1, function () { + return "blockSize should be > 1 for depthToSpace, but was: " + blockSize; + }); + var batchSize = x.shape[0]; + var inputHeight = x.shape[1]; + var inputWidth = x.shape[2]; + var inputDepth = x.shape[3]; + var outputHeight = inputHeight * blockSize; + var outputWidth = inputWidth * blockSize; + var outputDepth = inputDepth / (blockSize * blockSize); + var xValues = this.readSync(x.dataId); + var result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth); + var outputIdx = 0; + for (var b = 0; b < batchSize; ++b) { + for (var h = 0; h < outputHeight; ++h) { + var inH = Math.floor(h / blockSize); + var offsetH = (h % blockSize); + for (var w = 0; w < outputWidth; ++w) { + var inW = Math.floor(w / blockSize); + var offsetW = (w % blockSize); + var offsetD = (offsetH * blockSize + offsetW) * outputDepth; + for (var d = 0; d < outputDepth; ++d) { + var inD = d + offsetD; + var inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b)); + result[outputIdx++] = xValues[inputIdx]; + } + } + } + } + return tf.tensor4d(result, [batchSize, outputHeight, outputWidth, outputDepth]); + }; + MathBackendCPU.prototype.broadcastedBinaryOp = function (a, b, dtype, op) { + var newShape = tf.backend_util.assertAndGetBroadcastShape(a.shape, b.shape); + var result = tf.buffer(newShape, dtype); + var aVals = this.readSync(a.dataId); + var bVals = this.readSync(b.dataId); + var aBroadcastDims = tf.backend_util.getBroadcastDims(a.shape, newShape); + var bBroadcastDims = tf.backend_util.getBroadcastDims(b.shape, newShape); + var resVals = result.values; + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (var i = 0; i < resVals.length; ++i) { + resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); + } + } + else { + var aBuf = this.bufferSync(a); + var bBuf = this.bufferSync(b); + var _loop_2 = function (i) { + var loc = result.indexToLoc(i); + var aLoc = loc.slice(-a.rank); + aBroadcastDims.forEach(function (d) { return aLoc[d] = 0; }); + var aIndex = aBuf.locToIndex(aLoc); + var bLoc = loc.slice(-b.rank); + bBroadcastDims.forEach(function (d) { return bLoc[d] = 0; }); + var bIndex = bBuf.locToIndex(bLoc); + resVals[i] = op(aVals[aIndex], bVals[bIndex]); + }; + for (var i = 0; i < resVals.length; ++i) { + _loop_2(i); + } + } + return result.toTensor(); + }; + MathBackendCPU.prototype.broadcastedBinaryComplexOp = function (a, b, op) { + var newShape = tf.backend_util.assertAndGetBroadcastShape(a.shape, b.shape); + var realResult = tf.buffer(newShape, 'float32'); + var imagResult = tf.buffer(newShape, 'float32'); + var aVals = this.readSync(a.dataId); + var bVals = this.readSync(b.dataId); + var aBroadcastDims = tf.backend_util.getBroadcastDims(a.shape, newShape); + var bBroadcastDims = tf.backend_util.getBroadcastDims(b.shape, newShape); + var realVals = realResult.values; + var imagVals = imagResult.values; + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (var i = 0; i < realVals.length; i++) { + var aIdx = i % aVals.length; + var bIdx = i % bVals.length; + var result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]); + realVals[i] = result.real; + imagVals[i] = result.imag; + } + } + else { + var aRealBuf = this.bufferSync(this.data.get(a.dataId).complexTensors.real); + var bRealBuf = this.bufferSync(this.data.get(b.dataId).complexTensors.real); + var _loop_3 = function (i) { + var loc = realResult.indexToLoc(i); + var aLoc = loc.slice(-a.rank); + aBroadcastDims.forEach(function (d) { return aLoc[d] = 0; }); + var aIndex = aRealBuf.locToIndex(aLoc); + var bLoc = loc.slice(-b.rank); + bBroadcastDims.forEach(function (d) { return bLoc[d] = 0; }); + var bIndex = bRealBuf.locToIndex(bLoc); + var opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]); + realVals[i] = opResult.real; + imagVals[i] = opResult.imag; + }; + for (var i = 0; i < realVals.length; i++) { + _loop_3(i); + } + } + return this.complex(realResult.toTensor(), imagResult.toTensor()); + }; + MathBackendCPU.prototype.split = function (x, sizeSplits, axis) { + return split(x, sizeSplits, axis); + }; + MathBackendCPU.prototype.dispose = function () { }; + MathBackendCPU.prototype.floatPrecision = function () { + return 32; + }; + /** Returns the smallest representable number. */ + MathBackendCPU.prototype.epsilon = function () { + return _super.prototype.epsilon.call(this); + }; + MathBackendCPU.prototype.cropAndResize = function (images, boxes, boxIndex, cropSize, method, extrapolationValue) { + var _a = images.shape, batch = _a[0], imageHeight = _a[1], imageWidth = _a[2], numChannels = _a[3]; + var numBoxes = boxes.shape[0]; + var cropHeight = cropSize[0], cropWidth = cropSize[1]; + var output = tf.buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32'); + var boxVals = this.readSync(boxes.dataId); + var boxIndVals = this.readSync(boxIndex.dataId); + var imageVals = this.readSync(images.dataId); + var inStride = images.strides; // to calculate flat indexes into image + var outStride = output.strides; // to calculate flat indexes into output + // Reference implementation + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc + for (var b = 0; b < numBoxes; b++) { + var startInd = b * 4; + var y1 = boxVals[startInd]; + var x1 = boxVals[startInd + 1]; + var y2 = boxVals[startInd + 2]; + var x2 = boxVals[startInd + 3]; + var bInd = boxIndVals[b]; + if (bInd >= batch) { + continue; + } + var heightScale = (cropHeight > 1) ? + (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : + 0; + var widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0; + for (var y = 0; y < cropHeight; y++) { + var yInd = (cropHeight > 1) ? + y1 * (imageHeight - 1) + y * (heightScale) : + 0.5 * (y1 + y2) * (imageHeight - 1); + if (yInd < 0 || yInd > imageHeight - 1) { + for (var x = 0; x < cropWidth; x++) { + for (var c = 0; c < numChannels; c++) { + var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = extrapolationValue; + } + } + continue; + } + if (method === 'bilinear') { + var topInd = Math.floor(yInd); + var bottomInd = Math.ceil(yInd); + var yLerp = yInd - topInd; + for (var x = 0; x < cropWidth; x++) { + var xInd = (cropWidth > 1) ? + x1 * (imageWidth - 1) + x * widthScale : + 0.5 * (x1 + x2) * (imageWidth - 1); + if (xInd < 0 || xInd > imageWidth - 1) { + for (var c = 0; c < numChannels; c++) { + var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = extrapolationValue; + } + continue; + } + var leftInd = Math.floor(xInd); + var rightInd = Math.ceil(xInd); + var xLerp = xInd - leftInd; + for (var c = 0; c < numChannels; c++) { + var ind = c + leftInd * inStride[2] + topInd * inStride[1] + + bInd * inStride[0]; + var topLeft = imageVals[ind]; + ind = c + rightInd * inStride[2] + topInd * inStride[1] + + bInd * inStride[0]; + var topRight = imageVals[ind]; + ind = c + leftInd * inStride[2] + bottomInd * inStride[1] + + bInd * inStride[0]; + var bottomLeft = imageVals[ind]; + ind = c + rightInd * inStride[2] + bottomInd * inStride[1] + + bInd * inStride[0]; + var bottomRight = imageVals[ind]; + var top_2 = topLeft + (topRight - topLeft) * xLerp; + var bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp; + ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = top_2 + ((bottom - top_2) * yLerp); + } + } + } + else { // method == "nearest" + for (var x = 0; x < cropWidth; ++x) { + var xInd = (cropWidth > 1) ? + x1 * (imageWidth - 1) + x * widthScale : + 0.5 * (x1 + x2) * (imageWidth - 1); + if (xInd < 0 || xInd > imageWidth - 1) { + for (var c = 0; c < numChannels; c++) { + var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = extrapolationValue; + } + continue; + } + var closestX = Math.round(xInd); + var closestY = Math.round(yInd); + for (var c = 0; c < numChannels; c++) { + var inInd = c + closestX * inStride[2] + + closestY * inStride[1] + bInd * inStride[0]; + var outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[outInd] = imageVals[inInd]; + } + } + } + } + } + return output.toTensor(); + }; + MathBackendCPU.prototype.sparseToDense = function (sparseIndices, sparseValues, outputShape, defaultValue) { + var _a = tf.backend_util.calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; + var sumDupeIndices = false; + return this.scatter(sparseIndices, sparseValues, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices); + }; + MathBackendCPU.prototype.gatherND = function (x, indices) { + var indicesShape = indices.shape; + var sliceRank = indicesShape[indicesShape.length - 1]; + var _a = tf.backend_util.prepareAndValidate(x, indices), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3]; + if (numSlices === 0) { + return tf.tensor([], resultShape, x.dtype); + } + var buffer = new tf.TensorBuffer([numSlices, sliceSize], x.dtype); + var indicesData = this.readSync(indices.dataId); + var xData = this.readSync(x.dataId); + for (var i = 0; i < numSlices; i++) { + var index = []; + var flattenIndex = 0; + for (var j = 0; j < sliceRank; j++) { + var dim = indicesData[i * sliceRank + j]; + flattenIndex += dim * strides[j]; + index.push(dim); + } + if (flattenIndex < 0 || flattenIndex >= x.size / sliceSize) { + throw new Error("Invalid indices: " + index + " does not index into " + x.shape); + } + for (var k = 0; k < sliceSize; k++) { + buffer.values[i * sliceSize + k] = xData[flattenIndex * sliceSize + k]; + } + } + return buffer.toTensor().reshape(resultShape); + }; + MathBackendCPU.prototype.scatterND = function (indices, updates, shape) { + var _a = tf.backend_util.calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; + var defaultValue = tf.scalar(0); + var sumDupeIndices = true; + return this.scatter(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices); + }; + MathBackendCPU.prototype.fill = function (shape, value, dtype) { + dtype = dtype || tf.util.inferDtype(value); + var values = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(shape)); + values.fill(value); + return tf.engine().makeTensor(values, shape, dtype, this); + }; + MathBackendCPU.prototype.onesLike = function (x) { + if (x.dtype === 'string') { + throw new Error('onesLike is not supported for string tensors'); + } + else { + return this.fill(x.shape, 1, x.dtype); + } + }; + MathBackendCPU.prototype.zerosLike = function (x) { + var values = tf.util.getArrayFromDType(x.dtype, tf.util.sizeFromShape(x.shape)); + return this.makeOutput(values, x.shape, x.dtype); + }; + MathBackendCPU.prototype.linspace = function (start, stop, num) { + return tf.backend_util.linspaceImpl(start, stop, num); + }; + MathBackendCPU.prototype.scatter = function (indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) { + var flattenShape = [outputSize / sliceSize, sliceSize]; + var indicesData = this.readSync(indices.dataId); + var updatesData = this.readSync(updates.dataId); + if (outputSize === 0) { + return tf.tensor([], shape, updates.dtype); + } + var buffer = new tf.TensorBuffer(flattenShape, updates.dtype); + buffer.values.fill(this.readSync(defaultValue.dataId)[0]); + for (var i = 0; i < numUpdates; i++) { + var index = []; + var flattenIndex = 0; + for (var j = 0; j < sliceRank; j++) { + var dim = indicesData[i * sliceRank + j]; + index.push(dim); + flattenIndex += dim * strides[j]; + } + if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) { + throw new Error("Invalid indices: " + index + " does not index into " + shape); + } + for (var k = 0; k < sliceSize; k++) { + if (sumDupeIndices) { + buffer.values[flattenIndex * sliceSize + k] += + updatesData[i * sliceSize + k]; + } + else { + buffer.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? + updatesData[0] : + updatesData[i * sliceSize + k]; + } + } + } + return buffer.toTensor().reshape(shape); + }; + return MathBackendCPU; + }(tf.KernelBackend)); + + /** @license See the LICENSE file. */ + // This code is auto-generated, do not modify this file! + var version = '0.0.0'; + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function createBinaryKernelConfig(name, op) { + return { + kernelName: name, + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, backend = _a.backend; + var _b = inputs, a = _b.a, b = _b.b; + var cpuBackend = backend; + assertNotComplex([a, b], name); + var aVals = cpuBackend.data.get(a.dataId).values; + var bVals = cpuBackend.data.get(b.dataId).values; + var _c = op(a.shape, b.shape, aVals, bVals, a.dtype), resultData = _c[0], resultShape = _c[1]; + var dataId = cpuBackend.write(resultData, resultShape, a.dtype); + return { dataId: dataId, shape: resultShape, dtype: a.dtype }; + } + }; + } + function createBinaryKernelImpl(op) { + return function (aShape, bShape, aVals, bVals, dtype) { + var newShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); + var resultRank = newShape.length; + var resultStrides = tf.util.computeStrides(newShape); + var resultSize = tf.util.sizeFromShape(newShape); + var result = tf.util.getTypedArrayFromDType(dtype, resultSize); + var aRank = aShape.length; + var bRank = bShape.length; + var aStrides = tf.util.computeStrides(aShape); + var bStrides = tf.util.computeStrides(bShape); + var aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, newShape); + var bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, newShape); + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (var i = 0; i < result.length; ++i) { + result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); + } + } + else { + var _loop_1 = function (i) { + var loc = tf.util.indexToLoc(i, resultRank, resultStrides); + var aLoc = loc.slice(-aRank); + aBroadcastDims.forEach(function (d) { return aLoc[d] = 0; }); + var aIndex = tf.util.locToIndex(aLoc, aRank, aStrides); + var bLoc = loc.slice(-bRank); + bBroadcastDims.forEach(function (d) { return bLoc[d] = 0; }); + var bIndex = tf.util.locToIndex(bLoc, bRank, bStrides); + result[i] = op(aVals[aIndex], bVals[bIndex]); + }; + for (var i = 0; i < result.length; ++i) { + _loop_1(i); + } + } + return [result, newShape]; + }; + } + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var divImpl = createBinaryKernelImpl(function (a, b) { return a / b; }); + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var divConfig = createBinaryKernelConfig(tf.Div, divImpl); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var maxConfig = { + kernelName: tf.Max, + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; + var x = inputs.x; + var reductionIndices = attrs.reductionIndices; + var cpuBackend = backend; + var xShape = x.shape; + var xRank = xShape.length; + var origAxes = tf.util.parseAxisParam(reductionIndices, xShape); + var axes = origAxes; + var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); + var xVals = cpuBackend.data.get(x.dataId).values; + if (permutedAxes != null) { + var newShape = new Array(xRank); + for (var i = 0; i < newShape.length; i++) { + newShape[i] = xShape[permutedAxes[i]]; + } + xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape); + axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); + xShape = newShape; + } + assertNotComplex(x, 'max'); + tf.backend_util.assertAxesAreInnerMostDims('max', axes, xRank); + var _b = tf.backend_util.computeOutAndReduceShapes(xShape, axes), maxOutShape = _b[0], reduceShape = _b[1]; + var reduceSize = tf.util.sizeFromShape(reduceShape); + var result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype); + var dataId = cpuBackend.write(result, maxOutShape, x.dtype); + return { dataId: dataId, shape: maxOutShape, dtype: x.dtype }; + } + }; + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) { + var strides = tf.util.computeStrides(xShape); + var maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max'); + var maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex); + return [maxPools.values, maxPositions.values]; + } + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var maxPoolWithArgmaxConfig = { + kernelName: tf.MaxPoolWithArgmax, + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; + var x = inputs.x; + var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, pad = _b.pad, includeBatchInIndex = _b.includeBatchInIndex; + var cpuBackend = backend; + assertNotComplex(x, 'MaxPoolWithArgmax'); + var values = cpuBackend.data.get(x.dataId).values; + var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad); + var _c = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo), pooled = _c[0], indexes = _c[1]; + var pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype); + var indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype); + return [ + { dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype }, + { dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' } + ]; + } + }; + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var nonMaxSuppressionV5 = tf.kernel_impls.nonMaxSuppressionV5; + var nonMaxSuppressionV5Config = { + kernelName: tf.NonMaxSuppressionV5, + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs; + var _b = inputs, boxes = _b.boxes, scores = _b.scores; + var _c = attrs, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold, softNmsSigma = _c.softNmsSigma; + var cpuBackend = backend; + assertNotComplex(boxes, 'NonMaxSuppressionWithScore'); + var boxesVals = cpuBackend.data.get(boxes.dataId).values; + var scoresVals = cpuBackend.data.get(scores.dataId).values; + var maxOutputSizeVal = maxOutputSize; + var iouThresholdVal = iouThreshold; + var scoreThresholdVal = scoreThreshold; + var softNmsSigmaVal = softNmsSigma; + var _d = nonMaxSuppressionV5(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal), selectedIndices = _d.selectedIndices, selectedScores = _d.selectedScores; + return [selectedIndices, selectedScores]; + } + }; + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var squareConfig = { + kernelName: tf.Square, + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, backend = _a.backend; + var x = inputs.x; + var cpuBackend = backend; + assertNotComplex(x, 'square'); + var values = cpuBackend.data.get(x.dataId).values; + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = value * value; + } + var dataId = cpuBackend.write(newValues, x.shape, x.dtype); + return { dataId: dataId, shape: x.shape, dtype: x.dtype }; + } + }; + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var squaredDifferenceImpl = createBinaryKernelImpl(function (aVal, bVal) { + var diff = aVal - bVal; + return diff * diff; + }); + var squaredDifferenceConfig = createBinaryKernelConfig(tf.SquaredDifference, squaredDifferenceImpl); + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var transposeConfig = { + kernelName: tf.Transpose, + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; + var x = inputs.x; + var perm = attrs.perm; + var cpuBackend = backend; + assertNotComplex(x, 'transpose'); + var xRank = x.shape.length; + var newShape = new Array(xRank); + for (var i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[perm[i]]; + } + var values = cpuBackend.data.get(x.dataId).values; + var result = transposeImpl(values, x.shape, x.dtype, perm, newShape); + var dataId = cpuBackend.write(result, newShape, x.dtype); + return { dataId: dataId, shape: newShape, dtype: x.dtype }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // List all kernel configs here + var kernelConfigs = [ + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, + transposeConfig, maxPoolWithArgmaxConfig, maxConfig + ]; + for (var _i = 0, kernelConfigs_1 = kernelConfigs; _i < kernelConfigs_1.length; _i++) { + var kernelConfig = kernelConfigs_1[_i]; + tf.registerKernel(kernelConfig); + } + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Side effects for default initialization of MathBackendCPU + tf.registerBackend('cpu', function () { return new MathBackendCPU(); }, 1 /* priority */); + + exports.MathBackendCPU = MathBackendCPU; + exports.shared = shared; + exports.version_cpu = version; + + Object.defineProperty(exports, '__esModule', { value: true }); + +}))); +//# sourceMappingURL=tf-backend-cpu.js.map diff --git a/tfjs-core/benchmarks/tf-backend-wasm.js b/tfjs-core/benchmarks/tf-backend-wasm.js new file mode 100644 index 00000000000..21228a6e61b --- /dev/null +++ b/tfjs-core/benchmarks/tf-backend-wasm.js @@ -0,0 +1,2941 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +(function (global, factory) { + typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core'), require('path'), require('fs')) : + typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core', 'path', 'fs'], factory) : + (global = global || self, factory((global.tf = global.tf || {}, global.tf.wasm = global.tf.wasm || {}), global.tf, global.path, global.fs)); +}(this, (function (exports, tfjsCore, path, fs) { 'use strict'; + + path = path && path.hasOwnProperty('default') ? path['default'] : path; + fs = fs && fs.hasOwnProperty('default') ? fs['default'] : fs; + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // This enum must align with the enum defined in cc/backend.h. + var CppDType; + (function (CppDType) { + CppDType[CppDType["float32"] = 0] = "float32"; + CppDType[CppDType["int32"] = 1] = "int32"; + CppDType[CppDType["bool"] = 2] = "bool"; + CppDType[CppDType["string"] = 3] = "string"; + CppDType[CppDType["complex64"] = 4] = "complex64"; + })(CppDType || (CppDType = {})); + // Must match enum in cc/fusable_activations.h. + var FusableActivation; + (function (FusableActivation) { + FusableActivation[FusableActivation["linear"] = 0] = "linear"; + FusableActivation[FusableActivation["relu"] = 1] = "relu"; + FusableActivation[FusableActivation["relu6"] = 2] = "relu6"; + FusableActivation[FusableActivation["prelu"] = 3] = "prelu"; + })(FusableActivation || (FusableActivation = {})); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFusedMatMul; + function setup(backend) { + wasmFusedMatMul = backend.wasm.cwrap('_FusedMatMul', null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'array', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number' // out_id + ]); + } + function fusedBatchMatMul(args) { + const { inputs, backend, attrs } = args; + const { a, b, bias, preluActivationWeights } = inputs; + if (a.dtype !== 'float32' || b.dtype !== 'float32') { + throw new Error(`_FusedMatMul for non non-float32 tensors not yet supported.`); + } + const { transposeA, transposeB, activation } = attrs; + const aId = backend.dataIdMap.get(a.dataId).id; + const bId = backend.dataIdMap.get(b.dataId).id; + let biasId = 0; + if (bias != null) { + const biasData = backend.dataIdMap.get(bias.dataId); + if (biasData.shape.length !== 1) { + throw new Error(`_FusedMatMul only supports rank-1 bias but got ` + + `rank ${biasData.shape.length}.`); + } + biasId = biasData.id; + } + const preluActivationWeightsId = preluActivationWeights == null ? + 0 : + backend.dataIdMap.get(preluActivationWeights.dataId).id; + const fusedActivation = FusableActivation[activation]; + if (fusedActivation == null) { + throw new Error(`${activation} activation not yet supported for FusedConv2D ` + + `in the wasm backend.`); + } + const leftDim = transposeA ? a.shape[2] : a.shape[1]; + const rightDim = transposeB ? b.shape[1] : b.shape[2]; + const batchDim = a.shape[0]; + const out = backend.makeOutput([batchDim, leftDim, rightDim], a.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); + const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); + wasmFusedMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: '_FusedMatMul', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: fusedBatchMatMul + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function registerUnaryKernel(kernelName) { + let wasmFunc; + function setupFunc(backend) { + wasmFunc = + backend.wasm.cwrap(kernelName, null /* void */, ['number', 'number']); + } + function kernelFunc(args) { + const { backend, inputs: { x } } = args; + const xId = backend.dataIdMap.get(x.dataId).id; + const out = backend.makeOutput(x.shape, x.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(out.shape) === 0) { + return out; + } + wasmFunc(xId, outId); + return out; + } + tfjsCore.registerKernel({ kernelName, backendName: 'wasm', setupFunc, kernelFunc }); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Abs'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function registerBinaryKernel(kernelName, supportsFullBroadcast, dtype) { + let wasmFunc; + function setupFunc(backend) { + wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'array', + 'number', + 'number', + 'number' // out_id + ]); + } + function kernelFunc(args) { + const { backend, inputs } = args; + const { a, b } = inputs; + const aId = backend.dataIdMap.get(a.dataId).id; + const bId = backend.dataIdMap.get(b.dataId).id; + const outputType = dtype != null ? dtype : a.dtype; + const newShape = tfjsCore.backend_util.assertAndGetBroadcastShape(a.shape, b.shape); + const out = backend.makeOutput(newShape, outputType); + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(newShape) === 0) { + return out; + } + const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); + const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); + const outId = backend.dataIdMap.get(out.dataId).id; + const kernelFunc = () => wasmFunc(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId); + if (supportsFullBroadcast) { + kernelFunc(); + return out; + } + const aBroadcastDims = tfjsCore.backend_util.getBroadcastDims(a.shape, newShape); + const bBroadcastDims = tfjsCore.backend_util.getBroadcastDims(b.shape, newShape); + const loopsOverAllOfA = aBroadcastDims.every((v, i) => v === i); + const loopsOverAllOfB = bBroadcastDims.every((v, i) => v === i); + if (loopsOverAllOfA && loopsOverAllOfB) { + kernelFunc(); + return out; + } + else { + throw new Error(`Broadcasting along outer dims is not yet ` + + `supported for ${kernelName}.`); + } + } + tfjsCore.registerKernel({ kernelName, backendName: 'wasm', setupFunc, kernelFunc }); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast = true; + registerBinaryKernel('Add', supportsFullBroadcast); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFunc; + function setupFunc(backend) { + wasmFunc = backend.wasm.cwrap('AddN', null /* void */, [ + 'array', + 'number', + 'number', + 'number', + ]); + } + function addn(args) { + const { inputs, backend } = args; + const out = backend.makeOutput(inputs[0].shape, inputs[0].dtype); + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(out.shape) === 0) { + return out; + } + const inputIds = inputs.map(x => backend.dataIdMap.get(x.dataId).id); + const inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmFunc(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'AddN', + backendName: 'wasm', + setupFunc, + kernelFunc: addn, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFunc$1; + function setup$1(backend) { + wasmFunc$1 = backend.wasm.cwrap('ArgMax', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number' // out_id + ]); + } + function argmax(args) { + const { inputs: { x }, backend, attrs: { axis } } = args; + const outShape = x.shape.slice(0, -1); + const out = backend.makeOutput(outShape, 'int32'); + const xId = backend.dataIdMap.get(x.dataId).id; + const outId = backend.dataIdMap.get(out.dataId).id; + const outerSize = tfjsCore.util.sizeFromShape(out.shape); + const innerSize = x.shape[axis]; + wasmFunc$1(xId, CppDType[x.dtype], outerSize, innerSize, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'ArgMax', + backendName: 'wasm', + kernelFunc: argmax, + setupFunc: setup$1 + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmAvgPool; + function setup$2(backend) { + wasmAvgPool = backend.wasm.cwrap('AvgPool', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function avgPool(args) { + const { inputs, attrs, backend } = args; + const x = inputs.x; + const xId = backend.dataIdMap.get(x.dataId).id; + const { filterSize, strides, pad, dimRoundingMode } = attrs; + const convInfo = tfjsCore.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const channels = convInfo.inChannels; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error(`wasm backend does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) { + throw new Error(`was backend only supports average pooling with dilation = [1, 1], ` + + `got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`); + } + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'AvgPool', + backendName: 'wasm', + setupFunc: setup$2, + kernelFunc: avgPool + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmBatchMatMul; + function setup$3(backend) { + wasmBatchMatMul = backend.wasm.cwrap('BatchMatMul', null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'array', + 'number', + 'number', + 'number', + 'number' // out_id + ]); + } + function batchMatMul(args) { + const { inputs, backend, attrs } = args; + const { a, b } = inputs; + if (a.dtype !== 'float32' || b.dtype !== 'float32') { + throw new Error(`BatchMatMul for non non-float32 tensors not yet supported.`); + } + const { transposeA, transposeB } = attrs; + const aId = backend.dataIdMap.get(a.dataId).id; + const bId = backend.dataIdMap.get(b.dataId).id; + const leftDim = transposeA ? a.shape[2] : a.shape[1]; + const rightDim = transposeB ? b.shape[1] : b.shape[2]; + const batchDim = a.shape[0]; + const out = backend.makeOutput([batchDim, leftDim, rightDim], a.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); + const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); + wasmBatchMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'BatchMatMul', + backendName: 'wasm', + setupFunc: setup$3, + kernelFunc: batchMatMul + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function cast(args) { + const { inputs: { x }, attrs: { dtype }, backend } = args; + const out = backend.makeOutput(x.shape, dtype); + const inVals = backend.typedArrayFromHeap(x); + const outVals = backend.typedArrayFromHeap(out); + outVals.set(inVals); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Cast', + backendName: 'wasm', + kernelFunc: cast, + }); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmClip; + function setup$4(backend) { + wasmClip = backend.wasm.cwrap('ClipByValue', null /* void */, [ + 'number', + 'number', + 'number', + 'number' // out_id + ]); + } + function clip(args) { + const { inputs, backend, attrs } = args; + const { x } = inputs; + const { min, max } = attrs; + const xId = backend.dataIdMap.get(x.dataId).id; + const out = backend.makeOutput(x.shape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmClip(xId, min, max, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'ClipByValue', + backendName: 'wasm', + setupFunc: setup$4, + kernelFunc: clip + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function concat(args) { + const { inputs, backend, attrs: { axis } } = args; + const outShape = tfjsCore.backend_util.computeOutShape(inputs.map(t => t.shape), axis); + const out = backend.makeOutput(outShape, inputs[0].dtype); + const batchDim = tfjsCore.util.sizeFromShape(inputs[0].shape.slice(0, axis)); + let sumInnerDims = 0; + const innerDims = inputs.map(input => { + const innerDim = tfjsCore.util.sizeFromShape(input.shape.slice(axis)); + sumInnerDims += innerDim; + return innerDim; + }); + const inVals = inputs.map(input => backend.typedArrayFromHeap(input)); + const outVals = backend.typedArrayFromHeap(out); + for (let b = 0; b < batchDim; b++) { + let outOffset = b * sumInnerDims; + for (let i = 0; i < inVals.length; i++) { + const innerDim = innerDims[i]; + const inOffset = b * innerDim; + const vals = inVals[i].subarray(inOffset, inOffset + innerDim); + outVals.set(vals, outOffset); + outOffset += innerDim; + } + } + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Concat', + backendName: 'wasm', + kernelFunc: concat, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmConv2d; + function setup$5(backend) { + wasmConv2d = backend.wasm.cwrap('Conv2D', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function conv2d(args) { + const { inputs, attrs, backend } = args; + const { x, filter } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const filterId = backend.dataIdMap.get(filter.dataId).id; + const { strides, dilations, pad, dimRoundingMode, dataFormat } = attrs; + const $dataFormat = tfjsCore.backend_util.convertConv2DDataFormat(dataFormat); + const convInfo = tfjsCore.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const dilationHeight = convInfo.dilationHeight; + const dilationWidth = convInfo.dilationWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const inputChannels = convInfo.inChannels; + const outputChannels = convInfo.outChannels; + const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error(`wasm backend Conv2D does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmConv2d(xId, x.shape[0], x.shape[1], x.shape[2], filterId, filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Conv2D', + backendName: 'wasm', + setupFunc: setup$5, + kernelFunc: conv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Cos'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Must match enum in CropAndResize.cc + var InterpolationMethod; + (function (InterpolationMethod) { + InterpolationMethod[InterpolationMethod["bilinear"] = 0] = "bilinear"; + InterpolationMethod[InterpolationMethod["nearest"] = 1] = "nearest"; + })(InterpolationMethod || (InterpolationMethod = {})); + let wasmCropAndResize; + function setup$6(backend) { + wasmCropAndResize = backend.wasm.cwrap('CropAndResize', null /*void*/, [ + 'number', + 'number', + 'number', + 'number', + 'array', + 'number', + 'number', + 'number', + 'number', + 'number' // out id + ]); + } + function cropAndResize(args) { + const { backend, inputs, attrs } = args; + const { method, extrapolationValue, cropSize } = attrs; + const { images, boxes, boxInd } = inputs; + const numBoxes = boxes.shape[0]; + const [cropHeight, cropWidth] = cropSize; + const outShape = [numBoxes, cropHeight, cropWidth, images.shape[3]]; + let imagesData = backend.dataIdMap.get(images.dataId); + let castedData; + if (images.dtype !== 'float32') { + castedData = + cast({ backend, inputs: { x: images }, attrs: { dtype: 'float32' } }); + imagesData = backend.dataIdMap.get(castedData.dataId); + } + const imagesId = imagesData.id; + const boxesId = backend.dataIdMap.get(boxes.dataId).id; + const boxIndId = backend.dataIdMap.get(boxInd.dataId).id; + const out = backend.makeOutput(outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + const imagesShapeBytes = new Uint8Array(new Int32Array(images.shape).buffer); + wasmCropAndResize(imagesId, boxesId, boxIndId, numBoxes, imagesShapeBytes, cropHeight, cropWidth, InterpolationMethod[method], extrapolationValue, outId); + if (castedData != null) { + backend.disposeData(castedData.dataId); + } + return out; + } + tfjsCore.registerKernel({ + kernelName: 'CropAndResize', + backendName: 'wasm', + setupFunc: setup$6, + kernelFunc: cropAndResize + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmDepthwiseConv2d; + function setup$7(backend) { + wasmDepthwiseConv2d = + backend.wasm.cwrap('DepthwiseConv2dNative', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function depthwiseConv2d(args) { + const { inputs, attrs, backend } = args; + const { x, filter } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const filterId = backend.dataIdMap.get(filter.dataId).id; + const { strides, dilations, pad, dimRoundingMode } = attrs; + const $dilations = dilations == null ? [1, 1] : dilations; + const convInfo = tfjsCore.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const dilationHeight = convInfo.dilationHeight; + const dilationWidth = convInfo.dilationWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const inputChannels = convInfo.inChannels; + const outputChannels = convInfo.outChannels; + const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error(`wasm backend DepthwiseConv2dNative does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmDepthwiseConv2d(xId, x.shape[0], x.shape[1], x.shape[2], filterId, filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'DepthwiseConv2dNative', + backendName: 'wasm', + setupFunc: setup$7, + kernelFunc: depthwiseConv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$1 = false; + registerBinaryKernel('Div', supportsFullBroadcast$1); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Exp'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$2 = false; + registerBinaryKernel('FloorDiv', supportsFullBroadcast$2); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmBatchNorm; + function setup$8(backend) { + wasmBatchNorm = backend.wasm.cwrap('FusedBatchNorm', null /* void */, ['number', 'number', 'number', 'number', 'number', 'number', 'number']); + } + function fusedBatchNorm(args) { + const { backend, inputs, attrs } = args; + const { varianceEpsilon } = attrs; + const { x, mean, variance, offset, scale } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const meanId = backend.dataIdMap.get(mean.dataId).id; + const varianceId = backend.dataIdMap.get(variance.dataId).id; + const offsetId = offset != null ? backend.dataIdMap.get(offset.dataId).id : 0; + const scaleId = scale != null ? backend.dataIdMap.get(scale.dataId).id : 0; + const out = backend.makeOutput(x.shape, x.dtype); + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + const outId = backend.dataIdMap.get(out.dataId).id; + wasmBatchNorm(xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'FusedBatchNorm', + backendName: 'wasm', + setupFunc: setup$8, + kernelFunc: fusedBatchNorm + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFusedConv2d; + function setup$9(backend) { + wasmFusedConv2d = backend.wasm.cwrap('FusedConv2D', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function fusedConv2d(args) { + const { inputs, attrs, backend } = args; + const { convInfo, activation } = attrs; + const fusedActivation = FusableActivation[activation]; + if (fusedActivation == null) { + throw new Error(`${activation} activation not yet supported for FusedConv2D ` + + `in the wasm backend.`); + } + const { x, filter, bias, preluActivationWeights } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const filterId = backend.dataIdMap.get(filter.dataId).id; + const outputChannels = convInfo.outChannels; + let biasId = 0; + if (bias != null) { + const biasData = backend.dataIdMap.get(bias.dataId); + if (biasData.shape.length !== 1) { + throw new Error(`FusedConv2D only supports rank-1 bias but got ` + + `rank ${biasData.shape.length}.`); + } + if (biasData.shape[0] !== outputChannels) { + throw new Error(`FusedConv2D bias shape (${biasData.shape}) does not ` + + `match the number of output channels (${outputChannels})`); + } + biasId = biasData.id; + } + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const dilationHeight = convInfo.dilationHeight; + const dilationWidth = convInfo.dilationWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const inputChannels = convInfo.inChannels; + const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + const batchSize = convInfo.batchSize; + const inHeight = convInfo.inHeight; + const inWidth = convInfo.inWidth; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error(`wasm backend FusedConv2D does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + const preluActivationWeightsId = preluActivationWeights == null ? + 0 : + backend.dataIdMap.get(preluActivationWeights.dataId).id; + wasmFusedConv2d(xId, batchSize, inHeight, inWidth, filterId, filterHeight, filterWidth, biasId, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, fusedActivation, preluActivationWeightsId, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'FusedConv2D', + backendName: 'wasm', + setupFunc: setup$9, + kernelFunc: fusedConv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFusedDepthwiseConv2d; + function setup$a(backend) { + wasmFusedDepthwiseConv2d = + backend.wasm.cwrap('FusedDepthwiseConv2D', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function fusedDepthwiseConv2d(args) { + const { inputs, attrs, backend } = args; + const { convInfo, activation } = attrs; + const fusedActivation = FusableActivation[activation]; + if (fusedActivation == null) { + throw new Error(`${activation} activation not yet supported for FusedDepthwiseConv2D ` + + `in the wasm backend.`); + } + const { x, filter, bias, preluActivationWeights } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const filterId = backend.dataIdMap.get(filter.dataId).id; + const outputChannels = convInfo.outChannels; + let biasId = 0; + if (bias != null) { + const biasData = backend.dataIdMap.get(bias.dataId); + if (biasData.shape.length !== 1) { + throw new Error(`FusedDepthwiseConv2D only supports rank-1 bias but got ` + + `rank ${biasData.shape.length}.`); + } + if (biasData.shape[0] !== outputChannels) { + throw new Error(`FusedDepthwiseConv2D bias shape (${biasData.shape}) does not ` + + `match the number of output channels (${outputChannels})`); + } + biasId = biasData.id; + } + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const dilationHeight = convInfo.dilationHeight; + const dilationWidth = convInfo.dilationWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const inputChannels = convInfo.inChannels; + const isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + const batchSize = convInfo.batchSize; + const inHeight = convInfo.inHeight; + const inWidth = convInfo.inWidth; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error(`wasm backend FusedDepthwiseConv2D does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + const preluActivationWeightsId = preluActivationWeights == null ? + 0 : + backend.dataIdMap.get(preluActivationWeights.dataId).id; + wasmFusedDepthwiseConv2d(xId, batchSize, inHeight, inWidth, filterId, filterHeight, filterWidth, biasId, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, fusedActivation, preluActivationWeightsId, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'FusedDepthwiseConv2D', + backendName: 'wasm', + setupFunc: setup$a, + kernelFunc: fusedDepthwiseConv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmGather; + function setup$b(backend) { + wasmGather = backend.wasm.cwrap('Gather', null /*void*/, [ + 'number', + 'number', + 'array', + 'number', + 'number', + 'number', + 'array', + 'number' // outId + ]); + } + function gather(args) { + const { backend, inputs, attrs } = args; + const { x, indices } = inputs; + const { axis } = attrs; + const newShape = x.shape.slice(); + newShape[axis] = tfjsCore.util.sizeFromShape(indices.shape); + const stridesSize = x.shape.length - 1; + const out = backend.makeOutput(newShape, x.dtype); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + const xData = backend.dataIdMap.get(x.dataId); + const xId = xData.id; + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + const outId = backend.dataIdMap.get(out.dataId).id; + const xStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(x.shape)).buffer); + const outStridesBytes = new Uint8Array(new Int32Array(tfjsCore.util.computeStrides(newShape)).buffer); + wasmGather(xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId, axis, outStridesBytes, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Gather', + backendName: 'wasm', + setupFunc: setup$b, + kernelFunc: gather + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmGatherNd; + function setup$c(backend) { + wasmGatherNd = backend.wasm.cwrap('GatherNd', null /*void*/, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'array', + 'number' // outId + ]); + } + function gatherNd(args) { + const { backend, inputs } = args; + const { x, indices } = inputs; + const [resultShape, numSlices, sliceSize, strides] = tfjsCore.gather_util.prepareAndValidate(x, indices); + const out = backend.makeOutput(resultShape, x.dtype); + if (numSlices === 0) { + return out; + } + const indicesShape = indices.shape; + const sliceRank = indicesShape[indicesShape.length - 1]; + const xData = backend.dataIdMap.get(x.dataId); + const xId = xData.id; + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmGatherNd(xId, CppDType[x.dtype], indicesId, numSlices, sliceRank, sliceSize, stridesBytes, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'GatherNd', + backendName: 'wasm', + setupFunc: setup$c, + kernelFunc: gatherNd + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$3 = false; + registerBinaryKernel('Greater', supportsFullBroadcast$3, 'bool'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$4 = false; + registerBinaryKernel('GreaterEqual', supportsFullBroadcast$4, 'bool'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$5 = false; + registerBinaryKernel('Less', supportsFullBroadcast$5, 'bool'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$6 = false; + registerBinaryKernel('LessEqual', supportsFullBroadcast$6, 'bool'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Log'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$7 = false; + registerBinaryKernel('LogicalAnd', supportsFullBroadcast$7, 'bool'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmMax; + function setup$d(backend) { + wasmMax = + backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']); + } + function max(args) { + const { backend, inputs, attrs } = args; + const { reductionIndices } = attrs; + const { x } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const origAxes = tfjsCore.util.parseAxisParam(reductionIndices, x.shape); + tfjsCore.backend_util.assertAxesAreInnerMostDims('max', origAxes, x.shape.length); + const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(x.shape, origAxes); + const reduceSize = tfjsCore.util.sizeFromShape(reduceShape); + const out = backend.makeOutput(outShape, x.dtype); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + const outId = backend.dataIdMap.get(out.dataId).id; + wasmMax(xId, reduceSize, outId); + return out; + } + tfjsCore.registerKernel({ kernelName: tfjsCore.Max, backendName: 'wasm', setupFunc: setup$d, kernelFunc: max }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$8 = false; + registerBinaryKernel('Maximum', supportsFullBroadcast$8); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmMaxPool; + function setup$e(backend) { + wasmMaxPool = backend.wasm.cwrap('MaxPool', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function maxPool(args) { + const { inputs, attrs, backend } = args; + const x = inputs.x; + const xId = backend.dataIdMap.get(x.dataId).id; + const { filterSize, strides, pad, dimRoundingMode } = attrs; + const convInfo = tfjsCore.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padRight = convInfo.padInfo.right; + const padBottom = convInfo.padInfo.bottom; + const padLeft = convInfo.padInfo.left; + const dilationHeight = convInfo.dilationHeight; + const dilationWidth = convInfo.dilationWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const inputChannels = convInfo.inChannels; + const outputChannels = convInfo.outChannels; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error(`wasm backend does not support dataFormat:'` + + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + } + const out = backend.makeOutput(convInfo.outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmMaxPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'MaxPool', + backendName: 'wasm', + setupFunc: setup$e, + kernelFunc: maxPool + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmMin; + function setup$f(backend) { + wasmMin = + backend.wasm.cwrap('Min', null /*void*/, ['number, number, number']); + } + function min(args) { + const { backend, inputs, attrs } = args; + const { axes } = attrs; + const { x } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + tfjsCore.backend_util.assertAxesAreInnerMostDims('min', axes, x.shape.length); + const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(x.shape, axes); + const reduceSize = tfjsCore.util.sizeFromShape(reduceShape); + const out = backend.makeOutput(outShape, x.dtype); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + const outId = backend.dataIdMap.get(out.dataId).id; + wasmMin(xId, reduceSize, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Min', + backendName: 'wasm', + setupFunc: setup$f, + kernelFunc: min + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$9 = false; + registerBinaryKernel('Minimum', supportsFullBroadcast$9); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$a = true; + registerBinaryKernel('Mul', supportsFullBroadcast$a); + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Neg'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Parse the result of the c++ method, which has the shape equivalent to + * `Result`. + */ + function parseResultStruct(backend, resOffset) { + const result = new Int32Array(backend.wasm.HEAPU8.buffer, resOffset, 3); + const pSelectedIndices = result[0]; + const selectedSize = result[1]; + const pSelectedScores = result[2]; + // Since the result was allocated on the heap, we have to delete it. + backend.wasm._free(resOffset); + return { pSelectedIndices, selectedSize, pSelectedScores }; + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFunc$2; + function setup$g(backend) { + wasmFunc$2 = backend.wasm.cwrap('NonMaxSuppressionV3', 'number', // Result* + [ + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function kernelFunc(args) { + const { backend, inputs, attrs } = args; + const { iouThreshold, maxOutputSize, scoreThreshold } = attrs; + const { boxes, scores } = inputs; + const boxesId = backend.dataIdMap.get(boxes.dataId).id; + const scoresId = backend.dataIdMap.get(scores.dataId).id; + const resOffset = wasmFunc$2(boxesId, scoresId, maxOutputSize, iouThreshold, scoreThreshold); + const { pSelectedIndices, selectedSize, pSelectedScores } = parseResultStruct(backend, resOffset); + // Since we are not using scores for V3, we have to delete it from the heap. + backend.wasm._free(pSelectedScores); + const selectedIndicesTensor = backend.makeOutput([selectedSize], 'int32', pSelectedIndices); + return selectedIndicesTensor; + } + tfjsCore.registerKernel({ + kernelName: 'NonMaxSuppressionV3', + backendName: 'wasm', + setupFunc: setup$g, + kernelFunc, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFunc$3; + function setup$h(backend) { + wasmFunc$3 = backend.wasm.cwrap('NonMaxSuppressionV5', 'number', // Result* + [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function kernelFunc$1(args) { + const { backend, inputs, attrs } = args; + const { iouThreshold, maxOutputSize, scoreThreshold, softNmsSigma } = attrs; + const { boxes, scores } = inputs; + const boxesId = backend.dataIdMap.get(boxes.dataId).id; + const scoresId = backend.dataIdMap.get(scores.dataId).id; + const resOffset = wasmFunc$3(boxesId, scoresId, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); + const { pSelectedIndices, selectedSize, pSelectedScores, } = parseResultStruct(backend, resOffset); + const selectedIndicesTensor = backend.makeOutput([selectedSize], 'int32', pSelectedIndices); + const selectedScoresTensor = backend.makeOutput([selectedSize], 'float32', pSelectedScores); + return [selectedIndicesTensor, selectedScoresTensor]; + } + tfjsCore.registerKernel({ + kernelName: 'NonMaxSuppressionV5', + backendName: 'wasm', + setupFunc: setup$h, + kernelFunc: kernelFunc$1, + }); + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$b = false; + registerBinaryKernel('NotEqual', supportsFullBroadcast$b, 'bool'); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function onesLike(args) { + const { inputs: { x }, backend } = args; + const out = backend.makeOutput(x.shape, x.dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(1); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'OnesLike', + backendName: 'wasm', + kernelFunc: onesLike, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmPadV2; + function setup$i(backend) { + wasmPadV2 = backend.wasm.cwrap('PadV2', null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'array', + 'number', + 'number', + ]); + } + function pad(args) { + const { inputs: { x }, backend, attrs: { paddings, constantValue } } = args; + const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */); + const xId = backend.dataIdMap.get(x.dataId).id; + const out = backend.makeOutput(outShape, x.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + const paddingsFlat = [].concat(...paddings); + const paddingsBytes = new Uint8Array(new Int32Array(paddingsFlat).buffer); + wasmPadV2(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], paddingsBytes, constantValue, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'PadV2', + backendName: 'wasm', + kernelFunc: pad, + setupFunc: setup$i + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$c = false; + registerBinaryKernel('Pow', supportsFullBroadcast$c); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmPrelu; + function setup$j(backend) { + wasmPrelu = backend.wasm.cwrap('Prelu', null /* void */, [ + 'number', + 'number', + 'number' // out_id + ]); + } + function prelu(args) { + const { inputs, backend } = args; + const { x, alpha } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const weightsId = backend.dataIdMap.get(alpha.dataId).id; + const out = backend.makeOutput(x.shape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmPrelu(xId, weightsId, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Prelu', + backendName: 'wasm', + setupFunc: setup$j, + kernelFunc: prelu + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Relu'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Relu6'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function reshape(args) { + const { inputs: { x }, attrs: { shape } } = args; + return { dataId: x.dataId, shape, dtype: x.dtype }; + } + tfjsCore.registerKernel({ + kernelName: 'Reshape', + backendName: 'wasm', + kernelFunc: reshape, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmResizeBilinear; + function setup$k(backend) { + wasmResizeBilinear = backend.wasm.cwrap('ResizeBilinear', null /*void*/, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number' // outId + ]); + } + function resizeBilinear(args) { + const { backend, inputs, attrs } = args; + const { x } = inputs; + const { alignCorners, newHeight, newWidth } = attrs; + const [batch, oldHeight, oldWidth, numChannels] = x.shape; + const outShape = [batch, newHeight, newWidth, numChannels]; + let xData = backend.dataIdMap.get(x.dataId); + let castedData; + if (xData.dtype !== 'float32') { + castedData = cast({ backend, inputs: { x }, attrs: { dtype: 'float32' } }); + xData = backend.dataIdMap.get(castedData.dataId); + } + const xId = xData.id; + const out = backend.makeOutput(outShape, 'float32'); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + const outId = backend.dataIdMap.get(out.dataId).id; + wasmResizeBilinear(xId, batch, oldHeight, oldWidth, numChannels, newHeight, newWidth, alignCorners ? 1 : 0, outId); + if (castedData != null) { + backend.disposeData(castedData.dataId); + } + return out; + } + tfjsCore.registerKernel({ + kernelName: 'ResizeBilinear', + backendName: 'wasm', + setupFunc: setup$k, + kernelFunc: resizeBilinear + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Rsqrt'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmScatterNd; + function setup$l(backend) { + wasmScatterNd = backend.wasm.cwrap('ScatterNd', null /*void*/, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'array', + 'number', + 'number' // outId + ]); + } + function scatterNd(args) { + const { backend, inputs, attrs } = args; + const { indices, updates } = inputs; + const { shape } = attrs; + const out = backend.makeOutput(shape, updates.dtype); + if (tfjsCore.util.sizeFromShape(shape) === 0) { + return out; + } + const { sliceRank, numUpdates, sliceSize, strides, outputSize } = tfjsCore.scatter_util.calculateShapes(updates, indices, shape); + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + const updatesData = backend.dataIdMap.get(updates.dataId); + const updatesId = updatesData.id; + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmScatterNd(indicesId, updatesId, CppDType[updates.dtype], sliceRank, numUpdates, sliceSize, stridesBytes, outputSize, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'ScatterNd', + backendName: 'wasm', + setupFunc: setup$l, + kernelFunc: scatterNd + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFunc$4; + function setup$m(backend) { + wasmFunc$4 = + backend.wasm.cwrap('Sigmoid', null /* void */, ['number', 'number']); + } + function sigmoid(args) { + const { backend, inputs: { x } } = args; + const xId = backend.dataIdMap.get(x.dataId).id; + const out = backend.makeOutput(x.shape, x.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(out.shape) === 0) { + return out; + } + wasmFunc$4(xId, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Sigmoid', + backendName: 'wasm', + setupFunc: setup$m, + kernelFunc: sigmoid + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Sin'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function slice(args) { + const { inputs: { x }, attrs: { begin, size }, backend } = args; + const isContinous = tfjsCore.slice_util.isSliceContinous(x.shape, begin, size); + const xVals = backend.typedArrayFromHeap(x); + const out = backend.makeOutput(size, x.dtype); + const outVals = backend.typedArrayFromHeap(out); + const xStrides = tfjsCore.util.computeStrides(x.shape); + if (isContinous) { + const flatOffset = tfjsCore.slice_util.computeFlatOffset(begin, xStrides); + outVals.set(xVals.subarray(flatOffset, flatOffset + tfjsCore.util.sizeFromShape(size))); + return out; + } + const rank = x.shape.length; + if (rank === 2) { + slice2d(xVals, xStrides[0], outVals, begin, size); + } + else if (rank === 3) { + slice3d(xVals, xStrides[0], xStrides[1], outVals, begin, size); + } + else if (rank === 4) { + slice4d(xVals, xStrides[0], xStrides[1], xStrides[2], outVals, begin, size); + } + else { + genericSliceSlow(xVals, x, outVals, begin, size); + } + return out; + } + function slice2d(xVals, xStride, outVals, begin, size) { + let outOffset = 0; + const beginI = begin[0]; + const beginJ = begin[1]; + const endI = beginI + size[0]; + for (let i = beginI; i < endI; i++) { + const xOffset = i * xStride + beginJ; + outVals.set(xVals.subarray(xOffset, xOffset + size[1]), outOffset); + outOffset += size[1]; + } + } + function slice3d(xVals, xStride1, xStride2, outVals, begin, size) { + let outOffset = 0; + const beginI = begin[0]; + const beginJ = begin[1]; + const beginK = begin[2]; + const endI = beginI + size[0]; + const endJ = beginJ + size[1]; + for (let i = beginI; i < endI; i++) { + for (let j = beginJ; j < endJ; j++) { + const xOffset = i * xStride1 + j * xStride2 + beginK; + outVals.set(xVals.subarray(xOffset, xOffset + size[2]), outOffset); + outOffset += size[2]; + } + } + } + function slice4d(xVals, xStride1, xStride2, xStride3, outVals, begin, size) { + let outOffset = 0; + const beginI = begin[0]; + const beginJ = begin[1]; + const beginK = begin[2]; + const endI = beginI + size[0]; + const endJ = beginJ + size[1]; + const endK = beginK + size[2]; + const beginL = begin[3]; + for (let i = beginI; i < endI; i++) { + for (let j = beginJ; j < endJ; j++) { + for (let k = beginK; k < endK; k++) { + const xOffset = i * xStride1 + j * xStride2 + k * xStride3 + beginL; + outVals.set(xVals.subarray(xOffset, xOffset + size[3]), outOffset); + outOffset += size[3]; + } + } + } + } + function genericSliceSlow(xVals, xInfo, outVals, begin, size) { + const outBuf = tfjsCore.buffer(size, xInfo.dtype, outVals); + const xBuf = tfjsCore.buffer(xInfo.shape, xInfo.dtype, xVals); + for (let i = 0; i < outBuf.size; ++i) { + const loc = outBuf.indexToLoc(i); + const xLoc = loc.map((idx, j) => idx + begin[j]); + outVals[i] = xBuf.get(...xLoc); + } + } + tfjsCore.registerKernel({ + kernelName: 'Slice', + backendName: 'wasm', + kernelFunc: slice, + }); + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmFunc$5; + function setup$n(backend) { + wasmFunc$5 = backend.wasm.cwrap('Softmax', null /* void */, [ + 'number', + 'number', + 'number', + 'number' // batch + ]); + } + function softmax(args) { + const { backend, inputs: { logits }, attrs: { dim } } = args; + const xId = backend.dataIdMap.get(logits.dataId).id; + const out = backend.makeOutput(logits.shape, logits.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + const channels = logits.shape[dim]; + const batch = tfjsCore.util.sizeFromShape(logits.shape) / channels; + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(out.shape) === 0) { + return out; + } + wasmFunc$5(xId, outId, channels, batch); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Softmax', + backendName: 'wasm', + setupFunc: setup$n, + kernelFunc: softmax + }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function split(args) { + const { inputs, attrs, backend } = args; + const { x } = inputs; + const { numOrSizeSplits, axis } = attrs; + const $axis = tfjsCore.util.parseAxisParam(axis, x.shape)[0]; + let splitSizes; + if (typeof (numOrSizeSplits) === 'number') { + splitSizes = + new Array(numOrSizeSplits).fill(x.shape[$axis] / numOrSizeSplits); + } + else { + splitSizes = numOrSizeSplits; + } + const begin = new Array(x.shape.length).fill(0); + const size = x.shape.slice(); + return splitSizes.map(s => { + const xSliceSize = [...size]; + xSliceSize[$axis] = s; + const xSlice = slice({ inputs: { x }, attrs: { begin, size: xSliceSize }, backend }); + begin[$axis] += s; + return xSlice; + }); + } + tfjsCore.registerKernel({ kernelName: tfjsCore.SplitV, backendName: 'wasm', kernelFunc: split }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Sqrt'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Square'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const supportsFullBroadcast$d = true; + registerBinaryKernel('Sub', supportsFullBroadcast$d); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmSum; + function setup$o(backend) { + wasmSum = + backend.wasm.cwrap('Sum', null /*void*/, ['number, number, number']); + } + function sum(args) { + const { backend, inputs, attrs } = args; + const { axes } = attrs; + const { x } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + tfjsCore.backend_util.assertAxesAreInnerMostDims('sum', axes, x.shape.length); + const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(x.shape, axes); + const reduceSize = tfjsCore.util.sizeFromShape(reduceShape); + const out = backend.makeOutput(outShape, x.dtype); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + const outId = backend.dataIdMap.get(out.dataId).id; + wasmSum(xId, reduceSize, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Sum', + backendName: 'wasm', + setupFunc: setup$o, + kernelFunc: sum + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Tanh'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmTile; + function setup$p(backend) { + wasmTile = backend.wasm.cwrap('Tile', null /* void */, [ + 'number', + 'array', + 'number', + 'array', + 'number', + 'number' // out_id + ]); + } + function tile(args) { + const { inputs, backend, attrs } = args; + const { x } = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const { reps } = attrs; + const newShape = new Array(x.shape.length); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[i] * reps[i]; + } + const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + const newShapeBytes = new Uint8Array(new Int32Array(newShape).buffer); + const out = backend.makeOutput(newShape, x.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmTile(xId, xShapeBytes, x.shape.length, newShapeBytes, newShape.length, CppDType[out.dtype], outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Tile', + backendName: 'wasm', + setupFunc: setup$p, + kernelFunc: tile + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + let wasmTranspose; + function setup$q(backend) { + wasmTranspose = backend.wasm.cwrap('Transpose', null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'number', + 'array', + 'number', + ]); + } + function transpose(args) { + const { inputs, backend, attrs } = args; + // Reduce any dimensions with size one. Lower-rank transpose kernel performs + // better due to simpler memory access pattern. + const [reducedShape, perm] = removeOneSizeDims(inputs.x.shape, attrs.perm); + const x = { + dataId: inputs.x.dataId, + shape: reducedShape, + dtype: inputs.x.dtype + }; + let permIsNoOp = true; + for (let i = 0; i < perm.length; i++) { + if (perm[i] !== i) { + permIsNoOp = false; + } + } + const outShape = computeOutShape(inputs.x.shape, attrs.perm); + if (permIsNoOp) { + return { dataId: x.dataId, shape: outShape, dtype: x.dtype }; + } + const out = backend.makeOutput(outShape, x.dtype); + const xId = backend.dataIdMap.get(x.dataId).id; + const outId = backend.dataIdMap.get(out.dataId).id; + const permBytes = new Uint8Array(new Int32Array(perm).buffer); + const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + wasmTranspose(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], outId, permBytes, perm.length); + return out; + } + function computeOutShape(inShape, perm) { + const outShape = new Array(inShape.length); + for (let i = 0; i < outShape.length; i++) { + outShape[i] = inShape[perm[i]]; + } + return outShape; + } + function removeOneSizeDims(shape, perm) { + const newShape = []; + const newPerm = []; + for (let i = 0; i < shape.length; ++i) { + if (shape[i] !== 1) { + newShape.push(shape[i]); + } + if (shape[perm[i]] !== 1) { + newPerm.push(perm[i]); + } + } + for (let i = 0; i < newPerm.length; ++i) { + let minValIdx = -1; + for (let j = 0; j < newPerm.length; ++j) { + if (newPerm[j] >= i && + (minValIdx === -1 || newPerm[minValIdx] > newPerm[j])) { + minValIdx = j; + } + } + newPerm[minValIdx] = i; + } + return [newShape, newPerm]; + } + tfjsCore.registerKernel({ + kernelName: 'Transpose', + backendName: 'wasm', + kernelFunc: transpose, + setupFunc: setup$q, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function unpack(args) { + const { inputs: { x }, backend, attrs: { axis } } = args; + const numOutputs = x.shape[axis]; + const rank = x.shape.length; + const outShape = new Array(rank - 1); + let outIndex = 0; + for (let i = 0; i < rank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + const outs = new Array(numOutputs); + const begin = new Array(rank).fill(0); + const size = x.shape.slice(); + size[axis] = 1; + for (let i = 0; i < outs.length; i++) { + begin[axis] = i; + outs[i] = slice({ inputs: { x }, attrs: { begin, size }, backend }); + } + return outs.map(({ dataId, dtype }) => ({ dataId, dtype, shape: outShape })); + } + tfjsCore.registerKernel({ + kernelName: 'Unpack', + backendName: 'wasm', + kernelFunc: unpack, + }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function zerosLike(args) { + const { inputs: { x }, backend } = args; + const out = backend.makeOutput(x.shape, x.dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(0); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'ZerosLike', + backendName: 'wasm', + kernelFunc: zerosLike, + }); + + function createCommonjsModule(fn, module) { + return module = { exports: {} }, fn(module, module.exports), module.exports; + } + + var tfjsBackendWasm = createCommonjsModule(function (module, exports) { + var WasmBackendModule = (function() { + var _scriptDir = typeof document !== 'undefined' && document.currentScript ? document.currentScript.src : undefined; + if (typeof __filename !== 'undefined') _scriptDir = _scriptDir || __filename; + return ( + function(WasmBackendModule) { + WasmBackendModule = WasmBackendModule || {}; + + var Module=typeof WasmBackendModule!=="undefined"?WasmBackendModule:{};var moduleOverrides={};var key;for(key in Module){if(Module.hasOwnProperty(key)){moduleOverrides[key]=Module[key];}}var arguments_=[];var thisProgram="./this.program";var quit_=function(status,toThrow){throw toThrow};var ENVIRONMENT_IS_WEB=false;var ENVIRONMENT_IS_WORKER=false;var ENVIRONMENT_IS_NODE=false;var ENVIRONMENT_IS_SHELL=false;ENVIRONMENT_IS_WEB=typeof window==="object";ENVIRONMENT_IS_WORKER=typeof importScripts==="function";ENVIRONMENT_IS_NODE=typeof process==="object"&&typeof process.versions==="object"&&typeof process.versions.node==="string";ENVIRONMENT_IS_SHELL=!ENVIRONMENT_IS_WEB&&!ENVIRONMENT_IS_NODE&&!ENVIRONMENT_IS_WORKER;var scriptDirectory="";function locateFile(path){if(Module["locateFile"]){return Module["locateFile"](path,scriptDirectory)}return scriptDirectory+path}var read_,readBinary;var nodeFS;var nodePath;if(ENVIRONMENT_IS_NODE){if(ENVIRONMENT_IS_WORKER){scriptDirectory=path.dirname(scriptDirectory)+"/";}else{scriptDirectory=__dirname+"/";}read_=function shell_read(filename,binary){if(!nodeFS)nodeFS=fs;if(!nodePath)nodePath=path;filename=nodePath["normalize"](filename);return nodeFS["readFileSync"](filename,binary?null:"utf8")};readBinary=function readBinary(filename){var ret=read_(filename,true);if(!ret.buffer){ret=new Uint8Array(ret);}assert(ret.buffer);return ret};if(process["argv"].length>1){thisProgram=process["argv"][1].replace(/\\/g,"/");}arguments_=process["argv"].slice(2);process["on"]("uncaughtException",function(ex){if(!(ex instanceof ExitStatus)){throw ex}});process["on"]("unhandledRejection",abort);quit_=function(status){process["exit"](status);};Module["inspect"]=function(){return "[Emscripten Module object]"};}else if(ENVIRONMENT_IS_SHELL){if(typeof read!="undefined"){read_=function shell_read(f){return read(f)};}readBinary=function readBinary(f){var data;if(typeof readbuffer==="function"){return new Uint8Array(readbuffer(f))}data=read(f,"binary");assert(typeof data==="object");return data};if(typeof scriptArgs!="undefined"){arguments_=scriptArgs;}else if(typeof arguments!="undefined"){arguments_=arguments;}if(typeof quit==="function"){quit_=function(status){quit(status);};}if(typeof print!=="undefined"){if(typeof console==="undefined")console={};console.log=print;console.warn=console.error=typeof printErr!=="undefined"?printErr:print;}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){if(ENVIRONMENT_IS_WORKER){scriptDirectory=self.location.href;}else if(document.currentScript){scriptDirectory=document.currentScript.src;}if(_scriptDir){scriptDirectory=_scriptDir;}if(scriptDirectory.indexOf("blob:")!==0){scriptDirectory=scriptDirectory.substr(0,scriptDirectory.lastIndexOf("/")+1);}else{scriptDirectory="";}{read_=function shell_read(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.send(null);return xhr.responseText};if(ENVIRONMENT_IS_WORKER){readBinary=function readBinary(url){var xhr=new XMLHttpRequest;xhr.open("GET",url,false);xhr.responseType="arraybuffer";xhr.send(null);return new Uint8Array(xhr.response)};}}}var out=Module["print"]||console.log.bind(console);var err=Module["printErr"]||console.warn.bind(console);for(key in moduleOverrides){if(moduleOverrides.hasOwnProperty(key)){Module[key]=moduleOverrides[key];}}moduleOverrides=null;if(Module["arguments"])arguments_=Module["arguments"];if(Module["thisProgram"])thisProgram=Module["thisProgram"];if(Module["quit"])quit_=Module["quit"];var wasmBinary;if(Module["wasmBinary"])wasmBinary=Module["wasmBinary"];var noExitRuntime;if(Module["noExitRuntime"])noExitRuntime=Module["noExitRuntime"];if(typeof WebAssembly!=="object"){err("no native wasm support detected");}var wasmMemory;var wasmTable=new WebAssembly.Table({"initial":124,"maximum":124+0,"element":"anyfunc"});var ABORT=false;function assert(condition,text){if(!condition){abort("Assertion failed: "+text);}}function getCFunc(ident){var func=Module["_"+ident];assert(func,"Cannot call unknown function "+ident+", make sure it is exported");return func}function ccall(ident,returnType,argTypes,args,opts){var toC={"string":function(str){var ret=0;if(str!==null&&str!==undefined&&str!==0){var len=(str.length<<2)+1;ret=stackAlloc(len);stringToUTF8(str,ret,len);}return ret},"array":function(arr){var ret=stackAlloc(arr.length);writeArrayToMemory(arr,ret);return ret}};function convertReturnValue(ret){if(returnType==="string")return UTF8ToString(ret);if(returnType==="boolean")return Boolean(ret);return ret}var func=getCFunc(ident);var cArgs=[];var stack=0;if(args){for(var i=0;i=endIdx))++endPtr;if(endPtr-idx>16&&heap.subarray&&UTF8Decoder){return UTF8Decoder.decode(heap.subarray(idx,endPtr))}else{var str="";while(idx>10,56320|ch&1023);}}}return str}function UTF8ToString(ptr,maxBytesToRead){return ptr?UTF8ArrayToString(HEAPU8,ptr,maxBytesToRead):""}function stringToUTF8Array(str,heap,outIdx,maxBytesToWrite){if(!(maxBytesToWrite>0))return 0;var startIdx=outIdx;var endIdx=outIdx+maxBytesToWrite-1;for(var i=0;i=55296&&u<=57343){var u1=str.charCodeAt(++i);u=65536+((u&1023)<<10)|u1&1023;}if(u<=127){if(outIdx>=endIdx)break;heap[outIdx++]=u;}else if(u<=2047){if(outIdx+1>=endIdx)break;heap[outIdx++]=192|u>>6;heap[outIdx++]=128|u&63;}else if(u<=65535){if(outIdx+2>=endIdx)break;heap[outIdx++]=224|u>>12;heap[outIdx++]=128|u>>6&63;heap[outIdx++]=128|u&63;}else{if(outIdx+3>=endIdx)break;heap[outIdx++]=240|u>>18;heap[outIdx++]=128|u>>12&63;heap[outIdx++]=128|u>>6&63;heap[outIdx++]=128|u&63;}}heap[outIdx]=0;return outIdx-startIdx}function stringToUTF8(str,outPtr,maxBytesToWrite){return stringToUTF8Array(str,HEAPU8,outPtr,maxBytesToWrite)}var UTF16Decoder=typeof TextDecoder!=="undefined"?new TextDecoder("utf-16le"):undefined;function writeArrayToMemory(array,buffer){HEAP8.set(array,buffer);}var WASM_PAGE_SIZE=65536;function alignUp(x,multiple){if(x%multiple>0){x+=multiple-x%multiple;}return x}var buffer,HEAP8,HEAPU8,HEAP16,HEAPU16,HEAP32,HEAPU32,HEAPF32,HEAPF64;function updateGlobalBufferAndViews(buf){buffer=buf;Module["HEAP8"]=HEAP8=new Int8Array(buf);Module["HEAP16"]=HEAP16=new Int16Array(buf);Module["HEAP32"]=HEAP32=new Int32Array(buf);Module["HEAPU8"]=HEAPU8=new Uint8Array(buf);Module["HEAPU16"]=HEAPU16=new Uint16Array(buf);Module["HEAPU32"]=HEAPU32=new Uint32Array(buf);Module["HEAPF32"]=HEAPF32=new Float32Array(buf);Module["HEAPF64"]=HEAPF64=new Float64Array(buf);}var DYNAMIC_BASE=5254224,DYNAMICTOP_PTR=11184;var INITIAL_INITIAL_MEMORY=Module["INITIAL_MEMORY"]||16777216;if(Module["wasmMemory"]){wasmMemory=Module["wasmMemory"];}else{wasmMemory=new WebAssembly.Memory({"initial":INITIAL_INITIAL_MEMORY/WASM_PAGE_SIZE,"maximum":2147483648/WASM_PAGE_SIZE});}if(wasmMemory){buffer=wasmMemory.buffer;}INITIAL_INITIAL_MEMORY=buffer.byteLength;updateGlobalBufferAndViews(buffer);HEAP32[DYNAMICTOP_PTR>>2]=DYNAMIC_BASE;function callRuntimeCallbacks(callbacks){while(callbacks.length>0){var callback=callbacks.shift();if(typeof callback=="function"){callback(Module);continue}var func=callback.func;if(typeof func==="number"){if(callback.arg===undefined){Module["dynCall_v"](func);}else{Module["dynCall_vi"](func,callback.arg);}}else{func(callback.arg===undefined?null:callback.arg);}}}var __ATPRERUN__=[];var __ATINIT__=[];var __ATMAIN__=[];var __ATPOSTRUN__=[];function preRun(){if(Module["preRun"]){if(typeof Module["preRun"]=="function")Module["preRun"]=[Module["preRun"]];while(Module["preRun"].length){addOnPreRun(Module["preRun"].shift());}}callRuntimeCallbacks(__ATPRERUN__);}function initRuntime(){callRuntimeCallbacks(__ATINIT__);}function preMain(){callRuntimeCallbacks(__ATMAIN__);}function postRun(){if(Module["postRun"]){if(typeof Module["postRun"]=="function")Module["postRun"]=[Module["postRun"]];while(Module["postRun"].length){addOnPostRun(Module["postRun"].shift());}}callRuntimeCallbacks(__ATPOSTRUN__);}function addOnPreRun(cb){__ATPRERUN__.unshift(cb);}function addOnPostRun(cb){__ATPOSTRUN__.unshift(cb);}var Math_ceil=Math.ceil;var Math_floor=Math.floor;var runDependencies=0;var runDependencyWatcher=null;var dependenciesFulfilled=null;function addRunDependency(id){runDependencies++;if(Module["monitorRunDependencies"]){Module["monitorRunDependencies"](runDependencies);}}function removeRunDependency(id){runDependencies--;if(Module["monitorRunDependencies"]){Module["monitorRunDependencies"](runDependencies);}if(runDependencies==0){if(runDependencyWatcher!==null){clearInterval(runDependencyWatcher);runDependencyWatcher=null;}if(dependenciesFulfilled){var callback=dependenciesFulfilled;dependenciesFulfilled=null;callback();}}}Module["preloadedImages"]={};Module["preloadedAudios"]={};function abort(what){if(Module["onAbort"]){Module["onAbort"](what);}what+="";out(what);err(what);ABORT=true;what="abort("+what+"). Build with -s ASSERTIONS=1 for more info.";throw new WebAssembly.RuntimeError(what)}function hasPrefix(str,prefix){return String.prototype.startsWith?str.startsWith(prefix):str.indexOf(prefix)===0}var dataURIPrefix="data:application/octet-stream;base64,";function isDataURI(filename){return hasPrefix(filename,dataURIPrefix)}var fileURIPrefix="file://";function isFileURI(filename){return hasPrefix(filename,fileURIPrefix)}var wasmBinaryFile="tfjs-backend-wasm.wasm";if(!isDataURI(wasmBinaryFile)){wasmBinaryFile=locateFile(wasmBinaryFile);}function getBinary(){try{if(wasmBinary){return new Uint8Array(wasmBinary)}if(readBinary){return readBinary(wasmBinaryFile)}else{throw "both async and sync fetching of the wasm failed"}}catch(err){abort(err);}}function getBinaryPromise(){if(!wasmBinary&&(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER)&&typeof fetch==="function"&&!isFileURI(wasmBinaryFile)){return fetch(wasmBinaryFile,{credentials:"same-origin"}).then(function(response){if(!response["ok"]){throw "failed to load wasm binary file at '"+wasmBinaryFile+"'"}return response["arrayBuffer"]()}).catch(function(){return getBinary()})}return new Promise(function(resolve,reject){resolve(getBinary());})}function createWasm(){var info={"a":asmLibraryArg};function receiveInstance(instance,module){var exports=instance.exports;Module["asm"]=exports;removeRunDependency();}addRunDependency();function receiveInstantiatedSource(output){receiveInstance(output["instance"]);}function instantiateArrayBuffer(receiver){return getBinaryPromise().then(function(binary){return WebAssembly.instantiate(binary,info)}).then(receiver,function(reason){err("failed to asynchronously prepare wasm: "+reason);abort(reason);})}function instantiateAsync(){if(!wasmBinary&&typeof WebAssembly.instantiateStreaming==="function"&&!isDataURI(wasmBinaryFile)&&!isFileURI(wasmBinaryFile)&&typeof fetch==="function"){fetch(wasmBinaryFile,{credentials:"same-origin"}).then(function(response){var result=WebAssembly.instantiateStreaming(response,info);return result.then(receiveInstantiatedSource,function(reason){err("wasm streaming compile failed: "+reason);err("falling back to ArrayBuffer instantiation");instantiateArrayBuffer(receiveInstantiatedSource);})});}else{return instantiateArrayBuffer(receiveInstantiatedSource)}}if(Module["instantiateWasm"]){try{var exports=Module["instantiateWasm"](info,receiveInstance);return exports}catch(e){err("Module.instantiateWasm callback failed with error: "+e);return false}}instantiateAsync();return {}}__ATINIT__.push({func:function(){___wasm_call_ctors();}});function _abort(){abort();}function _emscripten_memcpy_big(dest,src,num){HEAPU8.copyWithin(dest,src,src+num);}function _emscripten_get_heap_size(){return HEAPU8.length}function emscripten_realloc_buffer(size){try{wasmMemory.grow(size-buffer.byteLength+65535>>>16);updateGlobalBufferAndViews(wasmMemory.buffer);return 1}catch(e){}}function _emscripten_resize_heap(requestedSize){var oldSize=_emscripten_get_heap_size();var PAGE_MULTIPLE=65536;var maxHeapSize=2147483648;if(requestedSize>maxHeapSize){return false}var minHeapSize=16777216;for(var cutDown=1;cutDown<=4;cutDown*=2){var overGrownHeapSize=oldSize*(1+.2/cutDown);overGrownHeapSize=Math.min(overGrownHeapSize,requestedSize+100663296);var newSize=Math.min(maxHeapSize,alignUp(Math.max(minHeapSize,requestedSize,overGrownHeapSize),PAGE_MULTIPLE));var replacement=emscripten_realloc_buffer(newSize);if(replacement){return true}}return false}var SYSCALLS={mappings:{},buffers:[null,[],[]],printChar:function(stream,curr){var buffer=SYSCALLS.buffers[stream];if(curr===0||curr===10){(stream===1?out:err)(UTF8ArrayToString(buffer,0));buffer.length=0;}else{buffer.push(curr);}},varargs:undefined,get:function(){SYSCALLS.varargs+=4;var ret=HEAP32[SYSCALLS.varargs-4>>2];return ret},getStr:function(ptr){var ret=UTF8ToString(ptr);return ret},get64:function(low,high){return low}};function _fd_close(fd){return 0}function _fd_seek(fd,offset_low,offset_high,whence,newOffset){}function _fd_write(fd,iov,iovcnt,pnum){var num=0;for(var i=0;i>2];var len=HEAP32[iov+(i*8+4)>>2];for(var j=0;j>2]=num;return 0}function _roundf(d){d=+d;return d>=+0?+Math_floor(d+ +.5):+Math_ceil(d-+.5)}var asmLibraryArg={"a":_abort,"e":_emscripten_memcpy_big,"f":_emscripten_resize_heap,"g":_fd_close,"d":_fd_seek,"c":_fd_write,"memory":wasmMemory,"b":_roundf,"table":wasmTable};var asm=createWasm();Module["asm"]=asm;var ___wasm_call_ctors=Module["___wasm_call_ctors"]=function(){return (___wasm_call_ctors=Module["___wasm_call_ctors"]=Module["asm"]["h"]).apply(null,arguments)};var _init=Module["_init"]=function(){return (_init=Module["_init"]=Module["asm"]["i"]).apply(null,arguments)};var _register_tensor=Module["_register_tensor"]=function(){return (_register_tensor=Module["_register_tensor"]=Module["asm"]["j"]).apply(null,arguments)};var _dispose_data=Module["_dispose_data"]=function(){return (_dispose_data=Module["_dispose_data"]=Module["asm"]["k"]).apply(null,arguments)};var _dispose=Module["_dispose"]=function(){return (_dispose=Module["_dispose"]=Module["asm"]["l"]).apply(null,arguments)};var _Abs=Module["_Abs"]=function(){return (_Abs=Module["_Abs"]=Module["asm"]["m"]).apply(null,arguments)};var _Add=Module["_Add"]=function(){return (_Add=Module["_Add"]=Module["asm"]["n"]).apply(null,arguments)};var _AddN=Module["_AddN"]=function(){return (_AddN=Module["_AddN"]=Module["asm"]["o"]).apply(null,arguments)};var _ArgMax=Module["_ArgMax"]=function(){return (_ArgMax=Module["_ArgMax"]=Module["asm"]["p"]).apply(null,arguments)};var _AvgPool=Module["_AvgPool"]=function(){return (_AvgPool=Module["_AvgPool"]=Module["asm"]["q"]).apply(null,arguments)};var _BatchMatMul=Module["_BatchMatMul"]=function(){return (_BatchMatMul=Module["_BatchMatMul"]=Module["asm"]["r"]).apply(null,arguments)};var _ClipByValue=Module["_ClipByValue"]=function(){return (_ClipByValue=Module["_ClipByValue"]=Module["asm"]["s"]).apply(null,arguments)};var _Conv2D=Module["_Conv2D"]=function(){return (_Conv2D=Module["_Conv2D"]=Module["asm"]["t"]).apply(null,arguments)};var _Cos=Module["_Cos"]=function(){return (_Cos=Module["_Cos"]=Module["asm"]["u"]).apply(null,arguments)};var _CropAndResize=Module["_CropAndResize"]=function(){return (_CropAndResize=Module["_CropAndResize"]=Module["asm"]["v"]).apply(null,arguments)};var _DepthwiseConv2dNative=Module["_DepthwiseConv2dNative"]=function(){return (_DepthwiseConv2dNative=Module["_DepthwiseConv2dNative"]=Module["asm"]["w"]).apply(null,arguments)};var _Div=Module["_Div"]=function(){return (_Div=Module["_Div"]=Module["asm"]["x"]).apply(null,arguments)};var _Exp=Module["_Exp"]=function(){return (_Exp=Module["_Exp"]=Module["asm"]["y"]).apply(null,arguments)};var _FloorDiv=Module["_FloorDiv"]=function(){return (_FloorDiv=Module["_FloorDiv"]=Module["asm"]["z"]).apply(null,arguments)};var _FusedBatchNorm=Module["_FusedBatchNorm"]=function(){return (_FusedBatchNorm=Module["_FusedBatchNorm"]=Module["asm"]["A"]).apply(null,arguments)};var _FusedConv2D=Module["_FusedConv2D"]=function(){return (_FusedConv2D=Module["_FusedConv2D"]=Module["asm"]["B"]).apply(null,arguments)};var _FusedDepthwiseConv2D=Module["_FusedDepthwiseConv2D"]=function(){return (_FusedDepthwiseConv2D=Module["_FusedDepthwiseConv2D"]=Module["asm"]["C"]).apply(null,arguments)};var _Gather=Module["_Gather"]=function(){return (_Gather=Module["_Gather"]=Module["asm"]["D"]).apply(null,arguments)};var _GatherNd=Module["_GatherNd"]=function(){return (_GatherNd=Module["_GatherNd"]=Module["asm"]["E"]).apply(null,arguments)};var _Greater=Module["_Greater"]=function(){return (_Greater=Module["_Greater"]=Module["asm"]["F"]).apply(null,arguments)};var _GreaterEqual=Module["_GreaterEqual"]=function(){return (_GreaterEqual=Module["_GreaterEqual"]=Module["asm"]["G"]).apply(null,arguments)};var _Less=Module["_Less"]=function(){return (_Less=Module["_Less"]=Module["asm"]["H"]).apply(null,arguments)};var _LessEqual=Module["_LessEqual"]=function(){return (_LessEqual=Module["_LessEqual"]=Module["asm"]["I"]).apply(null,arguments)};var _Log=Module["_Log"]=function(){return (_Log=Module["_Log"]=Module["asm"]["J"]).apply(null,arguments)};var _LogicalAnd=Module["_LogicalAnd"]=function(){return (_LogicalAnd=Module["_LogicalAnd"]=Module["asm"]["K"]).apply(null,arguments)};var _Max=Module["_Max"]=function(){return (_Max=Module["_Max"]=Module["asm"]["L"]).apply(null,arguments)};var _MaxPool=Module["_MaxPool"]=function(){return (_MaxPool=Module["_MaxPool"]=Module["asm"]["M"]).apply(null,arguments)};var _Maximum=Module["_Maximum"]=function(){return (_Maximum=Module["_Maximum"]=Module["asm"]["N"]).apply(null,arguments)};var _Min=Module["_Min"]=function(){return (_Min=Module["_Min"]=Module["asm"]["O"]).apply(null,arguments)};var _Minimum=Module["_Minimum"]=function(){return (_Minimum=Module["_Minimum"]=Module["asm"]["P"]).apply(null,arguments)};var _Mul=Module["_Mul"]=function(){return (_Mul=Module["_Mul"]=Module["asm"]["Q"]).apply(null,arguments)};var _Neg=Module["_Neg"]=function(){return (_Neg=Module["_Neg"]=Module["asm"]["R"]).apply(null,arguments)};var _NonMaxSuppressionV3=Module["_NonMaxSuppressionV3"]=function(){return (_NonMaxSuppressionV3=Module["_NonMaxSuppressionV3"]=Module["asm"]["S"]).apply(null,arguments)};var _NonMaxSuppressionV5=Module["_NonMaxSuppressionV5"]=function(){return (_NonMaxSuppressionV5=Module["_NonMaxSuppressionV5"]=Module["asm"]["T"]).apply(null,arguments)};var _NotEqual=Module["_NotEqual"]=function(){return (_NotEqual=Module["_NotEqual"]=Module["asm"]["U"]).apply(null,arguments)};var _PadV2=Module["_PadV2"]=function(){return (_PadV2=Module["_PadV2"]=Module["asm"]["V"]).apply(null,arguments)};var _Pow=Module["_Pow"]=function(){return (_Pow=Module["_Pow"]=Module["asm"]["W"]).apply(null,arguments)};var _Prelu=Module["_Prelu"]=function(){return (_Prelu=Module["_Prelu"]=Module["asm"]["X"]).apply(null,arguments)};var _Relu=Module["_Relu"]=function(){return (_Relu=Module["_Relu"]=Module["asm"]["Y"]).apply(null,arguments)};var _Relu6=Module["_Relu6"]=function(){return (_Relu6=Module["_Relu6"]=Module["asm"]["Z"]).apply(null,arguments)};var _ResizeBilinear=Module["_ResizeBilinear"]=function(){return (_ResizeBilinear=Module["_ResizeBilinear"]=Module["asm"]["_"]).apply(null,arguments)};var _Rsqrt=Module["_Rsqrt"]=function(){return (_Rsqrt=Module["_Rsqrt"]=Module["asm"]["$"]).apply(null,arguments)};var _ScatterNd=Module["_ScatterNd"]=function(){return (_ScatterNd=Module["_ScatterNd"]=Module["asm"]["aa"]).apply(null,arguments)};var _Sigmoid=Module["_Sigmoid"]=function(){return (_Sigmoid=Module["_Sigmoid"]=Module["asm"]["ba"]).apply(null,arguments)};var _Sin=Module["_Sin"]=function(){return (_Sin=Module["_Sin"]=Module["asm"]["ca"]).apply(null,arguments)};var _Softmax=Module["_Softmax"]=function(){return (_Softmax=Module["_Softmax"]=Module["asm"]["da"]).apply(null,arguments)};var _Sqrt=Module["_Sqrt"]=function(){return (_Sqrt=Module["_Sqrt"]=Module["asm"]["ea"]).apply(null,arguments)};var _Square=Module["_Square"]=function(){return (_Square=Module["_Square"]=Module["asm"]["fa"]).apply(null,arguments)};var _Sub=Module["_Sub"]=function(){return (_Sub=Module["_Sub"]=Module["asm"]["ga"]).apply(null,arguments)};var _Sum=Module["_Sum"]=function(){return (_Sum=Module["_Sum"]=Module["asm"]["ha"]).apply(null,arguments)};var _Tanh=Module["_Tanh"]=function(){return (_Tanh=Module["_Tanh"]=Module["asm"]["ia"]).apply(null,arguments)};var _Tile=Module["_Tile"]=function(){return (_Tile=Module["_Tile"]=Module["asm"]["ja"]).apply(null,arguments)};var _Transpose=Module["_Transpose"]=function(){return (_Transpose=Module["_Transpose"]=Module["asm"]["ka"]).apply(null,arguments)};var __FusedMatMul=Module["__FusedMatMul"]=function(){return (__FusedMatMul=Module["__FusedMatMul"]=Module["asm"]["la"]).apply(null,arguments)};var _malloc=Module["_malloc"]=function(){return (_malloc=Module["_malloc"]=Module["asm"]["ma"]).apply(null,arguments)};var _free=Module["_free"]=function(){return (_free=Module["_free"]=Module["asm"]["na"]).apply(null,arguments)};var stackSave=Module["stackSave"]=function(){return (stackSave=Module["stackSave"]=Module["asm"]["oa"]).apply(null,arguments)};var stackAlloc=Module["stackAlloc"]=function(){return (stackAlloc=Module["stackAlloc"]=Module["asm"]["pa"]).apply(null,arguments)};var stackRestore=Module["stackRestore"]=function(){return (stackRestore=Module["stackRestore"]=Module["asm"]["qa"]).apply(null,arguments)};var dynCall_vi=Module["dynCall_vi"]=function(){return (dynCall_vi=Module["dynCall_vi"]=Module["asm"]["ra"]).apply(null,arguments)};var dynCall_v=Module["dynCall_v"]=function(){return (dynCall_v=Module["dynCall_v"]=Module["asm"]["sa"]).apply(null,arguments)};Module["asm"]=asm;Module["cwrap"]=cwrap;var calledRun;Module["then"]=function(func){if(calledRun){func(Module);}else{var old=Module["onRuntimeInitialized"];Module["onRuntimeInitialized"]=function(){if(old)old();func(Module);};}return Module};function ExitStatus(status){this.name="ExitStatus";this.message="Program terminated with exit("+status+")";this.status=status;}dependenciesFulfilled=function runCaller(){if(!calledRun)run();if(!calledRun)dependenciesFulfilled=runCaller;};function run(args){if(runDependencies>0){return}preRun();if(runDependencies>0)return;function doRun(){if(calledRun)return;calledRun=true;Module["calledRun"]=true;if(ABORT)return;initRuntime();preMain();if(Module["onRuntimeInitialized"])Module["onRuntimeInitialized"]();postRun();}if(Module["setStatus"]){Module["setStatus"]("Running...");setTimeout(function(){setTimeout(function(){Module["setStatus"]("");},1);doRun();},1);}else{doRun();}}Module["run"]=run;if(Module["preInit"]){if(typeof Module["preInit"]=="function")Module["preInit"]=[Module["preInit"]];while(Module["preInit"].length>0){Module["preInit"].pop()();}}noExitRuntime=true;run(); + + + return WasmBackendModule + } + ); + })(); + module.exports = WasmBackendModule; + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + const WASM_PRIORITY = 2; + class BackendWasm extends tfjsCore.KernelBackend { + constructor(wasm) { + super(); + this.wasm = wasm; + // 0 is reserved for null data ids. + this.dataIdNextNumber = 1; + this.wasm.tfjs.init(); + this.dataIdMap = new tfjsCore.DataStorage(this, tfjsCore.engine()); + } + write(values, shape, dtype) { + const dataId = {}; + this.move(dataId, values, shape, dtype); + return dataId; + } + numDataIds() { + return this.dataIdMap.numDataIds(); + } + async time(f) { + const start = tfjsCore.util.now(); + f(); + const kernelMs = tfjsCore.util.now() - start; + return { kernelMs }; + } + move(dataId, values, shape, dtype) { + const id = this.dataIdNextNumber++; + if (dtype === 'string') { + const stringBytes = values; + this.dataIdMap.set(dataId, { id, stringBytes, shape, dtype, memoryOffset: null }); + return; + } + const size = tfjsCore.util.sizeFromShape(shape); + const numBytes = size * tfjsCore.util.bytesPerElement(dtype); + const memoryOffset = this.wasm._malloc(numBytes); + this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype }); + this.wasm.tfjs.registerTensor(id, size, memoryOffset); + if (values != null) { + this.wasm.HEAPU8.set(new Uint8Array(values.buffer, values.byteOffset, numBytes), memoryOffset); + } + } + async read(dataId) { + return this.readSync(dataId); + } + readSync(dataId) { + const { memoryOffset, dtype, shape, stringBytes } = this.dataIdMap.get(dataId); + if (dtype === 'string') { + return stringBytes; + } + const bytes = this.wasm.HEAPU8.slice(memoryOffset, memoryOffset + tfjsCore.util.sizeFromShape(shape) * tfjsCore.util.bytesPerElement(dtype)); + return typedArrayFromBuffer(bytes.buffer, dtype); + } + disposeData(dataId) { + const data = this.dataIdMap.get(dataId); + this.wasm._free(data.memoryOffset); + this.wasm.tfjs.disposeData(data.id); + this.dataIdMap.delete(dataId); + } + floatPrecision() { + return 32; + } + // Returns the memory offset of a tensor. Useful for debugging and unit + // testing. + getMemoryOffset(dataId) { + return this.dataIdMap.get(dataId).memoryOffset; + } + dispose() { + this.wasm.tfjs.dispose(); + this.wasm = null; + } + memory() { + return { unreliable: false }; + } + /** + * Make a tensor info for the output of an op. If `memoryOffset` is not + * present, this method allocates memory on the WASM heap. If `memoryOffset` + * is present, the memory was allocated elsewhere (in c++) and we just record + * the pointer where that memory lives. + */ + makeOutput(shape, dtype, memoryOffset) { + let dataId; + if (memoryOffset == null) { + dataId = this.write(null /* values */, shape, dtype); + } + else { + dataId = {}; + const id = this.dataIdNextNumber++; + this.dataIdMap.set(dataId, { id, memoryOffset, shape, dtype }); + const size = tfjsCore.util.sizeFromShape(shape); + this.wasm.tfjs.registerTensor(id, size, memoryOffset); + } + return { dataId, shape, dtype }; + } + typedArrayFromHeap({ shape, dtype, dataId }) { + const buffer = this.wasm.HEAPU8.buffer; + const { memoryOffset } = this.dataIdMap.get(dataId); + const size = tfjsCore.util.sizeFromShape(shape); + switch (dtype) { + case 'float32': + return new Float32Array(buffer, memoryOffset, size); + case 'int32': + return new Int32Array(buffer, memoryOffset, size); + case 'bool': + return new Uint8Array(buffer, memoryOffset, size); + default: + throw new Error(`Uknown dtype ${dtype}`); + } + } + } + tfjsCore.registerBackend('wasm', async () => { + const { wasm } = await init(); + return new BackendWasm(wasm); + }, WASM_PRIORITY); + function createInstantiateWasmFunc(path) { + // tslint:disable-next-line:no-any + return (imports, callback) => { + tfjsCore.util.fetch(path, { credentials: 'same-origin' }).then((response) => { + if (!response['ok']) { + imports.env.a(`failed to load wasm binary file at '${path}'`); + } + response.arrayBuffer().then(binary => { + WebAssembly.instantiate(binary, imports).then(output => { + callback(output.instance); + }); + }); + }); + return {}; + }; + } + /** + * Initializes the wasm module and creates the js <--> wasm bridge. + * + * NOTE: We wrap the wasm module in a object with property 'wasm' instead of + * returning Promise to avoid freezing Chrome (last tested + * in Chrome 76). + */ + async function init() { + return new Promise((resolve, reject) => { + const factoryConfig = {}; + if (wasmPath != null) { + factoryConfig.locateFile = (path, prefix) => { + if (path.endsWith('.wasm')) { + return wasmPath; + } + return prefix + path; + }; + // use wasm instantiateWasm override when system fetch is not available. + // For detail references + // https://github.com/emscripten-core/emscripten/blob/2bca083cbbd5a4133db61fbd74d04f7feecfa907/tests/manual_wasm_instantiate.html#L170 + if (customFetch) { + factoryConfig.instantiateWasm = createInstantiateWasmFunc(wasmPath); + } + } + const wasm = tfjsBackendWasm(factoryConfig); + const voidReturnType = null; + // Using the tfjs namespace to avoid conflict with emscripten's API. + wasm.tfjs = { + init: wasm.cwrap('init', null, []), + registerTensor: wasm.cwrap('register_tensor', null, [ + 'number', + 'number', + 'number', + ]), + disposeData: wasm.cwrap('dispose_data', voidReturnType, ['number']), + dispose: wasm.cwrap('dispose', voidReturnType, []), + }; + let initialized = false; + wasm.onRuntimeInitialized = () => { + initialized = true; + initAborted = false; + resolve({ wasm }); + }; + wasm.onAbort = () => { + if (initialized) { + // Emscripten already called console.warn so no need to double log. + return; + } + if (initAborted) { + // Emscripten calls `onAbort` twice, resulting in double error + // messages. + return; + } + initAborted = true; + const rejectMsg = 'Make sure the server can serve the `.wasm` file relative to the ' + + 'bundled js file. For more details see https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers'; + reject({ message: rejectMsg }); + }; + }); + } + function typedArrayFromBuffer(buffer, dtype) { + switch (dtype) { + case 'float32': + return new Float32Array(buffer); + case 'int32': + return new Int32Array(buffer); + case 'bool': + return new Uint8Array(buffer); + default: + throw new Error(`Unknown dtype ${dtype}`); + } + } + let wasmPath = null; + let initAborted = false; + let customFetch = false; + /** + * Sets the path to the `.wasm` file which will be fetched when the wasm + * backend is initialized. See + * https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-wasm/README.md#using-bundlers + * for more details. + * @param path wasm file path or url + * @param usePlatformFetch optional boolean to use platform fetch to download + * the wasm file, default to false. + */ + /** @doc {heading: 'Environment', namespace: 'wasm'} */ + function setWasmPath(path, usePlatformFetch = false) { + if (initAborted) { + throw new Error('The WASM backend was already initialized. Make sure you call ' + + '`setWasmPath()` before you call `tf.setBackend()` or `tf.ready()`'); + } + wasmPath = path; + customFetch = usePlatformFetch; + } + + /** @license See the LICENSE file. */ + // This code is auto-generated, do not modify this file! + const version = '0.0.0'; + + exports.BackendWasm = BackendWasm; + exports.setWasmPath = setWasmPath; + exports.version_wasm = version; + + Object.defineProperty(exports, '__esModule', { value: true }); + +}))); +//# sourceMappingURL=tf-backend-wasm.js.map diff --git a/tfjs-core/benchmarks/tf-core.js b/tfjs-core/benchmarks/tf-core.js new file mode 100644 index 00000000000..3a0cfeedcfc --- /dev/null +++ b/tfjs-core/benchmarks/tf-core.js @@ -0,0 +1,26220 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +(function (global, factory) { + typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) : + typeof define === 'function' && define.amd ? define(['exports'], factory) : + (global = global || self, factory(global.tf = global.tf || {})); +}(this, (function (exports) { 'use strict'; + + /*! ***************************************************************************** + Copyright (c) Microsoft Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of the + License at http://www.apache.org/licenses/LICENSE-2.0 + + THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED + WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, + MERCHANTABLITY OR NON-INFRINGEMENT. + + See the Apache Version 2.0 License for specific language governing permissions + and limitations under the License. + ***************************************************************************** */ + /* global Reflect, Promise */ + + var extendStatics = function(d, b) { + extendStatics = Object.setPrototypeOf || + ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || + function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; + return extendStatics(d, b); + }; + + function __extends(d, b) { + extendStatics(d, b); + function __() { this.constructor = d; } + d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); + } + + function __awaiter(thisArg, _arguments, P, generator) { + return new (P || (P = Promise))(function (resolve, reject) { + function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } + function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } + function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } + step((generator = generator.apply(thisArg, _arguments || [])).next()); + }); + } + + function __generator(thisArg, body) { + var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; + return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; + function verb(n) { return function (v) { return step([n, v]); }; } + function step(op) { + if (f) throw new TypeError("Generator is already executing."); + while (_) try { + if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; + if (y = 0, t) op = [op[0] & 2, t.value]; + switch (op[0]) { + case 0: case 1: t = op; break; + case 4: _.label++; return { value: op[1], done: false }; + case 5: _.label++; y = op[1]; op = [0]; continue; + case 7: op = _.ops.pop(); _.trys.pop(); continue; + default: + if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } + if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } + if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } + if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } + if (t[2]) _.ops.pop(); + _.trys.pop(); continue; + } + op = body.call(thisArg, _); + } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } + if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; + } + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. + var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; + /** + * The environment contains evaluated flags as well as the registered platform. + * This is always used as a global singleton and can be retrieved with + * `tf.env()`. + */ + /** @doc {heading: 'Environment'} */ + var Environment = /** @class */ (function () { + // tslint:disable-next-line: no-any + function Environment(global) { + this.global = global; + this.flags = {}; + this.flagRegistry = {}; + this.urlFlags = {}; + this.populateURLFlags(); + } + Environment.prototype.setPlatform = function (platformName, platform) { + if (this.platform != null) { + console.warn("Platform " + this.platformName + " has already been set. " + + ("Overwriting the platform with " + platform + ".")); + } + this.platformName = platformName; + this.platform = platform; + }; + Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) { + this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook }; + // Override the flag value from the URL. This has to happen here because the + // environment is initialized before flags get registered. + if (this.urlFlags[flagName] != null) { + var flagValue = this.urlFlags[flagName]; + console.warn("Setting feature override from URL " + flagName + ": " + flagValue + "."); + this.set(flagName, flagValue); + } + }; + Environment.prototype.get = function (flagName) { + if (flagName in this.flags) { + return this.flags[flagName]; + } + this.flags[flagName] = this.evaluateFlag(flagName); + return this.flags[flagName]; + }; + Environment.prototype.getNumber = function (flagName) { + return this.get(flagName); + }; + Environment.prototype.getBool = function (flagName) { + return this.get(flagName); + }; + Environment.prototype.getFlags = function () { + return this.flags; + }; + Object.defineProperty(Environment.prototype, "features", { + // For backwards compatibility. + get: function () { + return this.flags; + }, + enumerable: true, + configurable: true + }); + Environment.prototype.set = function (flagName, value) { + if (this.flagRegistry[flagName] == null) { + throw new Error("Cannot set flag " + flagName + " as it has not been registered."); + } + this.flags[flagName] = value; + if (this.flagRegistry[flagName].setHook != null) { + this.flagRegistry[flagName].setHook(value); + } + }; + Environment.prototype.evaluateFlag = function (flagName) { + if (this.flagRegistry[flagName] == null) { + throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found."); + } + return this.flagRegistry[flagName].evaluationFn(); + }; + Environment.prototype.setFlags = function (flags) { + this.flags = Object.assign({}, flags); + }; + Environment.prototype.reset = function () { + this.flags = {}; + this.urlFlags = {}; + this.populateURLFlags(); + }; + Environment.prototype.populateURLFlags = function () { + var _this = this; + if (typeof this.global === 'undefined' || + typeof this.global.location === 'undefined' || + typeof this.global.location.search === 'undefined') { + return; + } + var urlParams = getQueryParams(this.global.location.search); + if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { + var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); + keyValues.forEach(function (keyValue) { + var _a = keyValue.split(':'), key = _a[0], value = _a[1]; + _this.urlFlags[key] = parseValue(key, value); + }); + } + }; + return Environment; + }()); + function getQueryParams(queryString) { + var params = {}; + queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) { + var t = []; + for (var _i = 1; _i < arguments.length; _i++) { + t[_i - 1] = arguments[_i]; + } + decodeParam(params, t[0], t[1]); + return t.join('='); + }); + return params; + } + function decodeParam(params, name, value) { + params[decodeURIComponent(name)] = decodeURIComponent(value || ''); + } + function parseValue(flagName, value) { + value = value.toLowerCase(); + if (value === 'true' || value === 'false') { + return value === 'true'; + } + else if ("" + +value === value) { + return +value; + } + throw new Error("Could not parse value flag value " + value + " for flag " + flagName + "."); + } + /** + * Returns the current environment (a global singleton). + * + * The environment object contains the evaluated feature values as well as the + * active platform. + */ + /** @doc {heading: 'Environment'} */ + function env() { + return exports.ENV; + } + exports.ENV = null; + function setEnvironmentGlobal(environment) { + exports.ENV = environment; + } + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Note that the identifier globalNameSpace is scoped to this module, but will + // always resolve to the same global object regardless of how the module is + // resolved. + // tslint:disable-next-line:no-any + var globalNameSpace; + // tslint:disable-next-line:no-any + function getGlobalNamespace() { + if (globalNameSpace == null) { + // tslint:disable-next-line:no-any + var ns = void 0; + if (typeof (window) !== 'undefined') { + ns = window; + } + else if (typeof (global) !== 'undefined') { + ns = global; + } + else if (typeof (process) !== 'undefined') { + ns = process; + } + else if (typeof (self) !== 'undefined') { + ns = self; + } + else { + throw new Error('Could not find a global object'); + } + globalNameSpace = ns; + } + return globalNameSpace; + } + // tslint:disable-next-line:no-any + function getGlobalMap() { + var ns = getGlobalNamespace(); + if (ns._tfGlobals == null) { + ns._tfGlobals = new Map(); + } + return ns._tfGlobals; + } + /** + * Returns a globally accessible 'singleton' object. + * + * @param key the name of the object + * @param init a function to initialize to initialize this object + * the first time it is fetched. + */ + function getGlobal(key, init) { + var globalMap = getGlobalMap(); + if (globalMap.has(key)) { + return globalMap.get(key); + } + else { + var singleton = init(); + globalMap.set(key, singleton); + return globalMap.get(key); + } + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var kernelRegistry = getGlobal('kernelRegistry', function () { return new Map(); }); + var gradRegistry = getGlobal('gradRegistry', function () { return new Map(); }); + /** + * Returns the kernel function (code) associated with the provided names. + * + * @param kernelName The official name of the kernel. + * @param backendName The official name of the backend. + */ + function getKernel(kernelName, backendName) { + var key = makeKey(kernelName, backendName); + return kernelRegistry.get(key); + } + /** + * Returns the registered gradient info associated with the provided kernel. + * @param kernelName The official TF kernel name. + */ + function getGradient(kernelName) { + return gradRegistry.get(kernelName); + } + function getKernelsForBackend(backendName) { + var it = kernelRegistry.entries(); + var result = []; + while (true) { + var _a = it.next(), done = _a.done, value = _a.value; + if (done) { + break; + } + var key = value[0], config = value[1]; + var backend = key.split('_')[0]; + if (backend === backendName) { + result.push(config); + } + } + return result; + } + /** + * Registers the function (forward pass) for the kernel in a global registry. + * + * @param config A config object with the following properties: + * - `kernelName` The official name of the kernel. + * - `backendName` The official name of the backend. + * - `kernelFunc` The function to run during the forward pass of the kernel. + * - `setupFunc` Optional. Gets called once, after the backend initializes. + * - `disposeFunc` Optional. Gets called once, right before the backend is + * disposed. + */ + function registerKernel(config) { + var kernelName = config.kernelName, backendName = config.backendName; + var key = makeKey(kernelName, backendName); + if (kernelRegistry.has(key)) { + console.warn("The kernel '" + kernelName + "' for backend " + + ("'" + backendName + "' is already registered")); + } + kernelRegistry.set(key, config); + } + /** + * Registers a gradient function for a given kernel in the global registry, + * to be used during the back-propagation of that kernel. + * + * @param config An object with the following properties: + * - `kernelName` The name of the kernel that the gradient function is for. + * - `gradFunc` The function to run during back-propagation. + */ + function registerGradient(config) { + var kernelName = config.kernelName; + if (gradRegistry.has(kernelName)) { + console.warn("Overriding the gradient for '" + kernelName + "'"); + } + gradRegistry.set(kernelName, config); + } + /** + * Removes the kernel function from the registry. + * + * @param kernelName The official name of the kernel. + * @param backendName The official name of the backend. + * + */ + function unregisterKernel(kernelName, backendName) { + var key = makeKey(kernelName, backendName); + if (!kernelRegistry.has(key)) { + throw new Error("The kernel '" + kernelName + "' for backend " + + ("'" + backendName + "' is not registered")); + } + kernelRegistry.delete(key); + } + /** Removes the registered gradient from the global registry. */ + function unregisterGradient(kernelName) { + if (!gradRegistry.has(kernelName)) { + throw new Error("The gradient '" + kernelName + "' for backend is not registered"); + } + gradRegistry.delete(kernelName); + } + function makeKey(kernelName, backendName) { + return backendName + "_" + kernelName; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Shuffles the array in-place using Fisher-Yates algorithm. + * + * ```js + * const a = [1, 2, 3, 4, 5]; + * tf.util.shuffle(a); + * console.log(a); + * ``` + * + * @param array The array to shuffle in-place. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + // tslint:disable-next-line:no-any + function shuffle(array) { + var counter = array.length; + var temp = 0; + var index = 0; + // While there are elements in the array + while (counter > 0) { + // Pick a random index + index = (Math.random() * counter) | 0; + // Decrease counter by 1 + counter--; + // And swap the last element with it + temp = array[counter]; + array[counter] = array[index]; + array[index] = temp; + } + } + /** Clamps a value to a specified range. */ + function clamp(min, x, max) { + return Math.max(min, Math.min(x, max)); + } + function nearestLargerEven(val) { + return val % 2 === 0 ? val : val + 1; + } + function sum(arr) { + var sum = 0; + for (var i = 0; i < arr.length; i++) { + sum += arr[i]; + } + return sum; + } + /** + * Returns a sample from a uniform [a, b) distribution. + * + * @param a The minimum support (inclusive). + * @param b The maximum support (exclusive). + * @return A pseudorandom number on the half-open interval [a,b). + */ + function randUniform(a, b) { + var r = Math.random(); + return (b * r) + (1 - r) * a; + } + /** Returns the squared Euclidean distance between two vectors. */ + function distSquared(a, b) { + var result = 0; + for (var i = 0; i < a.length; i++) { + var diff = Number(a[i]) - Number(b[i]); + result += diff * diff; + } + return result; + } + /** + * Asserts that the expression is true. Otherwise throws an error with the + * provided message. + * + * ```js + * const x = 2; + * tf.util.assert(x === 2, 'x is not 2'); + * ``` + * + * @param expr The expression to assert (as a boolean). + * @param msg A function that returns the message to report when throwing an + * error. We use a function for performance reasons. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function assert(expr, msg) { + if (!expr) { + throw new Error(typeof msg === 'string' ? msg : msg()); + } + } + function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { + if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } + assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); }); + } + function assertNonNull(a) { + assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; }); + } + // NOTE: We explicitly type out what T extends instead of any so that + // util.flatten on a nested array of number doesn't try to infer T as a + // number[][], causing us to explicitly type util.flatten(). + /** + * Flattens an arbitrarily nested array. + * + * ```js + * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; + * const flat = tf.util.flatten(a); + * console.log(flat); + * ``` + * + * @param arr The nested array to flatten. + * @param result The destination array which holds the elements. + * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults + * to false. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function flatten(arr, result, skipTypedArray) { + if (result === void 0) { result = []; } + if (skipTypedArray === void 0) { skipTypedArray = false; } + if (result == null) { + result = []; + } + if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { + for (var i = 0; i < arr.length; ++i) { + flatten(arr[i], result, skipTypedArray); + } + } + else { + result.push(arr); + } + return result; + } + /** + * Returns the size (number of elements) of the tensor given its shape. + * + * ```js + * const shape = [3, 4, 2]; + * const size = tf.util.sizeFromShape(shape); + * console.log(size); + * ``` + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function sizeFromShape(shape) { + if (shape.length === 0) { + // Scalar. + return 1; + } + var size = shape[0]; + for (var i = 1; i < shape.length; i++) { + size *= shape[i]; + } + return size; + } + function isScalarShape(shape) { + return shape.length === 0; + } + function arraysEqual(n1, n2) { + if (n1 === n2) { + return true; + } + if (n1 == null || n2 == null) { + return false; + } + if (n1.length !== n2.length) { + return false; + } + for (var i = 0; i < n1.length; i++) { + if (n1[i] !== n2[i]) { + return false; + } + } + return true; + } + function isInt(a) { + return a % 1 === 0; + } + function tanh(x) { + // tslint:disable-next-line:no-any + if (Math.tanh != null) { + // tslint:disable-next-line:no-any + return Math.tanh(x); + } + if (x === Infinity) { + return 1; + } + else if (x === -Infinity) { + return -1; + } + else { + var e2x = Math.exp(2 * x); + return (e2x - 1) / (e2x + 1); + } + } + function sizeToSquarishShape(size) { + var width = Math.ceil(Math.sqrt(size)); + return [width, Math.ceil(size / width)]; + } + /** + * Creates a new array with randomized indicies to a given quantity. + * + * ```js + * const randomTen = tf.util.createShuffledIndices(10); + * console.log(randomTen); + * ``` + * + * @param number Quantity of how many shuffled indicies to create. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function createShuffledIndices(n) { + var shuffledIndices = new Uint32Array(n); + for (var i = 0; i < n; ++i) { + shuffledIndices[i] = i; + } + shuffle(shuffledIndices); + return shuffledIndices; + } + function rightPad(a, size) { + if (size <= a.length) { + return a; + } + return a + ' '.repeat(size - a.length); + } + function repeatedTry(checkFn, delayFn, maxCounter) { + if (delayFn === void 0) { delayFn = function (counter) { return 0; }; } + return new Promise(function (resolve, reject) { + var tryCount = 0; + var tryFn = function () { + if (checkFn()) { + resolve(); + return; + } + tryCount++; + var nextBackoff = delayFn(tryCount); + if (maxCounter != null && tryCount >= maxCounter) { + reject(); + return; + } + setTimeout(tryFn, nextBackoff); + }; + tryFn(); + }); + } + /** + * Given the full size of the array and a shape that may contain -1 as the + * implicit dimension, returns the inferred shape where -1 is replaced. + * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. + * + * @param shape The shape, which may contain -1 in some dimension. + * @param size The full size (number of elements) of the array. + * @return The inferred shape where -1 is replaced with the inferred size. + */ + function inferFromImplicitShape(shape, size) { + var shapeProd = 1; + var implicitIdx = -1; + for (var i = 0; i < shape.length; ++i) { + if (shape[i] >= 0) { + shapeProd *= shape[i]; + } + else if (shape[i] === -1) { + if (implicitIdx !== -1) { + throw Error("Shapes can only have 1 implicit size. " + + ("Found -1 at dim " + implicitIdx + " and dim " + i)); + } + implicitIdx = i; + } + else if (shape[i] < 0) { + throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i); + } + } + if (implicitIdx === -1) { + if (size > 0 && size !== shapeProd) { + throw Error("Size(" + size + ") must match the product of shape " + shape); + } + return shape; + } + if (shapeProd === 0) { + throw Error("Cannot infer the missing size in [" + shape + "] when " + + "there are 0 elements"); + } + if (size % shapeProd !== 0) { + throw Error("The implicit shape can't be a fractional number. " + + ("Got " + size + " / " + shapeProd)); + } + var newShape = shape.slice(); + newShape[implicitIdx] = size / shapeProd; + return newShape; + } + function parseAxisParam(axis, shape) { + var rank = shape.length; + // Normalize input + axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis); + // Check for valid range + assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { + return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + + ("got axis " + axis); + }); + // Check for only integers + assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " + + ("got axis " + axis); }); + // Handle negative axis. + return axis.map(function (a) { return a < 0 ? rank + a : a; }); + } + /** Reduces the shape by removing all dimensions of shape 1. */ + function squeezeShape(shape, axis) { + var newShape = []; + var keptDims = []; + var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; + var axes = (axis == null || isEmptyArray) ? + null : + parseAxisParam(axis, shape).sort(); + var j = 0; + for (var i = 0; i < shape.length; ++i) { + if (axes != null) { + if (axes[j] === i && shape[i] !== 1) { + throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1"); + } + if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { + newShape.push(shape[i]); + keptDims.push(i); + } + if (axes[j] <= i) { + j++; + } + } + if (shape[i] !== 1) { + newShape.push(shape[i]); + keptDims.push(i); + } + } + return { newShape: newShape, keptDims: keptDims }; + } + function getTypedArrayFromDType(dtype, size) { + var values = null; + if (dtype == null || dtype === 'float32') { + values = new Float32Array(size); + } + else if (dtype === 'int32') { + values = new Int32Array(size); + } + else if (dtype === 'bool') { + values = new Uint8Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + return values; + } + function getArrayFromDType(dtype, size) { + var values = null; + if (dtype == null || dtype === 'float32') { + values = new Float32Array(size); + } + else if (dtype === 'int32') { + values = new Int32Array(size); + } + else if (dtype === 'bool') { + values = new Uint8Array(size); + } + else if (dtype === 'string') { + values = new Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + return values; + } + function checkConversionForErrors(vals, dtype) { + for (var i = 0; i < vals.length; i++) { + var num = vals[i]; + if (isNaN(num) || !isFinite(num)) { + throw Error("A tensor of type " + dtype + " being uploaded contains " + num + "."); + } + } + } + /** Returns true if the dtype is valid. */ + function isValidDtype(dtype) { + return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || + dtype === 'int32' || dtype === 'string'; + } + /** + * Returns true if the new type can't encode the old type without loss of + * precision. + */ + function hasEncodingLoss(oldType, newType) { + if (newType === 'complex64') { + return false; + } + if (newType === 'float32' && oldType !== 'complex64') { + return false; + } + if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { + return false; + } + if (newType === 'bool' && oldType === 'bool') { + return false; + } + return true; + } + function isTypedArray(a) { + return a instanceof Float32Array || a instanceof Int32Array || + a instanceof Uint8Array; + } + function bytesPerElement(dtype) { + if (dtype === 'float32' || dtype === 'int32') { + return 4; + } + else if (dtype === 'complex64') { + return 8; + } + else if (dtype === 'bool') { + return 1; + } + else { + throw new Error("Unknown dtype " + dtype); + } + } + /** + * Returns the approximate number of bytes allocated in the string array - 2 + * bytes per character. Computing the exact bytes for a native string in JS is + * not possible since it depends on the encoding of the html page that serves + * the website. + */ + function bytesFromStringArray(arr) { + if (arr == null) { + return 0; + } + var bytes = 0; + arr.forEach(function (x) { return bytes += x.length; }); + return bytes; + } + /** Returns true if the value is a string. */ + function isString(value) { + return typeof value === 'string' || value instanceof String; + } + function isBoolean(value) { + return typeof value === 'boolean'; + } + function isNumber(value) { + return typeof value === 'number'; + } + function inferDtype(values) { + if (Array.isArray(values)) { + return inferDtype(values[0]); + } + if (values instanceof Float32Array) { + return 'float32'; + } + else if (values instanceof Int32Array || values instanceof Uint8Array) { + return 'int32'; + } + else if (isNumber(values)) { + return 'float32'; + } + else if (isString(values)) { + return 'string'; + } + else if (isBoolean(values)) { + return 'bool'; + } + return 'float32'; + } + function isFunction(f) { + return !!(f && f.constructor && f.call && f.apply); + } + function nearestDivisor(size, start) { + for (var i = start; i < size; ++i) { + if (size % i === 0) { + return i; + } + } + return size; + } + function computeStrides(shape) { + var rank = shape.length; + if (rank < 2) { + return []; + } + // Last dimension has implicit stride of 1, thus having D-1 (instead of D) + // strides. + var strides = new Array(rank - 1); + strides[rank - 2] = shape[rank - 1]; + for (var i = rank - 3; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + return strides; + } + function toTypedArray(a, dtype, debugMode) { + if (dtype === 'string') { + throw new Error('Cannot convert a string[] to a TypedArray'); + } + if (Array.isArray(a)) { + a = flatten(a); + } + if (debugMode) { + checkConversionForErrors(a, dtype); + } + if (noConversionNeeded(a, dtype)) { + return a; + } + if (dtype == null || dtype === 'float32' || dtype === 'complex64') { + return new Float32Array(a); + } + else if (dtype === 'int32') { + return new Int32Array(a); + } + else if (dtype === 'bool') { + var bool = new Uint8Array(a.length); + for (var i = 0; i < bool.length; ++i) { + if (Math.round(a[i]) !== 0) { + bool[i] = 1; + } + } + return bool; + } + else { + throw new Error("Unknown data type " + dtype); + } + } + function createNestedArray(offset, shape, a) { + var ret = new Array(); + if (shape.length === 1) { + var d = shape[0]; + for (var i = 0; i < d; i++) { + ret[i] = a[offset + i]; + } + } + else { + var d = shape[0]; + var rest = shape.slice(1); + var len = rest.reduce(function (acc, c) { return acc * c; }); + for (var i = 0; i < d; i++) { + ret[i] = createNestedArray(offset + i * len, rest, a); + } + } + return ret; + } + // Provide a nested array of TypedArray in given shape. + function toNestedArray(shape, a) { + if (shape.length === 0) { + // Scalar type should return a single number. + return a[0]; + } + var size = shape.reduce(function (acc, c) { return acc * c; }); + if (size === 0) { + // A tensor with shape zero should be turned into empty list. + return []; + } + if (size !== a.length) { + throw new Error("[" + shape + "] does not match the input size."); + } + return createNestedArray(0, shape, a); + } + function noConversionNeeded(a, dtype) { + return (a instanceof Float32Array && dtype === 'float32') || + (a instanceof Int32Array && dtype === 'int32') || + (a instanceof Uint8Array && dtype === 'bool'); + } + function makeOnesTypedArray(size, dtype) { + var array = makeZerosTypedArray(size, dtype); + for (var i = 0; i < array.length; i++) { + array[i] = 1; + } + return array; + } + function makeZerosTypedArray(size, dtype) { + if (dtype == null || dtype === 'float32' || dtype === 'complex64') { + return new Float32Array(size); + } + else if (dtype === 'int32') { + return new Int32Array(size); + } + else if (dtype === 'bool') { + return new Uint8Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + } + /** + * Returns the current high-resolution time in milliseconds relative to an + * arbitrary time in the past. It works across different platforms (node.js, + * browsers). + * + * ```js + * console.log(tf.util.now()); + * ``` + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function now() { + return env().platform.now(); + } + function assertNonNegativeIntegerDimensions(shape) { + shape.forEach(function (dimSize) { + assert(Number.isInteger(dimSize) && dimSize >= 0, function () { + return "Tensor must have a shape comprised of positive integers but got " + + ("shape [" + shape + "]."); + }); + }); + } + /** + * Returns a platform-specific implementation of + * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). + * + * If `fetch` is defined on the global object (`window`, `process`, etc.), + * `tf.util.fetch` returns that function. + * + * If not, `tf.util.fetch` returns a platform-specific solution. + * + * ```js + * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs'); + * // handle response + * ``` + */ + /** @doc {heading: 'Util'} */ + function fetch$1(path, requestInits) { + return env().platform.fetch(path, requestInits); + } + /** + * Encodes the provided string into bytes using the provided encoding scheme. + * + * @param s The string to encode. + * @param encoding The encoding scheme. Defaults to utf-8. + * + */ + /** @doc {heading: 'Util'} */ + function encodeString(s, encoding) { + if (encoding === void 0) { encoding = 'utf-8'; } + encoding = encoding || 'utf-8'; + return env().platform.encode(s, encoding); + } + /** + * Decodes the provided bytes into a string using the provided encoding scheme. + * @param bytes The bytes to decode. + * + * @param encoding The encoding scheme. Defaults to utf-8. + */ + /** @doc {heading: 'Util'} */ + function decodeString(bytes, encoding) { + if (encoding === void 0) { encoding = 'utf-8'; } + encoding = encoding || 'utf-8'; + return env().platform.decode(bytes, encoding); + } + /** + * Computes flat index for a given location (multidimentionsal index) in a + * Tensor/multidimensional array. + * + * @param locs Location in the tensor. + * @param rank Rank of the tensor. + * @param strides Tensor strides. + */ + function locToIndex(locs, rank, strides) { + if (rank === 0) { + return 0; + } + else if (rank === 1) { + return locs[0]; + } + var index = locs[locs.length - 1]; + for (var i = 0; i < locs.length - 1; ++i) { + index += strides[i] * locs[i]; + } + return index; + } + /** + * Computes the location (multidimensional index) in a tensor/multidimentional + * array for a given flat index. + * + * @param index Index in flat array. + * @param rank Rank of tensor. + * @param strides Strides of tensor. + */ + function indexToLoc(index, rank, strides) { + if (rank === 0) { + return []; + } + else if (rank === 1) { + return [index]; + } + var locs = new Array(rank); + for (var i = 0; i < locs.length - 1; ++i) { + locs[i] = Math.floor(index / strides[i]); + index -= locs[i] * strides[i]; + } + locs[locs.length - 1] = index; + return locs; + } + + var util = { + __proto__: null, + shuffle: shuffle, + clamp: clamp, + nearestLargerEven: nearestLargerEven, + sum: sum, + randUniform: randUniform, + distSquared: distSquared, + assert: assert, + assertShapesMatch: assertShapesMatch, + assertNonNull: assertNonNull, + flatten: flatten, + sizeFromShape: sizeFromShape, + isScalarShape: isScalarShape, + arraysEqual: arraysEqual, + isInt: isInt, + tanh: tanh, + sizeToSquarishShape: sizeToSquarishShape, + createShuffledIndices: createShuffledIndices, + rightPad: rightPad, + repeatedTry: repeatedTry, + inferFromImplicitShape: inferFromImplicitShape, + parseAxisParam: parseAxisParam, + squeezeShape: squeezeShape, + getTypedArrayFromDType: getTypedArrayFromDType, + getArrayFromDType: getArrayFromDType, + checkConversionForErrors: checkConversionForErrors, + isValidDtype: isValidDtype, + hasEncodingLoss: hasEncodingLoss, + isTypedArray: isTypedArray, + bytesPerElement: bytesPerElement, + bytesFromStringArray: bytesFromStringArray, + isString: isString, + isBoolean: isBoolean, + isNumber: isNumber, + inferDtype: inferDtype, + isFunction: isFunction, + nearestDivisor: nearestDivisor, + computeStrides: computeStrides, + toTypedArray: toTypedArray, + toNestedArray: toNestedArray, + makeOnesTypedArray: makeOnesTypedArray, + makeZerosTypedArray: makeZerosTypedArray, + now: now, + assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions, + fetch: fetch$1, + encodeString: encodeString, + decodeString: decodeString, + locToIndex: locToIndex, + indexToLoc: indexToLoc + }; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Profiler = /** @class */ (function () { + function Profiler(backendTimer, logger) { + this.backendTimer = backendTimer; + this.logger = logger; + if (logger == null) { + this.logger = new Logger(); + } + } + Profiler.prototype.profileKernel = function (kernelName, inputs, f) { + var _this = this; + var outputs; + var holdResultWrapperFn = function () { + outputs = f(); + }; + var timer = this.backendTimer.time(holdResultWrapperFn); + outputs.forEach(function (r) { + // Dangling promise here because we don't want to propagate up + // asynchronicity. + r.data().then(function (vals) { + checkComputationForErrors(vals, r.dtype, kernelName); + timer.then(function (timing) { + var extraInfo = ''; + if (timing.getExtraProfileInfo != null) { + extraInfo = timing.getExtraProfileInfo(); + } + _this.logger.logKernelProfile(kernelName, r, vals, timing.kernelMs, inputs, extraInfo); + }); + }); + }); + return outputs; + }; + return Profiler; + }()); + function checkComputationForErrors(vals, dtype, kernelName) { + if (dtype !== 'float32') { + // Only floating point computations will generate NaN values + return false; + } + for (var i = 0; i < vals.length; i++) { + var num = vals[i]; + if (isNaN(num) || !isFinite(num)) { + // Throwing custom exception so behavior is testable. + console.warn("Found " + num + " in the result of '" + kernelName + "'"); + return true; + } + } + return false; + } + var Logger = /** @class */ (function () { + function Logger() { + } + Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) { + var time = typeof timeMs === 'number' ? rightPad(timeMs + "ms", 9) : + timeMs['error']; + var paddedName = rightPad(name, 25); + var rank = result.rank; + var size = result.size; + var shape = rightPad(result.shape.toString(), 14); + var inputShapesDescription = ''; + for (var name_1 in inputs) { + var input = inputs[name_1]; + // The input might be a non-tensor (e.g HTMLImageElement), in which case + // we claim the output shape as input shape. + var inputShape = input.shape || result.shape; + var inputRank = inputShape.length; + inputShapesDescription += + name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " "; + } + console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); + }; + return Logger; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes a list of TapeNodes that connect x to y, filtering everything else + * out and preserving the order of the original tape elements. + * + * @param tape The tape elements to filter. + * @param xs The input Tensors. + * @param y The output Tensor. + */ + function getFilteredNodesXToY(tape, xs, y) { + // Forward pass to compute all the nodes and Tensors that are transitively a + // function of x. + var tensorsFromX = {}; + var nodesFromX = {}; + for (var i = 0; i < xs.length; i++) { + tensorsFromX[xs[i].id] = true; + } + for (var i = 0; i < tape.length; i++) { + var node = tape[i]; + var nodeInputs = node.inputs; + for (var inputName in nodeInputs) { + var input = nodeInputs[inputName]; + var anyInputFromX = false; + for (var j = 0; j < xs.length; j++) { + if (tensorsFromX[input.id]) { + node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; }); + anyInputFromX = true; + nodesFromX[node.id] = true; + break; + } + } + if (anyInputFromX) { + break; + } + } + } + // Backward pass to find all of the nodes and Tensors that lead to y. + var tensorsLeadToY = {}; + tensorsLeadToY[y.id] = true; + var nodesToY = {}; + for (var i = tape.length - 1; i >= 0; i--) { + var node = tape[i]; + var nodeInputs = node.inputs; + // If any of the outputs lead to y, mark all of the inputs as leading to y. + for (var j = 0; j < node.outputs.length; j++) { + if (tensorsLeadToY[node.outputs[j].id]) { + for (var inputName in nodeInputs) { + tensorsLeadToY[nodeInputs[inputName].id] = true; + nodesToY[node.id] = true; + } + break; + } + } + } + // Return the paths that come from x and lead to y. + var filteredTape = []; + for (var i = 0; i < tape.length; i++) { + var node = tape[i]; + if (nodesFromX[node.id] && nodesToY[node.id]) { + // Prune the inputs from the node that aren't a function of x. + var prunedInputs = {}; + for (var inputName in node.inputs) { + var nodeInput = node.inputs[inputName]; + if (tensorsFromX[nodeInput.id]) { + prunedInputs[inputName] = nodeInput; + } + } + // Copy the node and overwrite inputsAndArgs to the pruned version. + var prunedNode = Object.assign({}, node); + prunedNode.inputs = prunedInputs; + prunedNode.outputs = node.outputs; + filteredTape.push(prunedNode); + } + } + return filteredTape; + } + /** + * Backpropagate gradients through the filtered TapeNodes. + * + * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map + * is mutated by this method. + * @param filteredTape The filtered TapeNodes to backprop through. + */ + function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) { + var _loop_1 = function (i) { + var node = filteredTape[i]; + var dys = []; + node.outputs.forEach(function (o) { + var gradTensor = tensorAccumulatedGradientMap[o.id]; + if (gradTensor != null) { + dys.push(gradTensor); + } + else { + // This particular output is not in the back-propagation subgraph, so it + // does not affect the final output, thus we put null for its dy. + dys.push(null); + } + }); + if (node.gradient == null) { + throw new Error("Cannot compute gradient: gradient function not found " + + ("for " + node.kernelName + ".")); + } + // Backprop dy through this node and accumulate gradients over the inputs. + var inputGradients = node.gradient(dys); + var _loop_2 = function (inputName) { + if (!(inputName in inputGradients)) { + throw new Error("Cannot backprop through input " + inputName + ". " + + ("Available gradients found: " + Object.keys(inputGradients) + ".")); + } + // Call the gradient function. + var dx = tidy(function () { return inputGradients[inputName](); }); + if (dx.dtype !== 'float32') { + throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + + (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'")); + } + var x = node.inputs[inputName]; + if (!arraysEqual(dx.shape, x.shape)) { + throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + + ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + + ("the shape of the input '" + x.shape + "'")); + } + if (tensorAccumulatedGradientMap[x.id] == null) { + tensorAccumulatedGradientMap[x.id] = dx; + } + else { + var curGradient = tensorAccumulatedGradientMap[x.id]; + tensorAccumulatedGradientMap[x.id] = curGradient.add(dx); + curGradient.dispose(); + } + }; + for (var inputName in node.inputs) { + _loop_2(inputName); + } + }; + // Walk the tape backward and keep a map of Tensor to its gradient. + for (var i = filteredTape.length - 1; i >= 0; i--) { + _loop_1(i); + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Maximum number of values before we decide to show ellipsis. + var FORMAT_LIMIT_NUM_VALS = 20; + // Number of first and last values to show when displaying a, b,...,y, z. + var FORMAT_NUM_FIRST_LAST_VALS = 3; + // Number of significant digits to show. + var FORMAT_NUM_SIG_DIGITS = 7; + function tensorToString(vals, shape, dtype, verbose) { + var strides = computeStrides(shape); + var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); + var rank = shape.length; + var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); + var lines = ['Tensor']; + if (verbose) { + lines.push(" dtype: " + dtype); + lines.push(" rank: " + rank); + lines.push(" shape: [" + shape + "]"); + lines.push(" values:"); + } + lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n')); + return lines.join('\n'); + } + function computeMaxSizePerColumn(vals, shape, dtype, strides) { + var n = sizeFromShape(shape); + var numCols = strides[strides.length - 1]; + var padPerCol = new Array(numCols).fill(0); + var rank = shape.length; + var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; + if (rank > 1) { + for (var row = 0; row < n / numCols; row++) { + var offset = row * numCols; + for (var j = 0; j < numCols; j++) { + padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); + } + } + } + return padPerCol; + } + function valToString(val, pad, dtype) { + var valStr; + if (Array.isArray(val)) { + valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j"); + } + else if (isString(val)) { + valStr = "'" + val + "'"; + } + else if (dtype === 'bool') { + valStr = boolNumToString(val); + } + else { + valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); + } + return rightPad(valStr, pad); + } + function boolNumToString(v) { + return v === 0 ? 'false' : 'true'; + } + function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) { + if (isLast === void 0) { isLast = true; } + var storagePerElement = dtype === 'complex64' ? 2 : 1; + var size = shape[0]; + var rank = shape.length; + if (rank === 0) { + if (dtype === 'complex64') { + var complexTuple = createComplexTuples(vals); + return [valToString(complexTuple[0], 0, dtype)]; + } + if (dtype === 'bool') { + return [boolNumToString(vals[0])]; + } + return [vals[0].toString()]; + } + if (rank === 1) { + if (size > FORMAT_LIMIT_NUM_VALS) { + var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; + var firstVals = Array.from(vals.slice(0, firstValsSize)); + var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); + if (dtype === 'complex64') { + firstVals = createComplexTuples(firstVals); + lastVals = createComplexTuples(lastVals); + } + return [ + '[' + + firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) + .join(', ') + + ', ..., ' + + lastVals + .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); }) + .join(', ') + + ']' + ]; + } + var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : + Array.from(vals); + return [ + '[' + + displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) + .join(', ') + + ']' + ]; + } + // The array is rank 2 or more. + var subshape = shape.slice(1); + var substrides = strides.slice(1); + var stride = strides[0] * storagePerElement; + var lines = []; + if (size > FORMAT_LIMIT_NUM_VALS) { + for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)); + } + lines.push('...'); + for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); + } + } + else { + for (var i = 0; i < size; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); + } + } + var sep = rank === 2 ? ',' : ''; + lines[0] = '[' + lines[0] + sep; + for (var i = 1; i < lines.length - 1; i++) { + lines[i] = ' ' + lines[i] + sep; + } + var newLineSep = ',\n'; + for (var i = 2; i < rank; i++) { + newLineSep += '\n'; + } + lines[lines.length - 1] = + ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); + return lines; + } + function createComplexTuples(vals) { + var complexTuples = []; + for (var i = 0; i < vals.length; i += 2) { + complexTuples.push([vals[i], vals[i + 1]]); + } + return complexTuples; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * A mutable object, similar to `tf.Tensor`, that allows users to set values + * at locations before converting to an immutable `tf.Tensor`. + * + * See `tf.buffer` for creating a tensor buffer. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + var TensorBuffer = /** @class */ (function () { + function TensorBuffer(shape, dtype, values) { + var _this = this; + this.dtype = dtype; + this.shape = shape.slice(); + this.size = sizeFromShape(shape); + if (values != null) { + var n_1 = values.length; + assert(n_1 === this.size, function () { return "Length of values '" + n_1 + "' does not match the size " + + ("inferred by the shape '" + _this.size + "'."); }); + } + if (dtype === 'complex64') { + throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + + "a TensorBuffer for the real and imaginary parts separately and " + + "call tf.complex(real, imag)."); + } + this.values = values || getArrayFromDType(dtype, this.size); + this.strides = computeStrides(shape); + } + /** + * Sets a value in the buffer at a given location. + * + * @param value The value to set. + * @param locs The location indices. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + TensorBuffer.prototype.set = function (value) { + var _this = this; + var locs = []; + for (var _i = 1; _i < arguments.length; _i++) { + locs[_i - 1] = arguments[_i]; + } + if (locs.length === 0) { + locs = [0]; + } + assert(locs.length === this.rank, function () { return "The number of provided coordinates (" + locs.length + ") must " + + ("match the rank (" + _this.rank + ")"); }); + var index = this.locToIndex(locs); + this.values[index] = value; + }; + /** + * Returns the value in the buffer at the provided location. + * + * @param locs The location indices. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + TensorBuffer.prototype.get = function () { + var locs = []; + for (var _i = 0; _i < arguments.length; _i++) { + locs[_i] = arguments[_i]; + } + if (locs.length === 0) { + locs = [0]; + } + var i = 0; + for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) { + var loc = locs_1[_a]; + if (loc < 0 || loc >= this.shape[i]) { + var msg = "Requested out of range element at " + locs + ". " + + (" Buffer shape=" + this.shape); + throw new Error(msg); + } + i++; + } + var index = locs[locs.length - 1]; + for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) { + index += this.strides[i_1] * locs[i_1]; + } + return this.values[index]; + }; + TensorBuffer.prototype.locToIndex = function (locs) { + if (this.rank === 0) { + return 0; + } + else if (this.rank === 1) { + return locs[0]; + } + var index = locs[locs.length - 1]; + for (var i = 0; i < locs.length - 1; ++i) { + index += this.strides[i] * locs[i]; + } + return index; + }; + TensorBuffer.prototype.indexToLoc = function (index) { + if (this.rank === 0) { + return []; + } + else if (this.rank === 1) { + return [index]; + } + var locs = new Array(this.shape.length); + for (var i = 0; i < locs.length - 1; ++i) { + locs[i] = Math.floor(index / this.strides[i]); + index -= locs[i] * this.strides[i]; + } + locs[locs.length - 1] = index; + return locs; + }; + Object.defineProperty(TensorBuffer.prototype, "rank", { + get: function () { + return this.shape.length; + }, + enumerable: true, + configurable: true + }); + /** + * Creates an immutable `tf.Tensor` object from the buffer. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + TensorBuffer.prototype.toTensor = function () { + return trackerFn().makeTensor(this.values, this.shape, this.dtype); + }; + return TensorBuffer; + }()); + // For tracking tensor creation and disposal. + var trackerFn = null; + // Used by chaining methods to call into ops. + var opHandler = null; + /** + * An external consumer can register itself as the tensor tracker. This way + * the Tensor class can notify the tracker for every tensor created and + * disposed. + */ + function setTensorTracker(fn) { + trackerFn = fn; + } + /** + * An external consumer can register itself as the op handler. This way the + * Tensor class can have chaining methods that call into ops via the op + * handler. + */ + function setOpHandler(handler) { + opHandler = handler; + } + /** + * A `tf.Tensor` object represents an immutable, multidimensional array of + * numbers that has a shape and a data type. + * + * See `tf.tensor` for details on how to create a `tf.Tensor`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + var Tensor = /** @class */ (function () { + function Tensor(shape, dtype, dataId, id) { + /** Whether this tensor has been globally kept. */ + this.kept = false; + this.isDisposedInternal = false; + this.shape = shape.slice(); + this.dtype = dtype || 'float32'; + this.size = sizeFromShape(shape); + this.strides = computeStrides(shape); + this.dataId = dataId; + this.id = id; + this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); + } + /** Flatten a Tensor to a 1D array. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.flatten = function () { + this.throwIfDisposed(); + return this.as1D(); + }; + /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.asScalar = function () { + this.throwIfDisposed(); + assert(this.size === 1, function () { return 'The array must have only 1 element.'; }); + return this.reshape([]); + }; + /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as1D = function () { + this.throwIfDisposed(); + return this.reshape([this.size]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor2D`. + * + * @param rows Number of rows in `tf.Tensor2D`. + * @param columns Number of columns in `tf.Tensor2D`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as2D = function (rows, columns) { + this.throwIfDisposed(); + return this.reshape([rows, columns]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor3D`. + * + * @param rows Number of rows in `tf.Tensor3D`. + * @param columns Number of columns in `tf.Tensor3D`. + * @param depth Depth of `tf.Tensor3D`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as3D = function (rows, columns, depth) { + this.throwIfDisposed(); + return this.reshape([rows, columns, depth]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor4D`. + * + * @param rows Number of rows in `tf.Tensor4D`. + * @param columns Number of columns in `tf.Tensor4D`. + * @param depth Depth of `tf.Tensor4D`. + * @param depth2 4th dimension of `tf.Tensor4D`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as4D = function (rows, columns, depth, depth2) { + this.throwIfDisposed(); + return this.reshape([rows, columns, depth, depth2]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor5D`. + * + * @param rows Number of rows in `tf.Tensor5D`. + * @param columns Number of columns in `tf.Tensor5D`. + * @param depth Depth of `tf.Tensor5D`. + * @param depth2 4th dimension of `tf.Tensor5D`. + * @param depth3 5th dimension of 'tf.Tensor5D' + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as5D = function (rows, columns, depth, depth2, depth3) { + this.throwIfDisposed(); + return this.reshape([rows, columns, depth, depth2, depth3]); + }; + /** + * Casts a `tf.Tensor` to a specified dtype. + * + * @param dtype Data-type to cast the tensor to. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.asType = function (dtype) { + this.throwIfDisposed(); + return opHandler.cast(this, dtype); + }; + Object.defineProperty(Tensor.prototype, "rank", { + get: function () { + return this.shape.length; + }, + enumerable: true, + configurable: true + }); + /** + * Returns a promise of `tf.TensorBuffer` that holds the underlying data. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.buffer = function () { + return __awaiter(this, void 0, void 0, function () { + var vals; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.data()]; + case 1: + vals = _a.sent(); + return [2 /*return*/, opHandler.buffer(this.shape, this.dtype, vals)]; + } + }); + }); + }; + /** Returns a `tf.TensorBuffer` that holds the underlying data. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.bufferSync = function () { + return opHandler.buffer(this.shape, this.dtype, this.dataSync()); + }; + /** + * Returns the tensor data as a nested array. The transfer of data is done + * asynchronously. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.array = function () { + return __awaiter(this, void 0, void 0, function () { + var vals; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.data()]; + case 1: + vals = _a.sent(); + return [2 /*return*/, toNestedArray(this.shape, vals)]; + } + }); + }); + }; + /** + * Returns the tensor data as a nested array. The transfer of data is done + * synchronously. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.arraySync = function () { + return toNestedArray(this.shape, this.dataSync()); + }; + /** + * Asynchronously downloads the values from the `tf.Tensor`. Returns a + * promise of `TypedArray` that resolves when the computation has finished. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.data = function () { + return __awaiter(this, void 0, void 0, function () { + var data, bytes; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + this.throwIfDisposed(); + data = trackerFn().read(this.dataId); + if (!(this.dtype === 'string')) return [3 /*break*/, 2]; + return [4 /*yield*/, data]; + case 1: + bytes = _a.sent(); + try { + return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })]; + } + catch (_b) { + throw new Error('Failed to decode the string bytes into utf-8. ' + + 'To get the original bytes, call tensor.bytes().'); + } + _a.label = 2; + case 2: return [2 /*return*/, data]; + } + }); + }); + }; + /** + * Synchronously downloads the values from the `tf.Tensor`. This blocks the + * UI thread until the values are ready, which can cause performance issues. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.dataSync = function () { + this.throwIfDisposed(); + var data = trackerFn().readSync(this.dataId); + if (this.dtype === 'string') { + try { + return data.map(function (b) { return decodeString(b); }); + } + catch (_a) { + throw new Error('Failed to decode the string bytes into utf-8. ' + + 'To get the original bytes, call tensor.bytes().'); + } + } + return data; + }; + /** Returns the underlying bytes of the tensor's data. */ + Tensor.prototype.bytes = function () { + return __awaiter(this, void 0, void 0, function () { + var data; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + this.throwIfDisposed(); + return [4 /*yield*/, trackerFn().read(this.dataId)]; + case 1: + data = _a.sent(); + if (this.dtype === 'string') { + return [2 /*return*/, data]; + } + else { + return [2 /*return*/, new Uint8Array(data.buffer)]; + } + } + }); + }); + }; + /** + * Disposes `tf.Tensor` from memory. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.dispose = function () { + if (this.isDisposed) { + return; + } + trackerFn().disposeTensor(this); + this.isDisposedInternal = true; + }; + Object.defineProperty(Tensor.prototype, "isDisposed", { + get: function () { + return this.isDisposedInternal; + }, + enumerable: true, + configurable: true + }); + Tensor.prototype.throwIfDisposed = function () { + if (this.isDisposed) { + throw new Error("Tensor is disposed."); + } + }; + /** Casts the array to type `float32` */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toFloat = function () { + return this.asType('float32'); + }; + /** Casts the array to type `int32` */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toInt = function () { + return this.asType('int32'); + }; + /** Casts the array to type `bool` */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toBool = function () { + return this.asType('bool'); + }; + /** + * Prints the `tf.Tensor`. See `tf.print` for details. + * + * @param verbose Whether to print verbose information about the tensor, + * including dtype and size. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.print = function (verbose) { + if (verbose === void 0) { verbose = false; } + return opHandler.print(this, verbose); + }; + /** + * Reshapes the tensor into the provided shape. + * See `tf.reshape` for more details. + * + * @param newShape An array of integers defining the output tensor shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.reshape = function (newShape) { + this.throwIfDisposed(); + return opHandler.reshape(this, newShape); + }; + /** + * Reshapes the tensor into the shape of the provided tensor. + * + * @param x The tensor of required shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.reshapeAs = function (x) { + this.throwIfDisposed(); + return this.reshape(x.shape); + }; + /** + * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension + * into the tensor's shape. See `tf.expandDims` for details. + * + * @param axis The dimension index at which to insert shape of 1. Defaults to + * 0 (the first dimension). + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.expandDims = function (axis) { + if (axis === void 0) { axis = 0; } + return opHandler.expandDims(this, axis); + }; + /** + * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape. + * See `tf.squeeze` for more details. + * + * @param axis A list of numbers. If specified, only squeezes the + * dimensions listed. The dimension index starts at 0. It is an error to + * squeeze a dimension that is not 1. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.squeeze = function (axis) { + this.throwIfDisposed(); + return opHandler.squeeze(this, axis); + }; + /** Returns a copy of the tensor. See `tf.clone` for details. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.clone = function () { + this.throwIfDisposed(); + return opHandler.clone(this); + }; + /** + * Returns a human-readable description of the tensor. Useful for logging. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toString = function (verbose) { + if (verbose === void 0) { verbose = false; } + var vals = this.dataSync(); + return tensorToString(vals, this.shape, this.dtype, verbose); + }; + // Below is chain API that is not exposed to docs to avoid repetition. To + // expose a method, move it above this comment and add @doc and jsdoc. + Tensor.prototype.gather = function (indices, axis) { + if (axis === void 0) { axis = 0; } + this.throwIfDisposed(); + return opHandler.gather(this, indices, axis); + }; + Tensor.prototype.norm = function (ord, axis, keepDims) { + if (ord === void 0) { ord = 'euclidean'; } + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.norm(this, ord, axis, keepDims); + }; + Tensor.prototype.slice = function (begin, size) { + this.throwIfDisposed(); + return opHandler.slice(this, begin, size); + }; + Tensor.prototype.reverse = function (axis) { + this.throwIfDisposed(); + return opHandler.reverse(this, axis); + }; + Tensor.prototype.stack = function (x, axis) { + if (axis === void 0) { axis = 0; } + return opHandler.stack([this, x], axis); + }; + Tensor.prototype.unstack = function (axis) { + if (axis === void 0) { axis = 0; } + return opHandler.unstack(this, axis); + }; + // Reduction ops. + Tensor.prototype.all = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.all(this, axis, keepDims); + }; + Tensor.prototype.any = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.any(this, axis, keepDims); + }; + Tensor.prototype.logSumExp = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.logSumExp(this, axis, keepDims); + }; + Tensor.prototype.sum = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.sum(this, axis, keepDims); + }; + Tensor.prototype.prod = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.prod(this, axis, keepDims); + }; + Tensor.prototype.mean = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.mean(this, axis, keepDims); + }; + Tensor.prototype.min = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.min(this, axis, keepDims); + }; + Tensor.prototype.argMin = function (axis) { + if (axis === void 0) { axis = null; } + this.throwIfDisposed(); + return opHandler.argMin(this, axis); + }; + Tensor.prototype.argMax = function (axis) { + if (axis === void 0) { axis = null; } + this.throwIfDisposed(); + return opHandler.argMax(this, axis); + }; + // Transformations + Tensor.prototype.cast = function (dtype) { + this.throwIfDisposed(); + return opHandler.cast(this, dtype); + }; + // Binary ops. + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.addStrict = function (x) { + this.throwIfDisposed(); + return opHandler.addStrict(this, x); + }; + Tensor.prototype.atan2 = function (x) { + this.throwIfDisposed(); + return opHandler.atan2(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.subStrict = function (x) { + this.throwIfDisposed(); + return opHandler.subStrict(this, x); + }; + Tensor.prototype.pow = function (exp) { + this.throwIfDisposed(); + return opHandler.pow(this, exp); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.powStrict = function (exp) { + this.throwIfDisposed(); + return opHandler.powStrict(this, exp); + }; + Tensor.prototype.mul = function (x) { + this.throwIfDisposed(); + return opHandler.mul(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.mulStrict = function (x) { + this.throwIfDisposed(); + return opHandler.mulStrict(this, x); + }; + Tensor.prototype.floorDiv = function (x) { + this.throwIfDisposed(); + return opHandler.floorDiv(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.divStrict = function (x) { + this.throwIfDisposed(); + return opHandler.divStrict(this, x); + }; + Tensor.prototype.minimum = function (x) { + this.throwIfDisposed(); + return opHandler.minimum(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.minimumStrict = function (x) { + this.throwIfDisposed(); + return opHandler.minimumStrict(this, x); + }; + Tensor.prototype.maximum = function (x) { + this.throwIfDisposed(); + return opHandler.maximum(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.maximumStrict = function (x) { + this.throwIfDisposed(); + return opHandler.maximumStrict(this, x); + }; + Tensor.prototype.mod = function (x) { + this.throwIfDisposed(); + return opHandler.mod(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.modStrict = function (x) { + this.throwIfDisposed(); + return opHandler.modStrict(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.squaredDifferenceStrict = function (x) { + this.throwIfDisposed(); + return opHandler.squaredDifferenceStrict(this, x); + }; + // Compare ops. + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.notEqualStrict = function (x) { + this.throwIfDisposed(); + return opHandler.notEqualStrict(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.lessStrict = function (x) { + this.throwIfDisposed(); + return opHandler.lessStrict(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.equalStrict = function (x) { + this.throwIfDisposed(); + return opHandler.equalStrict(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.lessEqualStrict = function (x) { + this.throwIfDisposed(); + return opHandler.lessEqualStrict(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.greaterStrict = function (x) { + this.throwIfDisposed(); + return opHandler.greaterStrict(this, x); + }; + /** + * @deprecated strict variants of ops have been deprecated + */ + Tensor.prototype.greaterEqualStrict = function (x) { + this.throwIfDisposed(); + return opHandler.greaterEqualStrict(this, x); + }; + // Compare ops. + Tensor.prototype.logicalAnd = function (x) { + this.throwIfDisposed(); + return opHandler.logicalAnd(this, x); + }; + Tensor.prototype.logicalOr = function (x) { + this.throwIfDisposed(); + return opHandler.logicalOr(this, x); + }; + Tensor.prototype.logicalNot = function () { + this.throwIfDisposed(); + return opHandler.logicalNot(this); + }; + Tensor.prototype.logicalXor = function (x) { + this.throwIfDisposed(); + return opHandler.logicalXor(this, x); + }; + Tensor.prototype.where = function (condition, x) { + this.throwIfDisposed(); + return opHandler.where(condition, this, x); + }; + // Unary ops. + Tensor.prototype.neg = function () { + this.throwIfDisposed(); + return opHandler.neg(this); + }; + Tensor.prototype.ceil = function () { + this.throwIfDisposed(); + return opHandler.ceil(this); + }; + Tensor.prototype.floor = function () { + this.throwIfDisposed(); + return opHandler.floor(this); + }; + Tensor.prototype.sign = function () { + this.throwIfDisposed(); + return opHandler.sign(this); + }; + Tensor.prototype.isNaN = function () { + this.throwIfDisposed(); + return opHandler.isNaN(this); + }; + Tensor.prototype.isInf = function () { + this.throwIfDisposed(); + return opHandler.isInf(this); + }; + Tensor.prototype.isFinite = function () { + this.throwIfDisposed(); + return opHandler.isFinite(this); + }; + Tensor.prototype.exp = function () { + this.throwIfDisposed(); + return opHandler.exp(this); + }; + Tensor.prototype.expm1 = function () { + this.throwIfDisposed(); + return opHandler.expm1(this); + }; + Tensor.prototype.log = function () { + this.throwIfDisposed(); + return opHandler.log(this); + }; + Tensor.prototype.log1p = function () { + this.throwIfDisposed(); + return opHandler.log1p(this); + }; + Tensor.prototype.sqrt = function () { + this.throwIfDisposed(); + return opHandler.sqrt(this); + }; + Tensor.prototype.rsqrt = function () { + this.throwIfDisposed(); + return opHandler.rsqrt(this); + }; + Tensor.prototype.square = function () { + this.throwIfDisposed(); + return opHandler.square(this); + }; + Tensor.prototype.reciprocal = function () { + this.throwIfDisposed(); + return opHandler.reciprocal(this); + }; + Tensor.prototype.abs = function () { + this.throwIfDisposed(); + return opHandler.abs(this); + }; + Tensor.prototype.clipByValue = function (min, max) { + this.throwIfDisposed(); + return opHandler.clipByValue(this, min, max); + }; + Tensor.prototype.relu = function () { + this.throwIfDisposed(); + return opHandler.relu(this); + }; + Tensor.prototype.relu6 = function () { + this.throwIfDisposed(); + return opHandler.relu6(this); + }; + Tensor.prototype.elu = function () { + this.throwIfDisposed(); + return opHandler.elu(this); + }; + Tensor.prototype.selu = function () { + this.throwIfDisposed(); + return opHandler.selu(this); + }; + Tensor.prototype.leakyRelu = function (alpha) { + if (alpha === void 0) { alpha = 0.2; } + this.throwIfDisposed(); + return opHandler.leakyRelu(this, alpha); + }; + Tensor.prototype.prelu = function (alpha) { + this.throwIfDisposed(); + return opHandler.prelu(this, alpha); + }; + Tensor.prototype.sigmoid = function () { + this.throwIfDisposed(); + return opHandler.sigmoid(this); + }; + Tensor.prototype.logSigmoid = function () { + this.throwIfDisposed(); + return opHandler.logSigmoid(this); + }; + Tensor.prototype.softplus = function () { + this.throwIfDisposed(); + return opHandler.softplus(this); + }; + Tensor.prototype.zerosLike = function () { + this.throwIfDisposed(); + return opHandler.zerosLike(this); + }; + Tensor.prototype.onesLike = function () { + this.throwIfDisposed(); + return opHandler.onesLike(this); + }; + Tensor.prototype.sin = function () { + this.throwIfDisposed(); + return opHandler.sin(this); + }; + Tensor.prototype.cos = function () { + this.throwIfDisposed(); + return opHandler.cos(this); + }; + Tensor.prototype.tan = function () { + this.throwIfDisposed(); + return opHandler.tan(this); + }; + Tensor.prototype.asin = function () { + this.throwIfDisposed(); + return opHandler.asin(this); + }; + Tensor.prototype.acos = function () { + this.throwIfDisposed(); + return opHandler.acos(this); + }; + Tensor.prototype.atan = function () { + this.throwIfDisposed(); + return opHandler.atan(this); + }; + Tensor.prototype.sinh = function () { + this.throwIfDisposed(); + return opHandler.sinh(this); + }; + Tensor.prototype.cosh = function () { + this.throwIfDisposed(); + return opHandler.cosh(this); + }; + Tensor.prototype.tanh = function () { + this.throwIfDisposed(); + return opHandler.tanh(this); + }; + Tensor.prototype.asinh = function () { + this.throwIfDisposed(); + return opHandler.asinh(this); + }; + Tensor.prototype.acosh = function () { + this.throwIfDisposed(); + return opHandler.acosh(this); + }; + Tensor.prototype.atanh = function () { + this.throwIfDisposed(); + return opHandler.atanh(this); + }; + Tensor.prototype.erf = function () { + this.throwIfDisposed(); + return opHandler.erf(this); + }; + Tensor.prototype.round = function () { + this.throwIfDisposed(); + return opHandler.round(this); + }; + Tensor.prototype.step = function (alpha) { + if (alpha === void 0) { alpha = 0.0; } + this.throwIfDisposed(); + return opHandler.step(this, alpha); + }; + Tensor.prototype.softmax = function (dim) { + if (dim === void 0) { dim = -1; } + this.throwIfDisposed(); + return opHandler.softmax(this, dim); + }; + Tensor.prototype.logSoftmax = function (axis) { + if (axis === void 0) { axis = -1; } + this.throwIfDisposed(); + return opHandler.logSoftmax(this, axis); + }; + // Image ops. + Tensor.prototype.resizeBilinear = function (newShape2D, alignCorners) { + if (alignCorners === void 0) { alignCorners = false; } + this.throwIfDisposed(); + return opHandler.image.resizeBilinear(this, newShape2D, alignCorners); + }; + Tensor.prototype.resizeNearestNeighbor = function (newShape2D, alignCorners) { + if (alignCorners === void 0) { alignCorners = false; } + this.throwIfDisposed(); + return opHandler.image.resizeNearestNeighbor(this, newShape2D, alignCorners); + }; + // Pooling. + Tensor.prototype.variable = function (trainable, name, dtype) { + if (trainable === void 0) { trainable = true; } + this.throwIfDisposed(); + return trackerFn().makeVariable(this, trainable, name, dtype); + }; + Tensor.prototype.unsortedSegmentSum = function (segmentIds, numSegments) { + this.throwIfDisposed(); + return opHandler.unsortedSegmentSum(this, segmentIds, numSegments); + }; + Tensor.prototype.topk = function (k, sorted) { + if (k === void 0) { k = 1; } + if (sorted === void 0) { sorted = true; } + this.throwIfDisposed(); + return opHandler.topk(this, k, sorted); + }; + Tensor.prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) { + if (beginMask === void 0) { beginMask = 0; } + if (endMask === void 0) { endMask = 0; } + if (ellipsisMask === void 0) { ellipsisMask = 0; } + if (newAxisMask === void 0) { newAxisMask = 0; } + if (shrinkAxisMask === void 0) { shrinkAxisMask = 0; } + this.throwIfDisposed(); + return opHandler.stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + }; + Tensor.prototype.fft = function () { + this.throwIfDisposed(); + return opHandler.spectral.fft(this); + }; + Tensor.prototype.ifft = function () { + this.throwIfDisposed(); + return opHandler.spectral.ifft(this); + }; + Tensor.prototype.rfft = function () { + this.throwIfDisposed(); + return opHandler.spectral.rfft(this); + }; + Tensor.prototype.irfft = function () { + this.throwIfDisposed(); + return opHandler.spectral.irfft(this); + }; + return Tensor; + }()); + Object.defineProperty(Tensor, Symbol.hasInstance, { + value: function (instance) { + return !!instance && instance.dataId != null && instance.shape != null && + instance.dtype != null; + } + }); + /** + * A mutable `tf.Tensor`, useful for persisting state, e.g. for training. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + var Variable = /** @class */ (function (_super) { + __extends(Variable, _super); + function Variable(initialValue, trainable, name, tensorId) { + var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this; + _this.trainable = trainable; + _this.name = name; + return _this; + } + /** + * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have + * the same shape and dtype as the old `tf.Tensor`. + * + * @param newValue New tensor to be assigned to this variable. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Variable.prototype.assign = function (newValue) { + if (newValue.dtype !== this.dtype) { + throw new Error("dtype of the new value (" + newValue.dtype + ") and " + + ("previous value (" + this.dtype + ") must match")); + } + if (!arraysEqual(newValue.shape, this.shape)) { + throw new Error("shape of the new value (" + newValue.shape + ") and " + + ("previous value (" + this.shape + ") must match")); + } + trackerFn().disposeTensor(this); + this.dataId = newValue.dataId; + trackerFn().incRef(this, null /* backend */); + }; + Variable.prototype.dispose = function () { + trackerFn().disposeVariable(this); + this.isDisposedInternal = true; + }; + return Variable; + }(Tensor)); + Object.defineProperty(Variable, Symbol.hasInstance, { + value: function (instance) { + return instance instanceof Tensor && instance.assign != null && + instance.assign instanceof Function; + } + }); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + (function (Rank) { + Rank["R0"] = "R0"; + Rank["R1"] = "R1"; + Rank["R2"] = "R2"; + Rank["R3"] = "R3"; + Rank["R4"] = "R4"; + Rank["R5"] = "R5"; + Rank["R6"] = "R6"; + })(exports.Rank || (exports.Rank = {})); + // Looks for upcasting types. Used, for example, in operations with mixed dtype + // inputs. + var UpcastInt32AndMap; + (function (UpcastInt32AndMap) { + UpcastInt32AndMap["float32"] = "float32"; + UpcastInt32AndMap["int32"] = "int32"; + UpcastInt32AndMap["bool"] = "int32"; + UpcastInt32AndMap["complex64"] = "complex64"; + })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); + var UpcastBoolAndMap; + (function (UpcastBoolAndMap) { + UpcastBoolAndMap["float32"] = "float32"; + UpcastBoolAndMap["int32"] = "int32"; + UpcastBoolAndMap["bool"] = "bool"; + UpcastBoolAndMap["complex64"] = "complex64"; + })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); + var UpcastFloat32AndMap; + (function (UpcastFloat32AndMap) { + UpcastFloat32AndMap["float32"] = "float32"; + UpcastFloat32AndMap["int32"] = "float32"; + UpcastFloat32AndMap["bool"] = "float32"; + UpcastFloat32AndMap["complex64"] = "complex64"; + })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); + var UpcastComplex64AndMap; + (function (UpcastComplex64AndMap) { + UpcastComplex64AndMap["float32"] = "complex64"; + UpcastComplex64AndMap["int32"] = "complex64"; + UpcastComplex64AndMap["bool"] = "complex64"; + UpcastComplex64AndMap["complex64"] = "complex64"; + })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); + var upcastTypeMap = { + 'float32': UpcastFloat32AndMap, + 'int32': UpcastInt32AndMap, + 'bool': UpcastBoolAndMap, + 'complex64': UpcastComplex64AndMap + }; + function upcastType(typeA, typeB) { + if (typeA === 'string' || typeB === 'string') { + if (typeA === 'string' && typeB === 'string') { + return 'string'; + } + throw new Error("Can not upcast " + typeA + " with " + typeB); + } + return upcastTypeMap[typeA][typeB]; + } + /** Returns the output type after summation. */ + function sumOutType(type) { + return upcastType(type, 'int32'); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function makeTypesMatch(a, b) { + if (a.dtype === b.dtype) { + return [a, b]; + } + var dtype = upcastType(a.dtype, b.dtype); + return [a.cast(dtype), b.cast(dtype)]; + } + function assertTypesMatch(a, b) { + assert(a.dtype === b.dtype, function () { return "The dtypes of the first(" + a.dtype + ") and" + + (" second(" + b.dtype + ") input must match"); }); + } + function isTensorInList(tensor, tensorList) { + return tensorList.some(function (x) { return x.id === tensor.id; }); + } + /** + * Extracts any `Tensor`s found within the provided object. + * + * @param container an object that may be a `Tensor` or may directly contain + * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it + * is safe to pass any object here, except that `Promise`s are not + * supported. + * @returns An array of `Tensors` found within the passed object. If the + * argument is simply a `Tensor', a list containing that `Tensor` is + * returned. If the object is not a `Tensor` or does not + * contain `Tensors`, an empty list is returned. + */ + function getTensorsInContainer(result) { + var list = []; + var seen = new Set(); + walkTensorContainer(result, list, seen); + return list; + } + function walkTensorContainer(container, list, seen) { + if (container == null) { + return; + } + if (container instanceof Tensor) { + list.push(container); + return; + } + if (!isIterable(container)) { + return; + } + // Iteration over keys works also for arrays. + var iterable = container; + for (var k in iterable) { + var val = iterable[k]; + if (!seen.has(val)) { + seen.add(val); + walkTensorContainer(val, list, seen); + } + } + } + // tslint:disable-next-line:no-any + function isIterable(obj) { + return Array.isArray(obj) || typeof obj === 'object'; + } + + var tensor_util = { + __proto__: null, + makeTypesMatch: makeTypesMatch, + assertTypesMatch: assertTypesMatch, + isTensorInList: isTensorInList, + getTensorsInContainer: getTensorsInContainer + }; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EngineState = /** @class */ (function () { + function EngineState() { + // Public since optimizers will use it. + this.registeredVariables = {}; + this.nextTapeNodeId = 0; + this.numBytes = 0; + this.numTensors = 0; + this.numStringTensors = 0; + this.numDataBuffers = 0; + // Number of nested tf.grad() statements when computing higher-order + // gradients. E.g. `1` for first-order gradients and `2` for second-order + // gradients. Used to track if the tape should be removed after a backprop. + this.gradientDepth = 0; + // Number of nested kernel calls. When kernel depth is greater than 1, we turn + // off the tape. + this.kernelDepth = 0; + this.scopeStack = []; + /** + * Keeps track of the number of data moves during a kernel execution. We + * maintain a stack since kernels can call other kernels, recursively. + */ + this.numDataMovesStack = []; + this.nextScopeId = 0; + this.tensorInfo = new WeakMap(); + this.profiling = false; + this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null }; + } + EngineState.prototype.dispose = function () { + for (var variableName in this.registeredVariables) { + this.registeredVariables[variableName].dispose(); + } + }; + return EngineState; + }()); + var Engine = /** @class */ (function () { + function Engine(ENV) { + this.ENV = ENV; + this.registry = {}; + this.registryFactory = {}; + this.pendingBackendInitId = 0; + this.state = new EngineState(); + } + Engine.prototype.ready = function () { + return __awaiter(this, void 0, void 0, function () { + var sortedBackends, i, backendName, success; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (this.pendingBackendInit != null) { + return [2 /*return*/, this.pendingBackendInit.then(function () { })]; + } + if (this.backendInstance != null) { + return [2 /*return*/]; + } + sortedBackends = this.getSortedBackends(); + i = 0; + _a.label = 1; + case 1: + if (!(i < sortedBackends.length)) return [3 /*break*/, 5]; + backendName = sortedBackends[i]; + return [4 /*yield*/, this.initializeBackend(backendName).success]; + case 2: + success = _a.sent(); + if (!success) return [3 /*break*/, 4]; + return [4 /*yield*/, this.setBackend(backendName)]; + case 3: + _a.sent(); + return [2 /*return*/]; + case 4: + i++; + return [3 /*break*/, 1]; + case 5: throw new Error("Could not initialize any backends, all backend initializations " + + "failed."); + } + }); + }); + }; + Object.defineProperty(Engine.prototype, "backend", { + get: function () { + if (this.pendingBackendInit != null) { + throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " + + "sure to await tf.ready() or await tf.setBackend() before calling " + + "other methods"); + } + if (this.backendInstance == null) { + var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit; + if (asyncInit) { + throw new Error("The highest priority backend '" + name_1 + "' has not yet been " + + "initialized. Make sure to await tf.ready() or " + + "await tf.setBackend() before calling other methods"); + } + this.setBackend(name_1); + } + return this.backendInstance; + }, + enumerable: true, + configurable: true + }); + Engine.prototype.backendNames = function () { + return Object.keys(this.registryFactory); + }; + Engine.prototype.findBackend = function (backendName) { + if (!(backendName in this.registry)) { + // If the backend hasn't been initialized but we have a registry entry for + // it, initialize it and return it. + if (backendName in this.registryFactory) { + var asyncInit = this.initializeBackend(backendName).asyncInit; + if (asyncInit) { + // Backend is not ready yet. + return null; + } + } + else { + return null; + } + } + return this.registry[backendName]; + }; + Engine.prototype.findBackendFactory = function (backendName) { + if (!(backendName in this.registryFactory)) { + return null; + } + return this.registryFactory[backendName].factory; + }; + Engine.prototype.registerBackend = function (backendName, factory, priority) { + if (priority === void 0) { priority = 1; } + if (backendName in this.registryFactory) { + console.warn(backendName + " backend was already registered. " + + "Reusing existing backend factory."); + return false; + } + this.registryFactory[backendName] = { factory: factory, priority: priority }; + return true; + }; + Engine.prototype.setBackend = function (backendName) { + return __awaiter(this, void 0, void 0, function () { + var _a, success, asyncInit, result, _b; + return __generator(this, function (_c) { + switch (_c.label) { + case 0: + if (this.registryFactory[backendName] == null) { + throw new Error("Backend name '" + backendName + "' not found in registry"); + } + this.backendName = backendName; + if (!(this.registry[backendName] == null)) return [3 /*break*/, 4]; + this.backendInstance = null; + _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; + if (!asyncInit) return [3 /*break*/, 2]; + return [4 /*yield*/, success]; + case 1: + _b = _c.sent(); + return [3 /*break*/, 3]; + case 2: + _b = success; + _c.label = 3; + case 3: + result = _b; + if (!result) { + return [2 /*return*/, false]; + } + _c.label = 4; + case 4: + this.backendInstance = this.registry[backendName]; + this.setupRegisteredKernels(); + // Reset the profiler. + this.profiler = new Profiler(this.backendInstance); + return [2 /*return*/, true]; + } + }); + }); + }; + Engine.prototype.setupRegisteredKernels = function () { + var _this = this; + var kernels = getKernelsForBackend(this.backendName); + kernels.forEach(function (kernel) { + if (kernel.setupFunc != null) { + kernel.setupFunc(_this.backendInstance); + } + }); + }; + Engine.prototype.disposeRegisteredKernels = function (backendName) { + var _this = this; + var kernels = getKernelsForBackend(backendName); + kernels.forEach(function (kernel) { + if (kernel.disposeFunc != null) { + kernel.disposeFunc(_this.registry[backendName]); + } + }); + }; + /** + * Initializes a backend by looking up the backend name in the factory + * registry and calling the factory method. Returns a boolean representing + * whether the initialization of the backend suceeded. Throws an error if + * there is no backend in the factory registry. + */ + Engine.prototype.initializeBackend = function (backendName) { + var _this = this; + var registryFactoryEntry = this.registryFactory[backendName]; + if (registryFactoryEntry == null) { + throw new Error("Cannot initialize backend " + backendName + ", no registration found."); + } + try { + var backend = registryFactoryEntry.factory(); + // Test if the factory returns a promise. + if (Promise.resolve(backend) === backend) { + var promiseId_1 = ++this.pendingBackendInitId; + var success = backend + .then(function (backendInstance) { + // Outdated promise. Another backend was set in the meantime. + if (promiseId_1 < _this.pendingBackendInitId) { + return false; + } + _this.registry[backendName] = backendInstance; + _this.pendingBackendInit = null; + return true; + }) + .catch(function (err) { + // Outdated promise. Another backend was set in the meantime. + if (promiseId_1 < _this.pendingBackendInitId) { + return false; + } + _this.pendingBackendInit = null; + console.warn("Initialization of backend " + backendName + " failed"); + console.warn(err.stack || err.message); + return false; + }); + this.pendingBackendInit = success; + return { success: success, asyncInit: true }; + } + else { + this.registry[backendName] = backend; + return { success: true, asyncInit: false }; + } + } + catch (err) { + console.warn("Initialization of backend " + backendName + " failed"); + console.warn(err.stack || err.message); + return { success: false, asyncInit: false }; + } + }; + Engine.prototype.removeBackend = function (backendName) { + if (!(backendName in this.registryFactory)) { + throw new Error(backendName + " backend not found in registry"); + } + if (this.backendName === backendName && this.pendingBackendInit != null) { + // There is a pending promise of the backend we want to remove. Make it + // obsolete. + this.pendingBackendInitId++; + } + if (backendName in this.registry) { + this.disposeRegisteredKernels(backendName); + this.registry[backendName].dispose(); + delete this.registry[backendName]; + } + delete this.registryFactory[backendName]; + // Unset the backend if it is active. + if (this.backendName === backendName) { + this.pendingBackendInit = null; + this.backendName = null; + this.backendInstance = null; + } + }; + Engine.prototype.getSortedBackends = function () { + var _this = this; + if (Object.keys(this.registryFactory).length === 0) { + throw new Error('No backend found in registry.'); + } + return Object.keys(this.registryFactory).sort(function (a, b) { + // Highest priority comes first. + return _this.registryFactory[b].priority - + _this.registryFactory[a].priority; + }); + }; + Engine.prototype.initializeBackendsAndReturnBest = function () { + var sortedBackends = this.getSortedBackends(); + for (var i = 0; i < sortedBackends.length; i++) { + var backendName = sortedBackends[i]; + var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; + if (asyncInit || success) { + return { name: backendName, asyncInit: asyncInit }; + } + } + throw new Error("Could not initialize any backends, all backend initializations " + + "failed."); + }; + Engine.prototype.moveData = function (backend, dataId) { + var info = this.state.tensorInfo.get(dataId); + var srcBackend = info.backend; + var values = this.readSync(dataId); + // Delete the tensor from the old backend and move it to the new + // backend. + srcBackend.disposeData(dataId); + info.backend = backend; + backend.move(dataId, values, info.shape, info.dtype); + if (this.shouldCheckForMemLeaks()) { + // Track the number of moves during a kernel execution to correctly + // detect memory leaks. + this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; + } + }; + Engine.prototype.tidy = function (nameOrFn, fn) { + var _this = this; + var name = null; + if (fn == null) { + // Called with only 1 argument. + if (typeof nameOrFn !== 'function') { + throw new Error('Please provide a function to tidy()'); + } + fn = nameOrFn; + } + else { + // Called with 2 arguments. + if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { + throw new Error('When calling with two arguments, the first argument ' + + 'to tidy() must be a string'); + } + if (typeof fn !== 'function') { + throw new Error('When calling with two arguments, the 2nd argument ' + + 'to tidy() must be a function'); + } + name = nameOrFn; + // TODO(nsthorat,smilkov): Do operation logging and performance + // profiling. + } + var result; + return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () { + result = fn(); + if (result instanceof Promise) { + console.error('Cannot return a Promise inside of tidy.'); + } + return result; + }); + }; + Engine.prototype.scopedRun = function (start, end, f) { + start(); + try { + var res = f(); + end(); + return res; + } + catch (ex) { + end(); + throw ex; + } + }; + Engine.prototype.nextTensorId = function () { + return Engine.nextTensorId++; + }; + Engine.prototype.nextVariableId = function () { + return Engine.nextVariableId++; + }; + /** + * This method is called instead of the public-facing tensor.clone() when + * saving a tensor for backwards pass. It makes sure to add the clone + * operation to the tape regardless of being called inside a kernel + * execution. + * + * This method will go away once all kernels are modularized since we won't + * need to turn off the tape inside runKernel(). + */ + Engine.prototype.clone = function (x) { + var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); + var inputs = { x: x }; + var grad = function (dy) { return ({ x: function () { return dy.toFloat(); } }); }; + var saved = []; + this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); + return y; + }; + /** + * Execute a kernel with the given name and return the output tensor. + * + * @param kernelName The name of the kernel to execute. + * @param inputs A map of input names to tensors. + * @param attrs A map of attribute names to their values. An attribute is a + * primitive (non-tensor) input to the kernel. + * @param inputsToSave A list of tensors, inputs to save for the backprop + * computation. + * @param outputsToSave A list of booleans, specifying which output to save + * for the backprop computation. These are booleans since the output + * tensors are not visible to the user. + */ + Engine.prototype.runKernel = function (kernelName, inputs, attrs, inputsToSave, outputsToSave) { + var forwardFunc = null; + var backwardsFunc = null; + // Call runKernel as a stop-gap until we modularize all kernels. + // Once we modularize all kernels, we will remove the existing + // `runKernelFunc`. + return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); + }; + Engine.prototype.shouldCheckForMemLeaks = function () { + return this.ENV.getBool('IS_TEST'); + }; + Engine.prototype.checkKernelForMemLeak = function (kernelName, numDataIdsBefore, outInfos) { + var numDataIdsAfter = this.backend.numDataIds(); + // Count the number of data ids associated with the result of the kernel. + var numOutputDataIds = 0; + outInfos.forEach(function (info) { + // Complex numbers allocate 3 data ids, one for 'real', one for + // 'imaginary', and one for the container that holds the former two. + numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); + }); + // Account for the number of moves during kernel execution. A "data move" + // can happen in the middle of a kernel execution, placing a new (key,value) + // pair in the data storage. Since data moves have net zero effect (we + // always remove the data from the old backend), we have to cancel them out + // when detecting memory leaks. + var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; + var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; + if (dataIdsLeaked > 0) { + throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + + ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'")); + } + }; + /** + * @deprecated Use `runKernel` for newly added kernels. Keep using this method + * only for kernels that are not yet fully modularized. + */ + Engine.prototype.runKernelFunc = function (forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { + var _this = this; + var outputs; + var saved = []; + var isTapeOn = this.isTapeOn(); + if (kernelName == null) { + kernelName = + this.state.activeScope != null ? this.state.activeScope.name : ''; + } + var startingBytecount = this.state.numBytes; + var startingNumTensors = this.state.numTensors; + if (this.shouldCheckForMemLeaks()) { + this.state.numDataMovesStack.push(0); + } + var kernelFunc; + var kernel = getKernel(kernelName, this.backendName); + var out; + if (kernel != null) { + kernelFunc = function () { + var numDataIdsBefore = _this.backend.numDataIds(); + out = kernel.kernelFunc({ inputs: inputs, attrs: attrs, backend: _this.backend }); + var outInfos = Array.isArray(out) ? out : [out]; + if (_this.shouldCheckForMemLeaks()) { + _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); + } + var outTensors = outInfos.map(function (_a) { + var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype; + return _this.makeTensorFromDataId(dataId, shape, dtype); + }); + // Save the inputs and outputs. + // Do not save unless we are recording to the tape. Otherwise it would + // cause a mem leak since we would never run backprop, which disposes + // the kept tensors. + if (isTapeOn) { + var tensorsToSave = _this.getTensorsForGradient(kernelName, inputs, outTensors); + if (tensorsToSave == null) { + // Fallback for ops that call runKernelFunc and pass in + // inputsToSave and outputsToSave. Currently this is the set of ops + // with kernel support in the WASM backend. Once those ops and + // respective gradients are modularised we can remove this path. + if (outputsToSave == null) { + outputsToSave = []; + } + var outsToSave = outTensors.filter(function (_, i) { return outputsToSave[i]; }); + tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); + } + saved = _this.saveTensorsForBackwardMode(tensorsToSave); + } + return outTensors; + }; + } + else { + var saveFunc_1 = function (tensors) { + // Do not save unless we are recording to the tape. Otherwise it would + // cause a mem leak since we would never run backprop, which disposes + // the kept tensors. + if (!isTapeOn) { + return; + } + saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); + }; + kernelFunc = function () { + var numDataIdsBefore = _this.backend.numDataIds(); + out = _this.tidy(function () { return forwardFunc(_this.backend, saveFunc_1); }); + var outs = (Array.isArray(out) ? out : [out]); + if (_this.shouldCheckForMemLeaks()) { + _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); + } + return outs; + }; + } + // Stop recording to a tape when running a kernel. + this.scopedRun(function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () { + if (!_this.ENV.getBool('DEBUG')) { + outputs = kernelFunc(); + } + else { + outputs = _this.profiler.profileKernel(kernelName, inputs, function () { return kernelFunc(); }); + } + }); + if (isTapeOn) { + this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); + } + if (this.state.profiling) { + this.state.activeProfile.kernels.push({ + name: kernelName, + bytesAdded: this.state.numBytes - startingBytecount, + totalBytesSnapshot: this.state.numBytes, + tensorsAdded: this.state.numTensors - startingNumTensors, + totalTensorsSnapshot: this.state.numTensors, + inputShapes: Object.keys(inputs).map(function (key) { return inputs[key].shape; }), + outputShapes: outputs.map(function (item) { return item.shape; }) + }); + } + return (Array.isArray(out) ? outputs : outputs[0]); + }; + /** + * Saves tensors used in forward mode for use in backward mode. + * + * @param tensors the list of tensors to save. + */ + Engine.prototype.saveTensorsForBackwardMode = function (tensors) { + var _this = this; + var saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); + return saved; + }; + /** + * Returns a list of tensors to save for a given gradient calculation. + * + * Returns undefined if their is no registered gradient for this kernel in the + * gradient registry. + * + * @param kernelName name of kernel to look up gradient for. + * @param inputs a map of input tensors. + * @param outputs an array of output tensors from forward mode of kernel. + */ + Engine.prototype.getTensorsForGradient = function (kernelName, inputs, outputs) { + var gradConfig = getGradient(kernelName); + if (gradConfig != null) { + var inputsToSave = gradConfig.inputsToSave || []; + var outputsToSave_1 = gradConfig.outputsToSave || []; + // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs + // specified in inputsToSave will be saved. + var inputTensorsToSave = void 0; + if (gradConfig.saveAllInputs) { + assert(Array.isArray(inputs), function () { return 'saveAllInputs is true, expected inputs to be an array.'; }); + inputTensorsToSave = Object.keys(inputs).map(function (key) { return inputs[key]; }); + } + else { + inputTensorsToSave = inputsToSave.map(function (inputName) { return inputs[inputName]; }); + } + var outputTensorsToSave = outputs.filter(function (_, i) { return outputsToSave_1[i]; }); + return inputTensorsToSave.concat(outputTensorsToSave); + } + // TODO(yassogba) throw exception here once all runkernelFunc calls with + // inputsToSave/outputsToSave are removed + return null; + }; + /** + * Internal method used by public APIs for tensor creation. Makes a new + * tensor with the provided shape, dtype and values. It always + * creates a new data id and writes the values to the underlying backend. + */ + Engine.prototype.makeTensor = function (values, shape, dtype, backend) { + if (values == null) { + throw new Error('Values passed to engine.makeTensor() are null'); + } + dtype = dtype || 'float32'; + backend = backend || this.backend; + var backendVals = values; + if (dtype === 'string' && isString(values[0])) { + backendVals = values.map(function (d) { return encodeString(d); }); + } + var dataId = backend.write(backendVals, shape, dtype); + var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); + this.incRef(t, backend); + // Count bytes for string tensors. + if (dtype === 'string') { + var info = this.state.tensorInfo.get(dataId); + var newBytes = bytesFromStringArray(backendVals); + this.state.numBytes += newBytes - info.bytes; + info.bytes = newBytes; + } + return t; + }; + /** + * Internal method used by backends. Makes a new tensor + * that is a wrapper around an existing data id. It doesn't create + * a new data id, only increments the ref count used in memory tracking. + */ + Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) { + dtype = dtype || 'float32'; + var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); + this.incRef(t, backend); + return t; + }; + Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) { + if (trainable === void 0) { trainable = true; } + name = name || this.nextVariableId().toString(); + if (dtype != null && dtype !== initialValue.dtype) { + initialValue = initialValue.asType(dtype); + } + var v = new Variable(initialValue, trainable, name, this.nextTensorId()); + if (this.state.registeredVariables[v.name] != null) { + throw new Error("Variable with name " + v.name + " was already registered"); + } + this.state.registeredVariables[v.name] = v; + this.incRef(v, this.backend); + return v; + }; + Engine.prototype.incRef = function (a, backend) { + var refCount = this.state.tensorInfo.has(a.dataId) ? + this.state.tensorInfo.get(a.dataId).refCount : + 0; + this.state.numTensors++; + if (a.dtype === 'string') { + this.state.numStringTensors++; + } + if (refCount === 0) { + this.state.numDataBuffers++; + // Bytes for complex numbers are counted by their components. Bytes for + // string tensors are counted when writing values. + var bytes = 0; + if (a.dtype !== 'complex64' && a.dtype !== 'string') { + bytes = a.size * bytesPerElement(a.dtype); + } + this.state.tensorInfo.set(a.dataId, { + backend: backend || this.backend, + dtype: a.dtype, + shape: a.shape, + bytes: bytes, + refCount: 0 + }); + this.state.numBytes += bytes; + } + this.state.tensorInfo.get(a.dataId).refCount++; + if (!(a instanceof Variable)) { + this.track(a); + } + }; + Engine.prototype.disposeTensor = function (a) { + if (!this.state.tensorInfo.has(a.dataId)) { + return; + } + this.state.numTensors--; + if (a.dtype === 'string') { + this.state.numStringTensors--; + } + var info = this.state.tensorInfo.get(a.dataId); + var refCount = info.refCount; + if (refCount <= 1) { + // Don't count bytes for complex numbers as they are counted by their + // components. + if (a.dtype !== 'complex64') { + this.state.numBytes -= info.bytes; + } + this.state.numDataBuffers--; + info.backend.disposeData(a.dataId); + this.state.tensorInfo.delete(a.dataId); + } + else { + this.state.tensorInfo.get(a.dataId).refCount--; + } + // TODO(nsthorat): Construct an error and save the stack trace for + // debugging when in debug mode. Creating a stack trace is too expensive + // to do unconditionally. + }; + Engine.prototype.disposeVariables = function () { + for (var varName in this.state.registeredVariables) { + var v = this.state.registeredVariables[varName]; + this.disposeVariable(v); + } + }; + Engine.prototype.disposeVariable = function (v) { + this.disposeTensor(v); + if (this.state.registeredVariables[v.name] != null) { + delete this.state.registeredVariables[v.name]; + } + }; + Engine.prototype.memory = function () { + var info = this.backend.memory(); + info.numTensors = this.state.numTensors; + info.numDataBuffers = this.state.numDataBuffers; + info.numBytes = this.state.numBytes; + if (this.state.numStringTensors > 0) { + info.unreliable = true; + if (info.reasons == null) { + info.reasons = []; + } + info.reasons.push('Memory usage by string tensors is approximate ' + + '(2 bytes per character)'); + } + return info; + }; + Engine.prototype.profile = function (query) { + return __awaiter(this, void 0, void 0, function () { + var startBytes, startNumTensors; + return __generator(this, function (_a) { + this.state.profiling = true; + startBytes = this.state.numBytes; + startNumTensors = this.state.numTensors; + this.state.activeProfile.kernels = []; + this.state.activeProfile.result = query(); + this.state.profiling = false; + this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; })); + this.state.activeProfile.newBytes = this.state.numBytes - startBytes; + this.state.activeProfile.newTensors = + this.state.numTensors - startNumTensors; + return [2 /*return*/, this.state.activeProfile]; + }); + }); + }; + Engine.prototype.isTapeOn = function () { + return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; + }; + Engine.prototype.addTapeNode = function (kernelName, inputs, outputs, gradientsFunc, saved, attrs) { + var _this = this; + var tapeNode = { id: this.state.nextTapeNodeId++, kernelName: kernelName, inputs: inputs, outputs: outputs, saved: saved }; + var gradConfig = getGradient(kernelName); + if (gradConfig != null) { + gradientsFunc = gradConfig.gradFunc; + } + if (gradientsFunc != null) { + tapeNode.gradient = function (dys) { + // TODO(smilkov): To optimize back-prop, pass dys that are not used in + // the backprop graph to the user as null instead of zeros + dys = dys.map(function (dy, i) { + if (dy == null) { + var output = outputs[i]; + var vals = makeZerosTypedArray(output.size, output.dtype); + return _this.makeTensor(vals, output.shape, output.dtype); + } + return dy; + }); + // Grad functions of ops with single outputs expect a dy, while ops + // with multiple outputs expect dys (array of dy). + return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); + }; + } + this.state.activeTape.push(tapeNode); + }; + Engine.prototype.keep = function (result) { + result.kept = true; + return result; + }; + Engine.prototype.startTape = function () { + if (this.state.gradientDepth === 0) { + this.state.activeTape = []; + } + this.state.gradientDepth++; + }; + Engine.prototype.endTape = function () { + this.state.gradientDepth--; + }; + /** + * Start a scope. Use this with endScope() to achieve the same functionality + * as scope() without the need for a function closure. + */ + Engine.prototype.startScope = function (name) { + var scopeInfo = { + track: [], + name: 'unnamed scope', + id: this.state.nextScopeId++ + }; + if (name) { + scopeInfo.name = name; + } + this.state.scopeStack.push(scopeInfo); + this.state.activeScope = scopeInfo; + }; + /** + * End a scope. Use this with startScope() to achieve the same functionality + * as scope() without the need for a function closure. + */ + Engine.prototype.endScope = function (result) { + var _this = this; + var tensorsToTrackInParent = getTensorsInContainer(result); + var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; })); + // Dispose the arrays tracked in this scope. + for (var i = 0; i < this.state.activeScope.track.length; i++) { + var tensor = this.state.activeScope.track[i]; + if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { + tensor.dispose(); + } + } + var oldScope = this.state.scopeStack.pop(); + this.state.activeScope = this.state.scopeStack.length === 0 ? + null : + this.state.scopeStack[this.state.scopeStack.length - 1]; + // Track the current result in the parent scope. + tensorsToTrackInParent.forEach(function (tensor) { + // Only track the tensor if was allocated in the inner scope and is not + // globally kept. + if (!tensor.kept && tensor.scopeId === oldScope.id) { + _this.track(tensor); + } + }); + }; + /** + * Returns gradients of `f` with respect to each of the `xs`. The gradients + * returned are of the same length as `xs`, but some might be null if `f` + * was not a function of that `x`. It also takes optional dy to multiply the + * gradient, which defaults to `1`. + */ + Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) { + var _this = this; + if (allowNoGradients === void 0) { allowNoGradients = false; } + assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; }); + if (dy != null && dy.dtype !== 'float32') { + throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'"); + } + var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); }); + assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; }); + // Filter out the nodes that don't connect x => y. + var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); + if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { + throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + + 'that the f you passed encloses all operations that lead from x ' + + 'to y.'); + } + return this.tidy('backward', function () { + var accumulatedGradientMap = {}; + accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; + // Backprop gradients through the filtered nodes. + backpropagateGradients(accumulatedGradientMap, filteredTape, + // Pass the tidy function to avoid circular dep with `tape.ts`. + function (f) { return _this.tidy(f); }); + var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; }); + if (_this.state.gradientDepth === 0) { + // This means that we are not computing higher-order gradients + // and can clean up the tape. + _this.state.activeTape.forEach(function (node) { + for (var _i = 0, _a = node.saved; _i < _a.length; _i++) { + var tensor = _a[_i]; + tensor.dispose(); + } + }); + _this.state.activeTape = null; + } + return { value: y, grads: grads }; + }); + }; + Engine.prototype.customGrad = function (f) { + var _this = this; + assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; }); + return function () { + var inputs = []; + for (var _i = 0; _i < arguments.length; _i++) { + inputs[_i] = arguments[_i]; + } + assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + + 'tensors'; }); + var res; + var inputMap = {}; + inputs.forEach(function (input, i) { + inputMap[i] = input; + }); + return _this.runKernelFunc(function (_, save) { + res = f.apply(void 0, inputs.concat([save])); + assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.value` is a tensor'; }); + assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.gradFunc` is a function.'; }); + return res.value; + }, inputMap, function (dy, saved) { + var gradRes = res.gradFunc(dy, saved); + var grads = Array.isArray(gradRes) ? gradRes : [gradRes]; + assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.gradFunc` is a function that returns ' + + 'the same number of tensors as inputs passed to f(...).'; }); + assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.gradFunc` is a function that returns ' + + 'a list of only tensors.'; }); + var gradMap = {}; + grads.forEach(function (grad, i) { + gradMap[i] = function () { return grad; }; + }); + return gradMap; + }); + }; + }; + Engine.prototype.readSync = function (dataId) { + // Route the read to the correct backend. + var info = this.state.tensorInfo.get(dataId); + return info.backend.readSync(dataId); + }; + Engine.prototype.read = function (dataId) { + // Route the read to the correct backend. + var info = this.state.tensorInfo.get(dataId); + return info.backend.read(dataId); + }; + Engine.prototype.time = function (query) { + return __awaiter(this, void 0, void 0, function () { + var start, timingInfo; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + start = now(); + return [4 /*yield*/, this.backend.time(query)]; + case 1: + timingInfo = _a.sent(); + timingInfo.wallMs = now() - start; + return [2 /*return*/, timingInfo]; + } + }); + }); + }; + /** + * Tracks a Tensor in the current scope to be automatically cleaned up + * when the current scope ends, and returns the value. + * + * @param result The Tensor to track in the current scope. + */ + Engine.prototype.track = function (result) { + if (this.state.activeScope != null) { + result.scopeId = this.state.activeScope.id; + this.state.activeScope.track.push(result); + } + return result; + }; + Object.defineProperty(Engine.prototype, "registeredVariables", { + get: function () { + return this.state.registeredVariables; + }, + enumerable: true, + configurable: true + }); + /** + * Resets the engine state. Removes all backends but does not remove + * registered backend factories. + */ + Engine.prototype.reset = function () { + // Make any pending promise obsolete. + this.pendingBackendInitId++; + this.state.dispose(); + this.ENV.reset(); + this.state = new EngineState(); + for (var backendName in this.registry) { + this.disposeRegisteredKernels(backendName); + this.registry[backendName].dispose(); + delete this.registry[backendName]; + } + this.backendName = null; + this.backendInstance = null; + this.pendingBackendInit = null; + }; + Engine.nextTensorId = 0; + Engine.nextVariableId = 0; + return Engine; + }()); + function ones(shape) { + var values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); + return ENGINE.makeTensor(values, shape, 'float32'); + } + function getOrMakeEngine() { + var ns = getGlobalNamespace(); + if (ns._tfengine == null) { + var environment = new Environment(ns); + ns._tfengine = new Engine(environment); + } + setEnvironmentGlobal(ns._tfengine.ENV); + // Tell the current tensor interface that the global engine is responsible + // for tracking. + setTensorTracker(function () { return ns._tfengine; }); + return ns._tfengine; + } + var ENGINE = getOrMakeEngine(); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // tslint:disable-next-line:no-any + function _isNavigatorDefined() { + return typeof navigator !== 'undefined' && navigator != null; + } + function isMobile() { + if (_isNavigatorDefined()) { + // tslint:disable-next-line:no-any + var a = navigator.userAgent || navigator.vendor || window.opera; + // tslint:disable-next-line:max-line-length + return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i + .test(a) || + // tslint:disable-next-line:max-line-length + /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i + .test(a.substr(0, 4)); + } + return false; + } + function isBrowser() { + return (typeof window !== 'undefined' && window.document != null) || + //@ts-ignore + (typeof WorkerGlobalScope !== 'undefined'); + } + + var device_util = { + __proto__: null, + isMobile: isMobile, + isBrowser: isBrowser + }; + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ENV = env(); + /** + * This file contains environment-related flag registrations. + */ + /** Whether to enable debug mode. */ + ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) { + if (debugValue) { + console.warn('Debugging mode is ON. The output of every math call will ' + + 'be downloaded to CPU and checked for NaNs. ' + + 'This significantly impacts performance.'); + } + }); + /** Whether we are in a browser (as versus, say, node.js) environment. */ + ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); }); + /** Whether we are in a browser (as versus, say, node.js) environment. */ + ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') && + (typeof process.versions !== 'undefined') && + (typeof process.versions.node !== 'undefined'); }); + /** Whether this browser is Chrome. */ + ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null && + navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && + /Google Inc/.test(navigator.vendor); }); + /** + * True when the environment is "production" where we disable safety checks + * to gain performance. + */ + ENV.registerFlag('PROD', function () { return false; }); + /** + * Whether to do sanity checks when inferring a shape from user-provided + * values, used when creating a new tensor. + */ + ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); }); + /** Whether deprecation warnings are enabled. */ + ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; }); + /** True if running unit tests. */ + ENV.registerFlag('IS_TEST', function () { return false; }); + + var Add = 'Add'; + var AddN = 'AddN'; + var AvgPool = 'AvgPool'; + var AvgPoolBackprop = 'AvgPoolBackprop'; + var AvgPool3D = 'AvgPool3D'; + var AvgPool3DBackprop = 'AvgPool3DBackprop'; + var BatchMatMul = 'BatchMatMul'; + var BatchToSpaceND = 'BatchToSpaceND'; + var BroadcastTo = 'BroadcastTo'; + var Concat = 'Concat'; + var Conv2D = 'Conv2D'; + var Conv2DBackpropFilter = 'Conv2DBackpropFilter'; + var Conv2DBackpropInput = 'Conv2DBackpropInput'; + var Conv3D = 'Conv3D'; + var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; + var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; + var Cumsum = 'Cumsum'; + var DepthToSpace = 'DepthToSpace'; + var DepthwiseConv2dNative = 'DepthwiseConv2dNative'; + var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; + var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; + var Diag = 'Diag'; + var Div = 'Div'; + var Equal = 'Equal'; + var FusedBatchNorm = 'FusedBatchNorm'; + var Greater = 'Greater'; + var GreaterEqual = 'GreaterEqual'; + var Identity = 'Identity'; + var Less = 'Less'; + var LessEqual = 'LessEqual'; + var LRN = 'LRN'; + var LRNBackprop = 'LRNBackprop'; + var MaxPool = 'MaxPool'; + var MaxPoolBackprop = 'MaxPoolBackprop'; + var MaxPool3D = 'MaxPool3D'; + var MaxPool3DBackprop = 'MaxPool3DBackprop'; + var MaxPoolWithArgmax = 'MaxPoolWithArgmax'; + var NotEqual = 'NotEqual'; + var NonMaxSuppressionV3 = 'NonMaxSuppressionV3'; + var NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; + var Max = 'Max'; + var OneHot = 'OneHot'; + var PadV2 = 'PadV2'; + var Pool = 'Pool'; + var SpaceToBatchND = 'SpaceToBatchND'; + var SplitV = 'SplitV'; + var SquaredDifference = 'SquaredDifference'; + var Square = 'Square'; + var Sub = 'Sub'; + var Tile = 'Tile'; + var Transpose = 'Transpose'; + /** + * TensorFlow.js-only kernels + */ + var FromPixels = 'FromPixels'; + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns the dimensions in the input shape that are broadcasted to + * produce the provided output shape. + * + * The returned dimensions are 0-indexed and sorted. An example: + * inShape = [4, 1, 3] + * outShape = [5, 4, 3, 3] + * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3. + */ + function getBroadcastDims(inShape, outShape) { + var inRank = inShape.length; + var dims = []; + for (var i = 0; i < inRank; i++) { + var dim = inRank - 1 - i; + var a = inShape[dim] || 1; + var b = outShape[outShape.length - 1 - i] || 1; + if (b > 1 && a === 1) { + dims.unshift(dim); + } + } + return dims; + } + /** + * Returns the axes in the output space that should be reduced to produce + * the input space. + */ + function getReductionAxes(inShape, outShape) { + var result = []; + for (var i = 0; i < outShape.length; i++) { + var inDim = inShape[inShape.length - i - 1]; + var outAxis = outShape.length - i - 1; + var outDim = outShape[outAxis]; + if (inDim == null || (inDim === 1 && outDim > 1)) { + result.unshift(outAxis); + } + } + return result; + } + function assertAndGetBroadcastShape(shapeA, shapeB) { + var result = []; + var l = Math.max(shapeA.length, shapeB.length); + for (var i = 0; i < l; i++) { + var a = shapeA[shapeA.length - i - 1]; + if (a == null) { + a = 1; + } + var b = shapeB[shapeB.length - i - 1]; + if (b == null) { + b = 1; + } + if (a === 1) { + result.unshift(b); + } + else if (b === 1) { + result.unshift(a); + } + else if (a !== b) { + var errMsg = "Operands could not be broadcast together with shapes " + + (shapeA + " and " + shapeB + "."); + throw Error(errMsg); + } + else { + result.unshift(a); + } + } + return result; + } + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var addGradConfig = { + kernelName: Add, + inputsToSave: ['a', 'b'], + gradFunc: function (dy, saved) { + var a = saved[0], b = saved[1]; + var outShape = assertAndGetBroadcastShape(a.shape, b.shape); + var derA = function () { + var res = dy; + var reduceAxes = getReductionAxes(a.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape(a.shape); + }; + var derB = function () { + var res = dy; + var reduceAxes = getReductionAxes(b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape(b.shape); + }; + return { a: derA, b: derB }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var addNGradConfig = { + kernelName: AddN, + saveAllInputs: true, + gradFunc: function (dy, saved) { + var ders = {}; + saved.forEach(function (_, i) { + ders[i] = function () { return dy.clone(); }; + }); + return ders; + } + }; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function inferShape(val, dtype) { + var firstElem = val; + if (isTypedArray(val)) { + return dtype === 'string' ? [] : [val.length]; + } + if (!Array.isArray(val)) { + return []; // Scalar. + } + var shape = []; + while (Array.isArray(firstElem) || + isTypedArray(firstElem) && dtype !== 'string') { + shape.push(firstElem.length); + firstElem = firstElem[0]; + } + if (Array.isArray(val) && + env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { + deepAssertShapeConsistency(val, shape, []); + } + return shape; + } + function deepAssertShapeConsistency(val, shape, indices) { + indices = indices || []; + if (!(Array.isArray(val)) && !isTypedArray(val)) { + assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " + + ("but should be an array/TypedArray of " + shape[0] + " elements"); }); + return; + } + assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " + + ("but is an array of " + val.length + " elements"); }); + assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " + + ("elements, but has " + val.length + " elements"); }); + var subShape = shape.slice(1); + for (var i = 0; i < val.length; ++i) { + deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); + } + } + function assertDtype(expectedDtype, actualDType, argName, functionName) { + if (expectedDtype == null) { + return; + } + if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || + expectedDtype === 'numeric' && actualDType === 'string') { + throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor")); + } + } + function convertToTensor(x, argName, functionName, parseAsDtype) { + if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } + if (x instanceof Tensor) { + assertDtype(parseAsDtype, x.dtype, argName, functionName); + return x; + } + var inferredDtype = inferDtype(x); + // If the user expects a bool/int/float, use that info to update the + // inferredDtype when it is not a string. + if (inferredDtype !== 'string' && + ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { + inferredDtype = parseAsDtype; + } + assertDtype(parseAsDtype, inferredDtype, argName, functionName); + if ((x == null) || + (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && + typeof x !== 'boolean' && typeof x !== 'string')) { + var type = x == null ? 'null' : x.constructor.name; + throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + + ("Tensor or TensorLike, but got '" + type + "'")); + } + var inferredShape = inferShape(x, inferredDtype); + if (!isTypedArray(x) && !Array.isArray(x)) { + x = [x]; + } + var skipTypedArray = true; + var values = inferredDtype !== 'string' ? + toTypedArray(x, inferredDtype, env().getBool('DEBUG')) : + flatten(x, [], skipTypedArray); + return ENGINE.makeTensor(values, inferredShape, inferredDtype); + } + function convertToTensorArray(arg, argName, functionName, parseAsDtype) { + if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } + if (!Array.isArray(arg)) { + throw new Error("Argument " + argName + " passed to " + functionName + " must be a " + + '`Tensor[]` or `TensorLike[]`'); + } + var tensors = arg; + return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName); }, parseAsDtype); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function assertParamsConsistent(shapes, axis) { + var rank = shapes[0].length; + shapes.forEach(function (shape, i) { + assert(shape.length === rank, function () { + return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " + + ("as the rank of the rest (" + rank + ")"); + }); + }); + assert(axis >= 0 && axis < rank, function () { return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + "."; }); + var firstShape = shapes[0]; + shapes.forEach(function (shape, i) { + for (var r = 0; r < rank; r++) { + assert((r === axis) || (shape[r] === firstShape[r]), function () { return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " + + ("does not match the shape of the rest (" + firstShape + ") ") + + ("along the non-concatenated axis " + i + "."); }); + } + }); + } + function computeOutShape(shapes, axis) { + var outputShape = shapes[0].slice(); + for (var i = 1; i < shapes.length; i++) { + outputShape[axis] += shapes[i][axis]; + } + return outputShape; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Used for wrapping functions that perform math operations on + * Tensors. The function will be wrapped in a named scope that cleans all + * memory usage after the function is done. + */ + function op(f) { + var keys = Object.keys(f); + if (keys.length !== 1) { + throw new Error("Please provide an object with a single key " + + "(operation name) mapping to a function. Got an object with " + + (keys.length + " keys.")); + } + var opName = keys[0]; + var fn = f[opName]; + // Strip the underscore from the end of the function name. + if (opName.endsWith('_')) { + opName = opName.substring(0, opName.length - 1); + } + // tslint:disable-next-line:no-any + var f2 = function () { + var args = []; + for (var _i = 0; _i < arguments.length; _i++) { + args[_i] = arguments[_i]; + } + ENGINE.startScope(opName); + try { + var result = fn.apply(void 0, args); + if (result instanceof Promise) { + console.error('Cannot return a Promise inside of tidy.'); + } + ENGINE.endScope(result); + return result; + } + catch (ex) { + ENGINE.endScope(null); + throw ex; + } + }; + Object.defineProperty(f2, 'name', { value: opName, configurable: true }); + // tslint:disable-next-line:no-any + return f2; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Converts two real numbers to a complex number. + * + * Given a tensor `real` representing the real part of a complex number, and a + * tensor `imag` representing the imaginary part of a complex number, this + * operation returns complex numbers elementwise of the form [r0, i0, r1, i1], + * where r represents the real part and i represents the imag part. + * + * The input tensors real and imag must have the same shape. + * + * ```js + * const real = tf.tensor1d([2.25, 3.25]); + * const imag = tf.tensor1d([4.75, 5.75]); + * const complex = tf.complex(real, imag); + * + * complex.print(); + * ``` + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function complex_(real, imag) { + var $real = convertToTensor(real, 'real', 'complex'); + var $imag = convertToTensor(imag, 'imag', 'complex'); + assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " + + "must match in call to tf.complex()."); + return ENGINE.runKernelFunc(function (backend) { return backend.complex($real, $imag); }, { $real: $real, $imag: $imag }); + } + /** + * Returns the real part of a complex (or real) tensor. + * + * Given a tensor input, this operation returns a tensor of type float that is + * the real part of each element in input considered as a complex number. + * + * If the input is real, it simply makes a clone. + * + * ```js + * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); + * tf.real(x).print(); + * ``` + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function real_(input) { + var $input = convertToTensor(input, 'input', 'real'); + return ENGINE.runKernelFunc(function (backend) { return backend.real($input); }, { $input: $input }); + } + /** + * Returns the imaginary part of a complex (or real) tensor. + * + * Given a tensor input, this operation returns a tensor of type float that is + * the imaginary part of each element in input considered as a complex number. + * If input is real, a tensor of all zeros is returned. + * + * ```js + * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); + * tf.imag(x).print(); + * ``` + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function imag_(input) { + var $input = convertToTensor(input, 'input', 'imag'); + return ENGINE.runKernelFunc(function (backend) { return backend.imag($input); }, { $input: $input }); + } + var complex = op({ complex_: complex_ }); + var real = op({ real_: real_ }); + var imag = op({ imag_: imag_ }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Creates a `tf.Tensor` with the provided values, shape and dtype. + * + * ```js + * // Pass an array of values to create a vector. + * tf.tensor([1, 2, 3, 4]).print(); + * ``` + * + * ```js + * // Pass a nested array of values to make a matrix or a higher + * // dimensional tensor. + * tf.tensor([[1, 2], [3, 4]]).print(); + * ``` + * + * ```js + * // Pass a flat array and specify a shape yourself. + * tf.tensor([1, 2, 3, 4], [2, 2]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. If the values are strings, + * they will be encoded as utf-8 and kept as `Uint8Array[]`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor(values, shape, dtype) { + var inferredShape = inferShape(values, dtype); + return makeTensor(values, shape, inferredShape, dtype); + } + /** This is shared code across all tensor creation methods. */ + function makeTensor(values, shape, inferredShape, dtype) { + if (dtype == null) { + dtype = inferDtype(values); + } + if (dtype === 'complex64') { + throw new Error("Cannot construct a complex64 tensor directly. " + + "Please use tf.complex(real, imag)."); + } + if (!isTypedArray(values) && !Array.isArray(values) && + typeof values !== 'number' && typeof values !== 'boolean' && + typeof values !== 'string') { + throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + + 'an array of numbers/booleans/strings, or a TypedArray'); + } + if (shape != null) { + assertNonNegativeIntegerDimensions(shape); + var providedSize_1 = sizeFromShape(shape); + var inferredSize_1 = sizeFromShape(inferredShape); + assert(providedSize_1 === inferredSize_1, function () { + return "Based on the provided shape, [" + shape + "], the tensor should have " + + (providedSize_1 + " values but has " + inferredSize_1); + }); + for (var i = 0; i < inferredShape.length; ++i) { + var inferred = inferredShape[i]; + var flatDimsDontMatch = i === inferredShape.length - 1 ? + inferred !== sizeFromShape(shape.slice(i)) : + true; + assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " + + ("(" + inferredShape + ") does not match the provided ") + + ("shape (" + shape + "). "); }); + } + } + if (!isTypedArray(values) && !Array.isArray(values)) { + values = [values]; + } + shape = shape || inferredShape; + values = dtype !== 'string' ? + toTypedArray(values, dtype, env().getBool('DEBUG')) : + flatten(values, [], true); + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.scalar` as it makes the code more readable. + * + * ```js + * tf.scalar(3.14).print(); + * ``` + * + * @param value The value of the scalar. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function scalar(value, dtype) { + if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) && + dtype !== 'complex64') { + throw new Error('Error creating a new Scalar: value must be a primitive ' + + '(number|boolean|string)'); + } + if (dtype === 'string' && isTypedArray(value) && + !(value instanceof Uint8Array)) { + throw new Error('When making a scalar from encoded string, ' + + 'the value must be `Uint8Array`.'); + } + var shape = []; + var inferredShape = []; + return makeTensor(value, shape, inferredShape, dtype); + } + /** + * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor1d` as it makes the code more readable. + * + * ```js + * tf.tensor1d([1, 2, 3]).print(); + * ``` + * + * @param values The values of the tensor. Can be array of numbers, + * or a `TypedArray`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor1d(values, dtype) { + assertNonNull(values); + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 1) { + throw new Error('tensor1d() requires values to be a flat/TypedArray'); + } + var shape = null; + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor2d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor2d([[1, 2], [3, 4]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. If not provided, it is inferred from + * `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor2d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 2) { + throw new Error('tensor2d() requires shape to have two numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 2 && inferredShape.length !== 1) { + throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor2d() requires shape to be provided when `values` ' + + 'are a flat/TypedArray'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor3d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor3d([[[1], [2]], [[3], [4]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. If not provided, it is inferred from + * `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor3d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 3) { + throw new Error('tensor3d() requires shape to have three numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 3 && inferredShape.length !== 1) { + throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor3d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor4d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor4d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 4) { + throw new Error('tensor4d() requires shape to have four numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 4 && inferredShape.length !== 1) { + throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor4d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor5d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor5d([[[[[1], [2]], [[3], [4]]]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor5d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 5) { + throw new Error('tensor5d() requires shape to have five numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 5 && inferredShape.length !== 1) { + throw new Error('tensor5d() requires values to be ' + + 'number[][][][][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor5d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor6d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor6d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 6) { + throw new Error('tensor6d() requires shape to have six numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 6 && inferredShape.length !== 1) { + throw new Error('tensor6d() requires values to be number[][][][][][] or ' + + 'flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor6d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + shape = shape || + inferredShape; + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates a new variable with the provided initial value. + * ```js + * const x = tf.variable(tf.tensor([1, 2, 3])); + * x.assign(tf.tensor([4, 5, 6])); + * + * x.print(); + * ``` + * + * @param initialValue Initial value for the tensor. + * @param trainable If true, optimizers are allowed to update it. + * @param name Name of the variable. Defaults to a unique id. + * @param dtype If set, initialValue will be converted to the given type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function variable(initialValue, trainable, name, dtype) { + if (trainable === void 0) { trainable = true; } + return ENGINE.makeVariable(initialValue, trainable, name, dtype); + } + /** + * Creates a `tf.Tensor` with all elements set to 1. + * + * ```js + * tf.ones([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The type of an element in the resulting tensor. Defaults to + * 'float'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function ones$1(shape, dtype) { + if (dtype === void 0) { dtype = 'float32'; } + if (dtype === 'complex64') { + var real_1 = ones$1(shape, 'float32'); + var imag_1 = zeros(shape, 'float32'); + return complex(real_1, imag_1); + } + var values = makeOnesTypedArray(sizeFromShape(shape), dtype); + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates a `tf.Tensor` with all elements set to 0. + * + * ```js + * tf.zeros([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The type of an element in the resulting tensor. Can + * be 'float32', 'int32' or 'bool'. Defaults to 'float'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function zeros(shape, dtype) { + if (dtype === void 0) { dtype = 'float32'; } + if (dtype === 'complex64') { + var real_2 = zeros(shape, 'float32'); + var imag_2 = zeros(shape, 'float32'); + return complex(real_2, imag_2); + } + var values = makeZerosTypedArray(sizeFromShape(shape), dtype); + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates a `tf.Tensor` filled with a scalar value. + * + * ```js + * tf.fill([2, 2], 4).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param value The scalar value to fill the tensor with. + * @param dtype The type of an element in the resulting tensor. Defaults to + * 'float'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function fill(shape, value, dtype) { + return ENGINE.runKernelFunc(function (backend) { return backend.fill(shape, value, dtype); }, {}); + } + /** + * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the + * given tensor. + * + * ```js + * const x = tf.tensor([1, 2]); + * tf.onesLike(x).print(); + * ``` + * @param x A tensor. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function onesLike_(x) { + var $x = convertToTensor(x, 'x', 'onesLike'); + if ($x.dtype === 'complex64') { + var r = onesLike(real($x)); + var i = zerosLike(imag($x)); + return complex(r, i); + } + var der = function (dy, saved) { return ({ x: function () { return zerosLike(dy); } }); }; + return ENGINE.runKernelFunc(function (backend) { return backend.onesLike($x); }, { x: $x }, der, 'OnesLike'); + } + /** + * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the + * given tensor. + * + * ```js + * const x = tf.tensor([1, 2]); + * tf.zerosLike(x).print(); + * ``` + * + * @param x The tensor of required shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function zerosLike_(x) { + var $x = convertToTensor(x, 'x', 'zerosLike'); + var der = function (dy, saved) { return ({ x: function () { return zerosLike(dy); } }); }; + return ENGINE.runKernelFunc(function (backend) { return backend.zerosLike($x); }, { x: $x }, der, 'ZerosLike'); + } + /** + * Return an evenly spaced sequence of numbers over the given interval. + * + * ```js + * tf.linspace(0, 9, 10).print(); + * ``` + * @param start The start value of the sequence. + * @param stop The end value of the sequence. + * @param num The number of values to generate. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function linspace(start, stop, num) { + if (num <= 0) { + throw new Error('The number of values should be positive.'); + } + return ENGINE.runKernelFunc(function (backend) { return backend.linspace(start, stop, num); }, {}); + } + /** + * Creates a new `tf.Tensor1D` filled with the numbers in the range provided. + * + * The tensor is a is half-open interval meaning it includes start, but + * excludes stop. Decrementing ranges and negative step values are also + * supported. + * + * ```js + * tf.range(0, 9, 2).print(); + * ``` + * + * @param start An integer start value + * @param stop An integer stop value + * @param step An integer increment (will default to 1 or -1) + * @param dtype The data type of the output tensor. Defaults to 'float32'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function range(start, stop, step, dtype) { + if (step === void 0) { step = 1; } + if (dtype === void 0) { dtype = 'float32'; } + if (step === 0) { + throw new Error('Cannot have a step of zero'); + } + var sameStartStop = start === stop; + var increasingRangeNegativeStep = start < stop && step < 0; + var decreasingRangePositiveStep = stop < start && step > 1; + if (sameStartStop || increasingRangeNegativeStep || + decreasingRangePositiveStep) { + return zeros([0], dtype); + } + var numElements = Math.abs(Math.ceil((stop - start) / step)); + var values = makeZerosTypedArray(numElements, dtype); + if (stop < start && step === 1) { + // Auto adjust the step's sign if it hasn't been set + // (or was set to 1) + step = -1; + } + values[0] = start; + for (var i = 1; i < values.length; i++) { + values[i] = values[i - 1] + step; + } + return tensor1d(values, dtype); + } + var onesLike = op({ onesLike_: onesLike_ }); + var zerosLike = op({ zerosLike_: zerosLike_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Concatenates a list of `tf.Tensor`s along a given axis. + * + * The tensors ranks and types must match, and their sizes must match in all + * dimensions except `axis`. + * + * Also available are stricter rank-specific methods that assert that + * `tensors` are of the given rank: + * - `tf.concat1d` + * - `tf.concat2d` + * - `tf.concat3d` + * - `tf.concat4d` + * + * Except `tf.concat1d` (which does not have axis param), all methods have + * same signature as this method. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * a.concat(b).print(); // or a.concat(b) + * ``` + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * tf.concat([a, b, c]).print(); + * ``` + * + * ```js + * const a = tf.tensor2d([[1, 2], [10, 20]]); + * const b = tf.tensor2d([[3, 4], [30, 40]]); + * const axis = 1; + * tf.concat([a, b], axis).print(); + * ``` + * @param tensors A list of tensors to concatenate. + * @param axis The axis to concate along. Defaults to 0 (the first dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function concat_(tensors, axis) { + if (axis === void 0) { axis = 0; } + assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; }); + var $tensors = convertToTensorArray(tensors, 'tensors', 'concat'); + if ($tensors[0].dtype === 'complex64') { + $tensors.forEach(function (tensor) { + if (tensor.dtype !== 'complex64') { + throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". "); + } + }); + } + var $axis = parseAxisParam(axis, $tensors[0].shape)[0]; + var outShape = computeOutShape($tensors.map(function (t) { return t.shape; }), $axis); + if (sizeFromShape(outShape) === 0) { + return tensor([], outShape); + } + // Keep only non-empty tensors (ignore tensors with 0 in their shape). + $tensors = $tensors.filter(function (t) { return t.size > 0; }); + if ($tensors.length === 1) { + return $tensors[0]; + } + var shapes = $tensors.map(function (t) { return t.shape; }); + assertParamsConsistent(shapes, $axis); + var forward = function (backend, save) { + var $axis = parseAxisParam(axis, $tensors[0].shape)[0]; + var res = backend.concat($tensors, $axis); + save($tensors); + return res; + }; + var inputs = $tensors; + var attr = { axis: axis }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Concat, attr); + } + var concat = op({ concat_: concat_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Reshapes a `tf.Tensor` to a given shape. + * + * Given an input tensor, returns a new tensor with the same values as the + * input tensor with shape `shape`. + * + * If one component of shape is the special value -1, the size of that + * dimension is computed so that the total size remains constant. In + * particular, a shape of [-1] flattens into 1-D. At most one component of + * shape can be -1. + * + * If shape is 1-D or higher, then the operation returns a tensor with shape + * shape filled with the values of tensor. In this case, the number of + * elements implied by shape must be the same as the number of elements in + * tensor. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * x.reshape([2, 2]).print(); + * ``` + * + * @param x The input tensor to be reshaped. + * @param shape An array of integers defining the output tensor shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function reshape_(x, shape) { + var $x = convertToTensor(x, 'x', 'reshape', null); + shape = inferFromImplicitShape(shape, $x.size); + assert($x.size === sizeFromShape(shape), function () { return 'new shape and old shape must have the same number of elements.'; }); + var grad = function (dy) { + return { x: function () { return dy.reshape($x.shape); } }; + }; + var attrs = { shape: shape }; + return ENGINE.runKernelFunc(function (backend) { return backend.reshape($x, shape); }, { x: $x }, grad, 'Reshape', attrs); + } + /** + * Removes dimensions of size 1 from the shape of a `tf.Tensor`. + * + * ```js + * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]); + * x.squeeze().print(); + * ``` + * + * @param x The input tensor to be squeezed. + * @param axis An optional list of numbers. If specified, only + * squeezes the dimensions listed. The dimension index starts at 0. It + * is an error to squeeze a dimension that is not 1. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function squeeze_(x, axis) { + var $x = convertToTensor(x, 'x', 'squeeze'); + return reshape($x, squeezeShape($x.shape, axis).newShape); + } + /** + * Casts a `tf.Tensor` to a new dtype. + * + * ```js + * const x = tf.tensor1d([1.5, 2.5, 3]); + * tf.cast(x, 'int32').print(); + * ``` + * @param x The input tensor to be casted. + * @param dtype The dtype to cast the input tensor to. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function cast_(x, dtype) { + var $x = convertToTensor(x, 'x', 'cast'); + // Sanity checks. + if (!isValidDtype(dtype)) { + throw new Error("Failed to cast to unknown dtype " + dtype); + } + if (dtype === 'string' && $x.dtype !== 'string' || + dtype !== 'string' && $x.dtype === 'string') { + throw new Error('Only strings can be casted to strings'); + } + var grad = function (dy) { + return { x: function () { return dy.clone(); } }; + }; + var attrs = { dtype: dtype }; + return ENGINE.runKernelFunc(function (backend) { return backend.cast($x, dtype); }, { x: $x }, grad, 'Cast', attrs); + } + /** + * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * tf.stack([a, b, c]).print(); + * ``` + * + * @param tensors A list of tensor objects with the same shape and dtype. + * @param axis The axis to stack along. Defaults to 0 (the first dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function stack_(tensors, axis) { + if (axis === void 0) { axis = 0; } + var $tensors = convertToTensorArray(tensors, 'tensors', 'stack'); + assert($tensors.length >= 1, function () { return 'Pass at least one tensor to tf.stack'; }); + if ($tensors.length === 1) { + return $tensors[0].expandDims(axis); + } + var rank = $tensors[0].rank; + var shape = $tensors[0].shape; + var dtype = $tensors[0].dtype; + assert(axis <= rank, function () { return 'Axis must be <= rank of the tensor'; }); + $tensors.forEach(function (t) { + assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes'); + }); + $tensors.forEach(function (t) { + assert(dtype === t.dtype, function () { return 'All tensors passed to stack must have matching dtypes'; }); + }); + var expandedTensors = $tensors.map(function (t) { return t.expandDims(axis); }); + return concat(expandedTensors, axis); + } + /** + * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * tf.unstack(a).forEach(tensor => tensor.print()); + * ``` + * + * @param x A tensor object. + * @param axis The axis to unstack along. Defaults to 0 (the first dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function unstack_(x, axis) { + if (axis === void 0) { axis = 0; } + axis = axis || 0; + var $x = convertToTensor(x, 'x', 'unstack'); + assert(axis >= -$x.shape.length && axis < $x.shape.length, function () { + return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")"; + }); + if (axis < 0) { + axis += $x.shape.length; + } + var grad = function (dy) { + return { x: function () { return stack(dy, axis); } }; + }; + var attrs = { axis: axis }; + return ENGINE.runKernelFunc(function (backend) { return backend.unstack($x, axis); }, { x: $x }, grad, 'Unpack', attrs); + } + /** + * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension + * into the tensor's shape. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const axis = 1; + * x.expandDims(axis).print(); + * ``` + * + * @param x The input tensor whose dimensions to be expanded. + * @param axis The dimension index at which to insert shape of `1`. Defaults + * to 0 (the first dimension). + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function expandDims_(x, axis) { + if (axis === void 0) { axis = 0; } + var parseAs = null; + var $x = convertToTensor(x, 'x', 'expandDims', parseAs); + assert(axis <= $x.rank, function () { return 'Axis must be <= rank of the tensor'; }); + var newShape = $x.shape.slice(); + if (axis < 0) { + // Negative value is counted from the tail of rank. + assert(-($x.rank + 1) <= axis, function () { return "Axis must be in the interval [" + -($x.rank + 1) + ", " + $x.rank + "]"; }); + axis = $x.rank + axis + 1; + } + newShape.splice(axis, 0, 1); + return reshape($x, newShape); + } + /** + * Computes the difference between two lists of numbers. + * + * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out` + * that represents all values that are in `x` but not in `y`. The returned + * Tensor `out` is sorted in the same order that the numbers appear in `x` + * (duplicates are preserved). This operation also returns a Tensor indices that + * represents the position of each out element in `x`. In other words: + * + * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]` + * + * ```js + * const x = [1, 2, 3, 4, 5, 6]; + * const y = [1, 3, 5]; + * + * const [out, indices] = await tf.setdiff1dAsync(x, y); + * out.print(); // [2, 4, 6] + * indices.print(); // [1, 3, 5] + * ``` + * + * @param x 1-D Tensor. Values to keep. + * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the + * output. + * @returns Promise of Tensor tuple [out, indices]. + * out: Tensor with the same type as x. + * indices: A Tensor of type int32. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function setdiff1dAsync_(x, y) { + return __awaiter(this, void 0, void 0, function () { + var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, i, p; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + $x = convertToTensor(x, 'x', 'setdiff1d'); + $y = convertToTensor(y, 'y', 'setdiff1d'); + assert($x.dtype === $y.dtype, function () { return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ")."; }); + assert($x.rank === 1, function () { return "x should be 1D tensor, but got x (" + $x.shape + ")."; }); + assert($y.rank === 1, function () { return "y should be 1D tensor, but got y (" + $y.shape + ")."; }); + return [4 /*yield*/, $x.data()]; + case 1: + xVals = _a.sent(); + return [4 /*yield*/, $y.data()]; + case 2: + yVals = _a.sent(); + ySet = new Set(yVals); + outputSize = 0; + for (i = 0; i < xVals.length; i++) { + if (!ySet.has(xVals[i])) { + outputSize++; + } + } + buffer = new TensorBuffer([outputSize], $x.dtype); + indices = new TensorBuffer([outputSize], 'int32'); + for (i = 0, p = 0; i < xVals.length; i++) { + if (!ySet.has(xVals[i])) { + buffer.values[p] = xVals[i]; + indices.values[p] = i; + p++; + } + } + return [2 /*return*/, [buffer.toTensor(), indices.toTensor()]]; + } + }); + }); + } + /** + * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`. + * + * The values are stored in CPU as `TypedArray`. Fill the buffer using + * `buffer.set()`, or by modifying directly `buffer.values`. + * + * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with + * those values. + * + * ```js + * // Create a buffer and set values at particular indices. + * const buffer = tf.buffer([2, 2]); + * buffer.set(3, 0, 0); + * buffer.set(5, 1, 0); + * + * // Convert the buffer back to a tensor. + * buffer.toTensor().print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The dtype of the buffer. Defaults to 'float32'. + * @param values The values of the buffer as `TypedArray`. Defaults to + * zeros. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function buffer(shape, dtype, values) { + if (dtype === void 0) { dtype = 'float32'; } + dtype = dtype || 'float32'; + assertNonNegativeIntegerDimensions(shape); + return new TensorBuffer(shape, dtype, values); + } + /** + * Prints information about the `tf.Tensor` including its data. + * + * ```js + * const verbose = true; + * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose); + * ``` + * @param x The tensor to be printed. + * @param verbose Whether to print verbose information about the ` Tensor`, + * including dtype and size. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function print(x, verbose) { + if (verbose === void 0) { verbose = false; } + console.log(x.toString(verbose)); + } + var cast = op({ cast_: cast_ }); + var expandDims = op({ expandDims_: expandDims_ }); + var reshape = op({ reshape_: reshape_ }); + var squeeze = op({ squeeze_: squeeze_ }); + var stack = op({ stack_: stack_ }); + var unstack = op({ unstack_: unstack_ }); + var setdiff1dAsync = setdiff1dAsync_; + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { + if (dataFormat === void 0) { dataFormat = 'channelsLast'; } + var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1]; + var filterShape; + if (dataFormat === 'channelsLast') { + filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; + } + else if (dataFormat === 'channelsFirst') { + filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); + } + /** + * Computes the information for a forward pass of a pooling3D operation. + */ + function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { + if (dataFormat === void 0) { dataFormat = 'NDHWC'; } + var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2]; + var filterShape; + var $dataFormat; + if (dataFormat === 'NDHWC') { + $dataFormat = 'channelsLast'; + filterShape = + [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]]; + } + else if (dataFormat === 'NCDHW') { + $dataFormat = 'channelsFirst'; + filterShape = + [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode); + } + /** + * Computes the information for a forward pass of a convolution/pooling + * operation. + */ + function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) { + if (depthwise === void 0) { depthwise = false; } + if (dataFormat === void 0) { dataFormat = 'channelsLast'; } + var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3]; + if (dataFormat === 'channelsLast') { + batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3]; + } + else if (dataFormat === 'channelsFirst') { + batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3]; + var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1]; + var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1]; + var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); + var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); + var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth; + var outChannels = depthwise ? filterChannels * inChannels : filterChannels; + var outShape; + if (dataFormat === 'channelsFirst') { + outShape = [batchSize, outChannels, outHeight, outWidth]; + } + else if (dataFormat === 'channelsLast') { + outShape = [batchSize, outHeight, outWidth, outChannels]; + } + return { + batchSize: batchSize, + dataFormat: dataFormat, + inHeight: inHeight, + inWidth: inWidth, + inChannels: inChannels, + outHeight: outHeight, + outWidth: outWidth, + outChannels: outChannels, + padInfo: padInfo, + strideHeight: strideHeight, + strideWidth: strideWidth, + filterHeight: filterHeight, + filterWidth: filterWidth, + effectiveFilterHeight: effectiveFilterHeight, + effectiveFilterWidth: effectiveFilterWidth, + dilationHeight: dilationHeight, + dilationWidth: dilationWidth, + inShape: inShape, + outShape: outShape, + filterShape: filterShape + }; + } + /** + * Computes the information for a forward pass of a 3D convolution/pooling + * operation. + */ + function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) { + if (depthwise === void 0) { depthwise = false; } + if (dataFormat === void 0) { dataFormat = 'channelsLast'; } + var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4]; + if (dataFormat === 'channelsLast') { + batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4]; + } + else if (dataFormat === 'channelsFirst') { + batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4]; + var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2]; + var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2]; + var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); + var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); + var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); + var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth; + var outChannels = depthwise ? filterChannels * inChannels : filterChannels; + var outShape; + if (dataFormat === 'channelsFirst') { + outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; + } + else if (dataFormat === 'channelsLast') { + outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; + } + return { + batchSize: batchSize, + dataFormat: dataFormat, + inDepth: inDepth, + inHeight: inHeight, + inWidth: inWidth, + inChannels: inChannels, + outDepth: outDepth, + outHeight: outHeight, + outWidth: outWidth, + outChannels: outChannels, + padInfo: padInfo, + strideDepth: strideDepth, + strideHeight: strideHeight, + strideWidth: strideWidth, + filterDepth: filterDepth, + filterHeight: filterHeight, + filterWidth: filterWidth, + effectiveFilterDepth: effectiveFilterDepth, + effectiveFilterHeight: effectiveFilterHeight, + effectiveFilterWidth: effectiveFilterWidth, + dilationDepth: dilationDepth, + dilationHeight: dilationHeight, + dilationWidth: dilationWidth, + inShape: inShape, + outShape: outShape, + filterShape: filterShape + }; + } + function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) { + if (zeroPad == null) { + zeroPad = computeDefaultPad(inShape, fieldSize, stride); + } + var inputRows = inShape[0]; + var inputCols = inShape[1]; + var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + return [outputRows, outputCols]; + } + function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) { + if (zeroPad == null) { + zeroPad = computeDefaultPad(inShape, fieldSize, stride); + } + var inputDepth = inShape[0]; + var inputRows = inShape[1]; + var inputCols = inShape[2]; + var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputDepths), function () { return "The output # of depths (" + outputDepths + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + return [outputDepths, outputRows, outputCols, outChannels]; + } + function computeDefaultPad(inputShape, fieldSize, stride, dilation) { + if (dilation === void 0) { dilation = 1; } + var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); + return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); + } + function parseTupleParam(param) { + if (typeof param === 'number') { + return [param, param, param]; + } + if (param.length === 2) { + return [param[0], param[1], 1]; + } + return param; + } + function parse3TupleParam(param) { + return typeof param === 'number' ? [param, param, param] : param; + } + /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d + * Atrous convolution is equivalent to standard convolution with upsampled + * filters with effective_filter_height = + * filter_height + (filter_height - 1) * (dilation - 1) + * and effective_filter_width = + * filter_width + (filter_width - 1) * (dilation - 1), + * produced by inserting dilation - 1 zeros along consecutive elements across + * the filters' spatial dimensions. + * When there is a dilation, this converts a filter dimension to the + * effective filter dimension, so it can be used in a standard convolution. + */ + function getEffectiveFilterSize(filterSize, dilation) { + if (dilation <= 1) { + return filterSize; + } + return filterSize + (filterSize - 1) * (dilation - 1); + } + function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) { + var padInfo; + var outHeight; + var outWidth; + if (typeof pad === 'number') { + var padType = (pad === 0) ? 'VALID' : 'NUMBER'; + padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType }; + var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode); + outHeight = outShape[0]; + outWidth = outShape[1]; + } + else if (pad === 'same') { + outHeight = Math.ceil(inHeight / strideHeight); + outWidth = Math.ceil(inWidth / strideWidth); + var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight); + var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth); + var top_1 = Math.floor(padAlongHeight / 2); + var bottom = padAlongHeight - top_1; + var left = Math.floor(padAlongWidth / 2); + var right = padAlongWidth - left; + padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' }; + } + else if (pad === 'valid') { + padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' }; + outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); + outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); + } + else { + throw Error("Unknown padding parameter: " + pad); + } + return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth }; + } + function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) { + var padInfo; + var outDepth; + var outHeight; + var outWidth; + if (typeof pad === 'number') { + var padType = (pad === 0) ? 'VALID' : 'NUMBER'; + padInfo = { + top: pad, + bottom: pad, + left: pad, + right: pad, + front: pad, + back: pad, + type: padType + }; + var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode); + outDepth = outShape[0]; + outHeight = outShape[1]; + outWidth = outShape[2]; + } + else if (pad === 'same') { + outDepth = Math.ceil(inDepth / strideDepth); + outHeight = Math.ceil(inHeight / strideHeight); + outWidth = Math.ceil(inWidth / strideWidth); + var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; + var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; + var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; + var front = Math.floor(padAlongDepth / 2); + var back = padAlongDepth - front; + var top_2 = Math.floor(padAlongHeight / 2); + var bottom = padAlongHeight - top_2; + var left = Math.floor(padAlongWidth / 2); + var right = padAlongWidth - left; + padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' }; + } + else if (pad === 'valid') { + padInfo = { + top: 0, + bottom: 0, + left: 0, + right: 0, + front: 0, + back: 0, + type: 'VALID' + }; + outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); + outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); + outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); + } + else { + throw Error("Unknown padding parameter: " + pad); + } + return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth }; + } + /** + * Rounds a value depending on the rounding mode + * @param value + * @param roundingMode + */ + function conditionalRound(value, roundingMode) { + if (!roundingMode) { + return value; + } + switch (roundingMode) { + case 'round': + // used for Caffe Conv + return Math.round(value); + case 'ceil': + // used for Caffe Pool + return Math.ceil(value); + case 'floor': + return Math.floor(value); + default: + throw new Error("Unknown roundingMode " + roundingMode); + } + } + function tupleValuesAreOne(param) { + var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2]; + return dimA === 1 && dimB === 1 && dimC === 1; + } + function eitherStridesOrDilationsAreOne(strides, dilations) { + return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); + } + /** + * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to + * 'channelsLast'|'channelsFirst' + * @param dataFormat in 'NHWC'|'NCHW' mode + * @return dataFormat in 'channelsLast'|'channelsFirst' mode + * @throws unknown dataFormat + */ + function convertConv2DDataFormat(dataFormat) { + if (dataFormat === 'NHWC') { + return 'channelsLast'; + } + else if (dataFormat === 'NCHW') { + return 'channelsFirst'; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + } + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the backprop of a 3d avg pool. + * + * @param dy The dy error, of rank 5 of shape + * [batchSize, depth, height, width, channels]. + * assumed. + * @param input The original input image, of rank 5 or rank4 of shape + * [batchSize, depth, height, width, channels]. + * @param filterSize The filter size: + * `[filterDepth, filterHeight, filterWidth]`. + * `filterSize` is a single number, + * then `filterDepth == filterHeight == filterWidth`. + * @param strides The strides of the pooling: + * `[strideDepth, strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dilations Deprecated, this field will be gone in v3.0.0. The dilation + * rates: `[dilationDepth, dilationHeight, dilationWidth]` + * in which we sample input values across the depth, height and width + * dimensions in dilated pooling. + * Defaults to `[1, 1, 1]`. If `dilations` is a single number, + * then `dilationDepth == dilationHeight == dilationWidth`. + * If it is greater than 1, then all values of `strides` must be 1. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function avgPool3dBackprop_(dy, input, filterSize, strides, dilations, pad, dimRoundingMode) { + if (dilations === void 0) { dilations = [1, 1, 1]; } + var $dy = convertToTensor(dy, 'dy', 'avgPool3dBackprop'); + var $input = convertToTensor(input, 'input', 'avgPool3dBackprop'); + var dy5D = $dy; + var input5D = $input; + var reshapedTo5D = false; + if ($input.rank === 4) { + reshapedTo5D = true; + dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); + input5D = reshape($input, [ + 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] + ]); + } + assert(dy5D.rank === 5, function () { return "Error in avgPool3dBackprop: dy must be rank 5 but got rank " + + (dy5D.rank + "."); }); + assert(input5D.rank === 5, function () { return "Error in avgPool3dBackprop: input must be rank 5 but got rank " + + (input5D.rank + "."); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool3dBackprop: Either strides or dilations ' + + ("must be 1. Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPool3dBackprop: pad must be an integer when " + + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var forward = function (backend) { + var convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + return backend.avgPool3dBackprop(dy5D, input5D, convInfo); + }; + var inputs = { dy: dy5D, input: input5D }; + var attrs = { filterSize: filterSize, strides: strides, dilations: dilations, pad: pad, dimRoundingMode: dimRoundingMode }; + var res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, AvgPool3DBackprop, attrs); + if (reshapedTo5D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); + } + return res; + } + var avgPool3dBackprop = op({ avgPool3dBackprop_: avgPool3dBackprop_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var avgPool3DGradConfig = { + kernelName: AvgPool3D, + inputsToSave: ['x'], + gradFunc: function (dy, saved, attrs) { + var x = saved[0]; + var _a = attrs, filterSize = _a.filterSize, strides = _a.strides, dilations = _a.dilations, pad = _a.pad, dimRoundingMode = _a.dimRoundingMode; + var $dilations = dilations == null ? [1, 1, 1] : dilations; + return { + x: function () { return avgPool3dBackprop(dy, x, filterSize, strides, $dilations, pad, dimRoundingMode); } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the backprop of an 2D avg pool. + * + * @param dy The dy error, of rank 4 or rank 3 of shape + * [batchSize, height, width, channels]. If rank 3, batch of 1 is + * assumed. + * @param input The input image, of rank 4 or rank 3 of shape + * [batchSize, height, width, channels]. If rank 3, batch of 1 is + * assumed. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + */ + function avgPoolBackprop_(dy, input, filterSize, strides, pad) { + var $dy = convertToTensor(dy, 'dy', 'avgPoolBackprop'); + var $input = convertToTensor(input, 'input', 'avgPoolBackprop'); + assert($input.rank === $dy.rank, function () { return "Rank of input (" + $input.rank + ") does not match rank of dy (" + $dy.rank + ")"; }); + var input4D = $input; + var dy4D = $dy; + var reshapedTo4D = false; + if ($input.rank === 3) { + reshapedTo4D = true; + input4D = + reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]); + dy4D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]); + } + assert(dy4D.rank === 4, function () { return "Error in avgPoolBackprop: dy must be rank 4 but got rank " + + (dy4D.rank + "."); }); + assert(input4D.rank === 4, function () { return "Error in avgPoolBackprop: input must be rank 4 but got rank " + + (input4D.rank + "."); }); + var forward = function (backend) { + var convInfo = computePool2DInfo(input4D.shape, filterSize, strides, 1 /* dilations */, pad); + return backend.avgPoolBackprop(dy4D, input4D, convInfo); + }; + var inputs = { dy: dy4D, input: input4D }; + var attrs = { filterSize: filterSize, strides: strides, pad: pad }; + var res = ENGINE.runKernelFunc(forward, inputs, null, AvgPoolBackprop, attrs); + if (reshapedTo4D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]); + } + return res; + } + var avgPoolBackprop = op({ avgPoolBackprop_: avgPoolBackprop_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var avgPoolGradConfig = { + kernelName: AvgPool, + inputsToSave: ['x'], + gradFunc: function (dy, saved, attrs) { + var x = saved[0]; + var _a = attrs, filterSize = _a.filterSize, strides = _a.strides, pad = _a.pad; + return { + x: function () { return avgPoolBackprop(dy, x, filterSize, strides, pad); } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the dot product of two matrices, A * B. These must be matrices. + * + * ```js + * const a = tf.tensor2d([1, 2], [1, 2]); + * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.matMul(b).print(); // or tf.matMul(a, b) + * ``` + * @param a First matrix in dot product operation. + * @param b Second matrix in dot product operation. + * @param transposeA If true, `a` is transposed before multiplication. + * @param transposeB If true, `b` is transposed before multiplication. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function matMul_(a, b, transposeA, transposeB) { + var _a; + if (transposeA === void 0) { transposeA = false; } + if (transposeB === void 0) { transposeB = false; } + var $a = convertToTensor(a, 'a', 'matMul'); + var $b = convertToTensor(b, 'b', 'matMul'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, function () { return "Error in matMul: inputs must have the same rank of at least 2, " + + ("got ranks " + $a.rank + " and " + $b.rank + "."); }); + var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; + var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; + var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; + var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; + var outerDimsA = $a.shape.slice(0, -2); + var outerDimsB = $b.shape.slice(0, -2); + var batchDimA = sizeFromShape(outerDimsA); + var batchDimB = sizeFromShape(outerDimsB); + assert(arraysEqual(outerDimsA, outerDimsB), function () { return "Error in matMul: outer dimensions (" + outerDimsA + ") and (" + + (outerDimsB + ") of Tensors with shapes " + $a.shape + " and ") + + ($b.shape + " must match."); }); + assert(innerShapeA === innerShapeB, function () { return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + + ($b.shape + " and transposeA=" + transposeA) + + (" and transposeB=" + transposeB + " must match."); }); + var outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); + var a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : + reshape($a, [batchDimA, outerShapeA, innerShapeA]); + var b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : + reshape($b, [batchDimB, innerShapeB, outerShapeB]); + var forward = function (backend, save) { + save([a3D, b3D]); + return backend.batchMatMul(a3D, b3D, transposeA, transposeB); + }; + var inputs = { a: a3D, b: b3D }; + var attrs = { transposeA: transposeA, transposeB: transposeB }; + var res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, BatchMatMul, attrs); + return reshape(res, outShape); + } + var matMul = op({ matMul_: matMul_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var batchMatMulGradConfig = { + kernelName: BatchMatMul, + inputsToSave: ['a', 'b'], + gradFunc: function (dy, saved, attrs) { + var _a = saved, a = _a[0], b = _a[1]; + var _b = attrs, transposeA = _b.transposeA, transposeB = _b.transposeB; + if (!transposeA && !transposeB) { + return { + a: function () { return matMul(dy, b, false, true); }, + b: function () { return matMul(a, dy, true, false); } + }; + } + else if (!transposeA && transposeB) { + return { + a: function () { return matMul(dy, b, false, false); }, + b: function () { return matMul(dy, a, true, false); } + }; + } + else if (transposeA && !transposeB) { + return { + a: function () { return matMul(b, dy, false, true); }, + b: function () { return matMul(a, dy, false, false); } + }; + } + else { + return { + a: function () { return matMul(b, dy, true, true); }, + b: function () { return matMul(dy, a, true, true); } + }; + } + } + }; + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * This operation divides "spatial" dimensions `[1, ..., M]` of the input into + * a grid of blocks of shape `blockShape`, and interleaves these blocks with + * the "batch" dimension (0) such that in the output, the spatial + * dimensions `[1, ..., M]` correspond to the position within the grid, + * and the batch dimension combines both the position within a spatial block + * and the original batch position. Prior to division into blocks, + * the spatial dimensions of the input are optionally zero padded + * according to `paddings`. See below for a precise description. + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); + * const blockShape = [2, 2]; + * const paddings = [[0, 0], [0, 0]]; + * + * x.spaceToBatchND(blockShape, paddings).print(); + * ``` + * + * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + + * remainingShape`, where spatialShape has `M` dimensions. + * @param blockShape A 1-D array. Must have shape `[M]`, all values must + * be >= 1. + * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >= + * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad + * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It + * is required that + * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0` + * + * This operation is equivalent to the following steps: + * + * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input + * according to `paddings` to produce `padded` of shape paddedShape. + * + * 2. Reshape `padded` to `reshapedPadded` of shape: + * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ..., + * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape` + * + * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded` + * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ..., + * paddedShape[M] / blockShape[M-1]] + remainingShape` + * + * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the + * batch dimension, producing an output tensor of shape: + * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ..., + * paddedShape[M] / blockShape[M-1]] + remainingShape` + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function spaceToBatchND_(x, blockShape, paddings) { + var $x = convertToTensor(x, 'x', 'spaceToBatchND'); + assert($x.rank >= 1 + blockShape.length, function () { return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length; }); + assert(paddings.length === blockShape.length, function () { return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length; }); + assert($x.shape.reduce(function (a, b, i) { + if (i > 0 && i <= blockShape.length) { + return a && + ((b + paddings[i - 1][0] + paddings[i - 1][1]) % + blockShape[i - 1] === + 0); + } + return a; + }, true), function () { return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString(); }); + var forward = function (backend) { + return backend.spaceToBatchND($x, blockShape, paddings); + }; + var inputs = { x: $x }; + var attrs = { blockShape: blockShape, paddings: paddings }; + return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, SpaceToBatchND, attrs); + } + var spaceToBatchND = op({ spaceToBatchND_: spaceToBatchND_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var batchToSpaceNDGradConfig = { + kernelName: BatchToSpaceND, + gradFunc: function (dy, saved, attrs) { + var _a = attrs, blockShape = _a.blockShape, crops = _a.crops; + return { x: function () { return spaceToBatchND(dy, blockShape, crops); } }; + } + }; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the + * gradient of `f(x)` with respect to `x`. + * + * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to + * `x` is computed instead. `f(x)` must take a single tensor `x` and return a + * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead. + * + * ```js + * // f(x) = x ^ 2 + * const f = x => x.square(); + * // f'(x) = 2x + * const g = tf.grad(f); + * + * const x = tf.tensor1d([2, 3]); + * g(x).print(); + * ``` + * + * ```js + * // f(x) = x ^ 3 + * const f = x => x.pow(tf.scalar(3, 'int32')); + * // f'(x) = 3x ^ 2 + * const g = tf.grad(f); + * // f''(x) = 6x + * const gg = tf.grad(g); + * + * const x = tf.tensor1d([2, 3]); + * gg(x).print(); + * ``` + * + * @param f The function f(x), to compute gradient for. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function grad(f) { + assert(isFunction(f), function () { return 'The f passed in grad(f) must be a function'; }); + return function (x, dy) { + // x can be of any dtype, thus null as the last argument. + var $x = convertToTensor(x, 'x', 'tf.grad', null); + var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null; + return ENGINE.tidy(function () { + var _a = ENGINE.gradients(function () { return f($x); }, [$x], $dy), value = _a.value, grads = _a.grads; + if ($dy != null) { + assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + + 'returned by f(x)'); + } + checkGrads(grads); + return grads[0]; + }); + }; + } + /** + * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`, + * which gives an array of gradients of `f()` with respect to each input + * [`x1`,`x2`,...]. + * + * If `dy` is passed when calling `g()`, the gradient of + * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead. + * The provided `f` must take one or more tensors and return a single tensor + * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead. + * + * ```js + * // f(a, b) = a * b + * const f = (a, b) => a.mul(b); + * // df / da = b, df / db = a + * const g = tf.grads(f); + * + * const a = tf.tensor1d([2, 3]); + * const b = tf.tensor1d([-2, -3]); + * const [da, db] = g([a, b]); + * console.log('da'); + * da.print(); + * console.log('db'); + * db.print(); + * ``` + * + * @param f The function `f(x1, x2,...)` to compute gradients for. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function grads(f) { + assert(isFunction(f), function () { return 'The f passed in grads(f) must be a function'; }); + return function (args, dy) { + assert(Array.isArray(args), function () { return 'The args passed in grads(f)(args) must be an array ' + + 'of `Tensor`s or `TensorLike`s'; }); + // args can be of any dtype, thus null as the last argument. + var $args = convertToTensorArray(args, 'args', 'tf.grads', null); + var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null; + return ENGINE.tidy(function () { + var _a = ENGINE.gradients(function () { return f.apply(void 0, $args); }, $args, $dy), value = _a.value, grads = _a.grads; + if ($dy != null) { + assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + + 'match the shape returned by f([x1,...])'); + } + checkGrads(grads); + return grads; + }); + }; + } + /** + * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()` + * returns a metric you want to show. + * + * The result is a rich object with the following properties: + * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`). + * - value: The value returned by `f(x)`. + * + * ```js + * // f(x) = x ^ 2 + * const f = x => x.square(); + * // f'(x) = 2x + * const g = tf.valueAndGrad(f); + * + * const x = tf.tensor1d([2, 3]); + * const {value, grad} = g(x); + * + * console.log('value'); + * value.print(); + * console.log('grad'); + * grad.print(); + * ``` + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function valueAndGrad(f) { + assert(isFunction(f), function () { return 'The f passed in valueAndGrad(f) must be a function'; }); + return function (x, dy) { + assert(x instanceof Tensor, function () { return 'The x passed in valueAndGrad(f)(x) must be a tensor'; }); + assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'; }); + var _a = ENGINE.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value; + checkGrads(grads); + return { grad: grads[0], value: value }; + }; + } + /** + * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()` + * returns a metric you want to show. + * + * The result is a rich object with the following properties: + * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`). + * - value: The value returned by `f(x)`. + * + * ```js + * // f(a, b) = a * b + * const f = (a, b) => a.mul(b); + * // df/da = b, df/db = a + * const g = tf.valueAndGrads(f); + * + * const a = tf.tensor1d([2, 3]); + * const b = tf.tensor1d([-2, -3]); + * const {value, grads} = g([a, b]); + * + * const [da, db] = grads; + * + * console.log('value'); + * value.print(); + * + * console.log('da'); + * da.print(); + * console.log('db'); + * db.print(); + * ``` + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function valueAndGrads(f) { + assert(isFunction(f), function () { return 'The f passed in valueAndGrads(f) must be a function'; }); + return function (args, dy) { + assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof Tensor; }), function () { return 'The args passed in valueAndGrads(f)(args) must be array of ' + + 'tensors'; }); + assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'; }); + var res = ENGINE.gradients(function () { return f.apply(void 0, args); }, args, dy); + if (dy != null) { + assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + + 'match the shape returned by f([x1,...])'); + } + checkGrads(res.grads); + return res; + }; + } + /** + * Computes and returns the gradient of f(x) with respect to the list of + * trainable variables provided by `varList`. If no list is provided, it + * defaults to all trainable variables. + * + * ```js + * const a = tf.variable(tf.tensor1d([3, 4])); + * const b = tf.variable(tf.tensor1d([5, 6])); + * const x = tf.tensor1d([1, 2]); + * + * // f(a, b) = a * x ^ 2 + b * x + * const f = () => a.mul(x.square()).add(b.mul(x)).sum(); + * // df/da = x ^ 2, df/db = x + * const {value, grads} = tf.variableGrads(f); + * + * Object.keys(grads).forEach(varName => grads[varName].print()); + * ``` + * + * @param f The function to execute. f() should return a scalar. + * @param varList The list of variables to compute the gradients with respect + * to. Defaults to all trainable variables. + * @returns An object with the following keys and values: + * - `value`: The value of the function `f`. + * - `grads`: A map from the names of the variables to the gradients. + * If the `varList` argument is provided explicitly and contains a subset of + * non-trainable variables, this map in the return value will contain keys + * that map the names of the non-trainable variables to `null`. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function variableGrads(f, varList) { + assert(isFunction(f), function () { return 'The f passed in variableGrads(f) must be a function'; }); + assert(varList == null || + Array.isArray(varList) && varList.every(function (v) { return v instanceof Variable; }), function () { + return 'The varList passed in variableGrads(f, varList) must be an array ' + + 'of variables'; + }); + var specifiedVarList = varList != null; + if (!specifiedVarList) { + // Get all of the trainable variables. + varList = []; + for (var varName in ENGINE.registeredVariables) { + varList.push(ENGINE.registeredVariables[varName]); + } + } + var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) { return !variable.trainable; }) : null; + // Prune non-trainable variables. + var originalVarCount = varList.length; + varList = varList.filter(function (variable) { return variable.trainable; }); + assert(varList.length > 0, function () { return "variableGrads() expects at least one of the input variables to " + + ("be trainable, but none of the " + originalVarCount + " variables is ") + + "trainable."; }); + var allowNoGradients = true; + var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads; + assert(grads.some(function (g) { return g != null; }), function () { return 'Cannot find a connection between any variable and the result of ' + + 'the loss function y=f(x). Please make sure the operations that ' + + 'use variables are inside the function f passed to minimize().'; }); + assert(value.rank === 0, function () { return "The f passed in variableGrads(f) must return a scalar, but it " + + ("returned a rank-" + value.rank + " tensor"); }); + var namedGrads = {}; + varList.forEach(function (v, i) { + if (grads[i] != null) { + namedGrads[v.name] = grads[i]; + } + }); + if (specifiedNonTrainable != null) { + // If varList is explicitly provided and contains non-trainable values, + // add them to the returned gradients with `null` values. + specifiedNonTrainable.forEach(function (v) { return namedGrads[v.name] = null; }); + } + return { value: value, grads: namedGrads }; + } + /** + * Overrides the gradient computation of a function `f`. + * + * Takes a function + * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}` + * and returns another function `g(...inputs)` which takes the same inputs as + * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients + * with respect to each input of `f` are computed using `f().gradFunc`. + * + * The `save` function passsed to `f` should be used for saving tensors needed + * in the gradient. And the `saved` passed to the `gradFunc` is a + * `NamedTensorMap`, which contains those saved tensor. + * + * ```js + * const customOp = tf.customGrad((x, save) => { + * // Save x to make sure it's available later for the gradient. + * save([x]); + * // Override gradient of our custom x ^ 2 op to be dy * abs(x); + * return { + * value: x.square(), + * // Note `saved.x` which points to the `x` we saved earlier. + * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())] + * }; + * }); + * + * const x = tf.tensor1d([-1, -2, 3]); + * const dx = tf.grad(x => customOp(x)); + * + * console.log(`f(x):`); + * customOp(x).print(); + * console.log(`f'(x):`); + * dx(x).print(); + * ``` + * + * @param f The function to evaluate in forward mode, which should return + * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc` + * returns the custom gradients of `f` with respect to its inputs. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function customGrad(f) { + return ENGINE.customGrad(f); + } + function checkGrads(grads) { + var numNullGradients = grads.filter(function (g) { return g == null; }).length; + if (numNullGradients > 0) { + throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y."); + } + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns true if the axis specifies the inner most dimensions of the + * array. + */ + function axesAreInnerMostDims(axes, rank) { + for (var i = 0; i < axes.length; ++i) { + if (axes[axes.length - i - 1] !== rank - 1 - i) { + return false; + } + } + return true; + } + function combineLocations(outputLoc, reduceLoc, axes) { + var rank = outputLoc.length + reduceLoc.length; + var loc = []; + var outIdx = 0; + var reduceIdx = 0; + for (var dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + loc.push(outputLoc[outIdx++]); + } + else { + loc.push(reduceLoc[reduceIdx++]); + } + } + return loc; + } + function computeOutAndReduceShapes(aShape, axes) { + var outShape = []; + var rank = aShape.length; + for (var dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + outShape.push(aShape[dim]); + } + } + var reduceShape = axes.map(function (dim) { return aShape[dim]; }); + return [outShape, reduceShape]; + } + function expandShapeToKeepDim(shape, axes) { + var reduceSubShape = axes.map(function (x) { return 1; }); + return combineLocations(shape, reduceSubShape, axes); + } + function assertAxesAreInnerMostDims(msg, axes, rank) { + assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " + + ("Got axes " + axes + " and rank-" + rank + " input."); }); + } + /** + * Returns the axes permutation to be used with `tf.transpose`, if such + * permutation is necessary. Otherwise it returns null. This method is used by + * operations that operate only on inner-most axes. + */ + function getAxesPermutation(axes, rank) { + if (axesAreInnerMostDims(axes, rank)) { + return null; + } + var result = []; + for (var i = 0; i < rank; ++i) { + if (axes.indexOf(i) === -1) { + result.push(i); + } + } + axes.forEach(function (axis) { return result.push(axis); }); + return result; + } + /** Returns the axes permutation that undoes the original permutation. */ + function getUndoAxesPermutation(axes) { + return axes.map(function (axis, i) { return [i, axis]; }) + .sort(function (a, b) { return a[1] - b[1]; }) + .map(function (x) { return x[0]; }); + } + function getInnerMostAxes(numAxes, rank) { + var res = []; + for (var i = rank - numAxes; i < rank; ++i) { + res.push(i); + } + return res; + } + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Gradient helper function for the min and max operations. + */ + function gradForMinAndMax(dy, y, xOrig, origAxes, permutedAxes) { + if (y.rank < xOrig.rank) { + y = y.reshape(expandShapeToKeepDim(y.shape, origAxes)); + } + if (dy.rank < xOrig.rank) { + dy = dy.reshape(expandShapeToKeepDim(dy.shape, origAxes)); + } + return { + x: function () { + var dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); + return permutedAxes == null ? dx : dx.transpose(permutedAxes); + } + }; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the log(sum(exp(elements across the reduction dimensions)). + * + * Reduces the input along the dimensions given in `axis`. Unless `keepDims` + * is true, the rank of the array is reduced by 1 for each entry in `axis`. + * If `keepDims` is true, the reduced dimensions are retained with length 1. + * If `axis` has no entries, all dimensions are reduced, and an array with a + * single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.logSumExp().print(); // or tf.logSumExp(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis) + * ``` + * @param x The input tensor. + * @param axis The dimension(s) to reduce. If null (the default), + * reduces all dimensions. + * @param keepDims If true, retains reduced dimensions with length + * of 1. Defaults to false. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function logSumExp_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'logSumExp'); + var axes = parseAxisParam(axis, $x.shape); + var xMax = $x.max(axes, true /* keepDims */); + var a = $x.sub(xMax); + var b = a.exp(); + var c = b.sum(axes); + var d = c.log(); + var res = xMax.reshape(d.shape).add(d); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, axes); + return res.reshape(newShape); + } + return res; + } + /** + * Computes the sum of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If axes has no entries, all dimensions are reduced, and a + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.sum().print(); // or tf.sum(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.sum(axis).print(); // or tf.sum(x, axis) + * ``` + * + * @param x The input tensor to compute the sum over. If the dtype is `bool` + * it will be converted to `int32` and the output dtype will be `int32`. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function sum_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'sum'); + if ($x.dtype === 'bool') { + $x = $x.toInt(); + } + var axes = parseAxisParam(axis, $x.shape); + // Use a custom gradient to bypass 2 gradient backprops since sum is used + // extremely often. + var customOp = customGrad(function (x) { + var permutation = getAxesPermutation(axes, x.rank); + var reductionAxes = axes; + var permutedX = x; + if (permutation != null) { + permutedX = x.transpose(permutation); + reductionAxes = getInnerMostAxes(reductionAxes.length, x.rank); + } + var gradFunc = function (dy) { + var expandedDyShape = x.shape.slice(); + axes.forEach(function (axis) { + expandedDyShape[axis] = 1; + }); + var expandedDy = dy.reshape(expandedDyShape); + var derX = expandedDy.mul(ones$1(x.shape, 'float32')); + return derX; + }; + var gradInputs = function (dy) { + return { x: function () { return gradFunc(dy); } }; + }; + var attrs = { axes: reductionAxes }; + var value = ENGINE.runKernelFunc(function (backend) { return backend.sum(permutedX, reductionAxes); }, { x: permutedX }, gradInputs, 'Sum', attrs); + if (keepDims) { + var newShape = expandShapeToKeepDim(value.shape, axes); + value = value.reshape(newShape); + } + return { value: value, gradFunc: gradFunc }; + }); + return customOp($x); + } + /** + * Computes the product of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and a + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.prod().print(); // or tf.prod(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.prod(axis).print(); // or tf.prod(x, axis) + * ``` + * + * @param x The input tensor to compute the product over. If the dtype is `bool` + * it will be converted to `int32` and the output dtype will be `int32`. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function prod_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'prod'); + if ($x.dtype === 'bool') { + $x = $x.toInt(); + } + var axes = parseAxisParam(axis, $x.shape); + var permutation = getAxesPermutation(axes, $x.rank); + var reductionAxes = axes; + var permutedX = $x; + if (permutation != null) { + permutedX = $x.transpose(permutation); + reductionAxes = getInnerMostAxes(reductionAxes.length, $x.rank); + } + var value = ENGINE.runKernelFunc(function (backend) { return backend.prod(permutedX, reductionAxes); }, { permutedX: permutedX }); + if (keepDims) { + var newShape = expandShapeToKeepDim(value.shape, axes); + value = value.reshape(newShape); + } + return value; + } + /** + * Computes the mean of elements across dimensions of a `tf.Tensor`. + * + * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is + * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`. + * If `keepDims` is true, the reduced dimensions are retained with length 1. + * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with + * a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.mean().print(); // or tf.mean(a) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.mean(axis).print(); // or tf.mean(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function mean_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'mean'); + var axes = parseAxisParam(axis, $x.shape); + var shapes = computeOutAndReduceShapes($x.shape, axes); + var reduceShape = shapes[1]; + var reduceSize = sizeFromShape(reduceShape); + // Use a custom gradient to bypass 2 gradient backprops since mean is used + // extremely often. + var customOp = customGrad(function (x) { + var reduceSizeScalar = scalar(reduceSize); + // Cast if needed. + var xReduce = reduceSizeScalar.dtype === x.dtype ? x : x.cast(reduceSizeScalar.dtype); + var res = xReduce.div(reduceSizeScalar); + var value = res.sum(axis, keepDims); + var gradFunc = function (dy) { + var expandedDyShape = x.shape.slice(); + axes.forEach(function (axis) { + expandedDyShape[axis] = 1; + }); + var expandedDy = dy.reshape(expandedDyShape); + var derX = expandedDy.mul(ones$1(x.shape, 'float32')).div(reduceSize); + return derX; + }; + return { value: value, gradFunc: gradFunc }; + }); + return customOp($x); + } + /** + * Computes the minimum value from the input. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the array is reduced by 1 for each entry in `axes`. + * If `keepDims` is true, the reduced dimensions are retained with length 1. + * If `axes` has no entries, all dimensions are reduced, and an array with a + * single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.min().print(); // or tf.min(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.min(axis).print(); // or tf.min(x, axis) + * ``` + * + * @param x The input Tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function min_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'min'); + var xOrig = $x; + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + return gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); + }; + var inputsToSave = [$x]; + var outputsToSave = [true]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.min($x, axes); + save([xOrig, y]); + return y; + }, { x: $x }, grad, 'Min', { axes: axes }, inputsToSave, outputsToSave); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + res = res.reshape(newShape); + } + return res; + } + /** + * Returns the indices of the minimum values along an `axis`. + * + * The result has the same shape as `input` with the dimension along `axis` + * removed. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.argMin().print(); // or tf.argMin(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]); + * + * const axis = 1; + * x.argMin(axis).print(); // or tf.argMin(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension). + * + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function argMin_(x, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'argMin'); + if (axis == null) { + axis = 0; + } + var axes = parseAxisParam(axis, $x.shape); + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return zerosLike($x); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.argMin($x, axes[0]); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Returns the indices of the maximum values along an `axis`. + * + * The result has the same shape as `input` with the dimension along `axis` + * removed. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.argMax().print(); // or tf.argMax(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]); + * + * const axis = 1; + * x.argMax(axis).print(); // or tf.argMax(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension). + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function argMax_(x, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'argMax'); + if (axis == null) { + axis = 0; + } + var axes = parseAxisParam(axis, $x.shape); + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return zerosLike($x); } }; + }; + var attrs = { axis: axes[0] }; + var inputsToSave = [$x]; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.argMax($x, axes[0]); + save([$x]); + return res; + }, { x: $x }, grad, 'ArgMax', attrs, inputsToSave); + } + /** + * Computes the logical and of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 1, 1], 'bool'); + * + * x.all().print(); // or tf.all(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool'); + * + * const axis = 1; + * x.all(axis).print(); // or tf.all(x, axis) + * ``` + * + * @param x The input tensor. Must be of dtype bool. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function all_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'all', 'bool'); + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var res = ENGINE.runKernelFunc(function (backend) { return backend.all($x, axes); }, { $x: $x }); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + return res.reshape(newShape); + } + return res; + } + /** + * Computes the logical or of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 1, 1], 'bool'); + * + * x.any().print(); // or tf.any(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool'); + * + * const axis = 1; + * x.any(axis).print(); // or tf.any(x, axis) + * ``` + * + * @param x The input tensor. Must be of dtype bool. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function any_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'any', 'bool'); + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var res = ENGINE.runKernelFunc(function (backend) { return backend.any($x, axes); }, { $x: $x }); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + return res.reshape(newShape); + } + return res; + } + /** + * Calculates the mean and variance of `x`. The mean and variance are + * calculated by aggregating the contents of `x` across `axes`. If `x` is + * 1-D and `axes = [0]` this is just the mean and variance of a vector. + * + * @param x The input tensor. + * @param axis The dimension(s) along with to compute mean and + * variance. By default it reduces all dimensions. + * @param keepDims If true, the moments have the same dimensionality as the + * input. + * @return An object with two keys: `mean` and `variance`. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function moments_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + x = convertToTensor(x, 'x', 'moments'); + var axes = parseAxisParam(axis, x.shape); + var mean = x.mean(axes, keepDims); + var keepDimsShape = mean.shape; + if (!keepDims) { + keepDimsShape = expandShapeToKeepDim(mean.shape, axes); + } + var devSquared = x.toFloat().sub(mean.reshape(keepDimsShape)).square(); + var variance = devSquared.mean(axes, keepDims); + return { mean: mean, variance: variance }; + } + var all = op({ all_: all_ }); + // tslint:disable-next-line:variable-name + var any = op({ any_: any_ }); + var argMax = op({ argMax_: argMax_ }); + var argMin = op({ argMin_: argMin_ }); + var logSumExp = op({ logSumExp_: logSumExp_ }); + var mean = op({ mean_: mean_ }); + var min = op({ min_: min_ }); + var moments = op({ moments_: moments_ }); + var sum$1 = op({ sum_: sum_ }); + var prod = op({ prod_: prod_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var broadcastToGradConfig = { + kernelName: BroadcastTo, + gradFunc: function (dy, saved, attrs) { + var broadCastToAttrs = attrs; + var inputShape = broadCastToAttrs.inputShape; + var outputShape = broadCastToAttrs.shape; + var reps = Array.from(outputShape); + for (var i = inputShape.length - 1; i >= 0; i--) { + if (inputShape[i] === outputShape[i]) { + reps[i] = 1; + } + else if (inputShape[i] !== 1) { + throw new Error("broadcastTo(): [" + inputShape + "] cannot be broadcast to [" + outputShape + "]."); + } + } + var axes = []; + for (var i = 0; i < reps.length; i++) { + if (reps[i] > 1) { + axes.push(i); + } + } + return { x: function () { return sum$1(dy, axes, true /* keepDims */); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Splits a `tf.Tensor` into sub tensors. + * + * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` + * into `numOrSizeSplits` smaller tensors. + * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. + * + * If `numOrSizeSplits` is a number array, splits `x` into + * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the + * same size as `x` except along dimension `axis` where the size is + * `numOrSizeSplits[i]`. + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + * const [a, b] = tf.split(x, 2, 1); + * a.print(); + * b.print(); + * + * const [c, d, e] = tf.split(x, [1, 2, 1], 1); + * c.print(); + * d.print(); + * e.print(); + * ``` + * + * @param x The input tensor to split. + * @param numOrSizeSplits Either an integer indicating the number of + * splits along the axis or an array of integers containing the sizes of + * each output tensor along the axis. If a number then it must evenly divide + * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. + * @param axis The dimension along which to split. Defaults to 0 (the first + * dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function split_(x, numOrSizeSplits, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'split'); + var $axis = parseAxisParam(axis, $x.shape)[0]; + var splitSizes; + if (typeof (numOrSizeSplits) === 'number') { + assert($x.shape[$axis] % numOrSizeSplits === 0, function () { return 'Number of splits must evenly divide the axis.'; }); + splitSizes = + new Array(numOrSizeSplits).fill($x.shape[$axis] / numOrSizeSplits); + } + else { + assert($x.shape[$axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), function () { return 'The sum of sizes must match the size of the axis dimension.'; }); + splitSizes = numOrSizeSplits; + } + var forward = function (backend, _) { + return backend.split($x, splitSizes, $axis); + }; + var inputs = { x: $x }; + var attr = { numOrSizeSplits: numOrSizeSplits, axis: axis }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, SplitV, attr); + } + var split = op({ split_: split_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var concatGradConfig = { + kernelName: Concat, + saveAllInputs: true, + gradFunc: function (dy, saved, attrs) { + var shapes = saved.map(function (t) { return t.shape; }); + var axis = attrs.axis; + var $axis = parseAxisParam(axis, saved[0].shape)[0]; + var sizeSplits = shapes.map(function (s) { return s[$axis]; }); + var derTensors = split(dy, sizeSplits, $axis); + return derTensors.map(function (t) { return function () { return t; }; }); + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the derivative of the filter of a 2D convolution. + * + * @param x The input tensor, of rank 4 or rank 3 of shape + * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed. + * @param dy The dy image, of rank 4 or rank 3, of shape + * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed. + * @param filterShape The shape of the filter, length 4, + * [filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideHeight, + * strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + var x4D = x; + if (x.rank === 3) { + x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]); + } + var dy4D = dy; + if (dy4D.rank === 3) { + dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); + } + assert(x4D.rank === 4, function () { return "Error in conv2dDerFilter: input must be rank 4, but got shape " + + (x4D.shape + "."); }); + assert(dy4D.rank === 4, function () { return "Error in conv2dDerFilter: dy must be rank 4, but got shape " + + (dy4D.shape + "."); }); + assert(filterShape.length === 4, function () { return "Error in conv2dDerFilter: filterShape must be length 4, but got " + + (filterShape + "."); }); + var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; + var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; + assert(inDepth === filterShape[2], function () { return "Error in conv2dDerFilter: depth of input " + inDepth + ") must " + + ("match input depth in filter (" + filterShape[2] + "."); }); + assert(outDepth === filterShape[3], function () { return "Error in conv2dDerFilter: depth of dy (" + outDepth + ") must " + + ("match output depth for filter (" + filterShape[3] + ")."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv2dDerFilter: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var forward = function (backend) { + var dilations = 1; + var $dataFormat = convertConv2DDataFormat(dataFormat); + var convInfo = computeConv2DInfo(x4D.shape, filterShape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + return backend.conv2dDerFilter(x4D, dy4D, convInfo); + }; + var inputs = { x: x4D, dy: dy4D }; + var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dimRoundingMode: dimRoundingMode }; + return ENGINE.runKernelFunc(forward, inputs, null, Conv2DBackpropFilter, attrs); + } + var conv2DBackpropFilter = op({ conv2DBackpropFilter_: conv2DBackpropFilter_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the derivative of the input of a 2D convolution. + * + * @param xShape The shape of the input: [batch, height, width, inDepth]. + * If length of 3, batch of 1 is assumed. + * @param dy The derivative of the output, of rank 4 or rank 3 of shape + * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm used: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + assert(xShape.length === dy.rank, function () { return "Length of inShape " + + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"); }); + var xShape4D = xShape; + var dy4D = dy; + var reshapedTo4D = false; + if (dy.rank === 3) { + reshapedTo4D = true; + dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); + xShape4D = [1, xShape[0], xShape[1], xShape[2]]; + } + assert(xShape4D.length === 4, function () { + return "Error in conv2dDerInput: inShape must be length 4, but got length " + + (xShape4D.length + "."); + }); + assert(dy4D.rank === 4, function () { return "Error in conv2dDerInput: dy must be rank 4, but got " + + ("rank " + dy4D.rank); }); + assert(filter.rank === 4, function () { return "Error in conv2dDerInput: filter must be rank 4, but got " + + ("rank " + filter.rank); }); + var inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1]; + var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; + assert(inDepth === filter.shape[2], function () { return "Error in conv2dDerInput: depth of input (" + inDepth + ") must " + + ("match input depth for filter " + filter.shape[2] + "."); }); + assert(outDepth === filter.shape[3], function () { return "Error in conv2dDerInput: depth of output (" + outDepth + ") must " + + ("match output depth for filter " + filter.shape[3] + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv2dDerInput: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var forward = function (backend, save) { + var dilations = 1; + var $dataFormat = convertConv2DDataFormat(dataFormat); + var convInfo = computeConv2DInfo(xShape4D, filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + var res = backend.conv2dDerInput(dy4D, filter, convInfo); + save([dy4D, filter]); + return res; + }; + var inputs = { dy: dy4D, filter: filter }; + var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dimRoundingMode: dimRoundingMode }; + var res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, Conv2DBackpropInput, attrs); + if (reshapedTo4D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]); + } + return res; + } + var conv2DBackpropInput = op({ conv2DBackpropInput_: conv2DBackpropInput_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var conv2DGradConfig = { + kernelName: Conv2D, + inputsToSave: ['x', 'filter'], + gradFunc: function (dy, saved, attrs) { + var _a = saved, x4D = _a[0], $filter = _a[1]; + var _b = attrs, dilations = _b.dilations, strides = _b.strides, pad = _b.pad, dataFormat = _b.dataFormat; + assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of conv2D: dilation rates greater than 1 ' + + ("are not yet supported in gradients. Got dilations '" + dilations + "'"); }); + return { + x: function () { + return conv2DBackpropInput(x4D.shape, dy, $filter, strides, pad, dataFormat); + }, + filter: function () { + return conv2DBackpropFilter(x4D, dy, $filter.shape, strides, pad, dataFormat); + } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes a 2D convolution over the input x. + * + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function conv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + if (dilations === void 0) { dilations = [1, 1]; } + var $x = convertToTensor(x, 'x', 'conv2d'); + var $filter = convertToTensor(filter, 'filter', 'conv2d'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); + } + assert(x4D.rank === 4, function () { return "Error in conv2d: input must be rank 4, but got rank " + x4D.rank + "."; }); + assert($filter.rank === 4, function () { return "Error in conv2d: filter must be rank 4, but got rank " + + ($filter.rank + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv2d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; + assert(inDepth === $filter.shape[2], function () { return "Error in conv2d: depth of input (" + inDepth + ") must match " + + ("input depth for filter " + $filter.shape[2] + "."); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv2D: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + var forward = function (backend, save) { + var $dataFormat = convertConv2DDataFormat(dataFormat); + var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + var res = backend.conv2d(x4D, $filter, convInfo); + save([x4D, $filter]); + return res; + }; + var inputs = { x: x4D, filter: $filter }; + var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dilations: dilations, dimRoundingMode: dimRoundingMode }; + var res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, Conv2D, attrs); + if (reshapedTo4D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]); + } + return res; + } + var conv2d = op({ conv2d_: conv2d_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var conv2DBackpropInputGradConfig = { + kernelName: Conv2DBackpropInput, + inputsToSave: ['dy', 'filter'], + gradFunc: function (ddx, saved, attrs) { + var _a = saved, dy = _a[0], filter = _a[1]; + var _b = attrs, strides = _b.strides, pad = _b.pad, dataFormat = _b.dataFormat, dimRoundingMode = _b.dimRoundingMode; + return { + dy: function () { return conv2d(ddx, filter, strides, pad, dataFormat, 1 /* dilations */, dimRoundingMode); }, + filter: function () { return conv2DBackpropFilter(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode); } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the derivative of the filter of a 3D convolution. + * + * @param x The input tensor, of rank 5 or rank 4 of shape + * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is + * assumed. + * @param dy The dy image, of rank 5 or rank 4, of shape + * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is + * assumed. + * @param filterShape The shape of the filter, length 5, + * [filterDepth, filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideDepth, strideHeight, + * strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + */ + function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) { + var x5D = x; + if (x.rank === 4) { + x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]); + } + var dy5D = dy; + if (dy5D.rank === 4) { + dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]); + } + assert(x5D.rank === 5, function () { return "Error in conv3dDerFilter: input must be rank 5, but got shape " + + (x5D.shape + "."); }); + assert(dy5D.rank === 5, function () { return "Error in conv3dDerFilter: dy must be rank 5, but got shape " + + (dy5D.shape + "."); }); + assert(filterShape.length === 5, function () { return "Error in conv3dDerFilter: filterShape must be length 5, but got " + + (filterShape + "."); }); + assert(x5D.shape[4] === filterShape[3], function () { return "Error in conv3dDerFilter: depth of input " + x5D.shape[4] + ") must " + + ("match input depth in filter (" + filterShape[3] + "."); }); + assert(dy5D.shape[4] === filterShape[4], function () { return "Error in conv3dDerFilter: depth of dy (" + dy5D.shape[4] + ") must " + + ("match output depth for filter (" + filterShape[4] + ")."); }); + var forward = function (backend) { + var dilations = 1; + var convInfo = computeConv3DInfo(x5D.shape, filterShape, strides, dilations, pad); + return backend.conv3dDerFilter(x5D, dy5D, convInfo); + }; + var inputs = { x: x5D, y: dy5D }; + var attrs = { strides: strides, pad: pad }; + return ENGINE.runKernelFunc(forward, inputs, null, Conv3DBackpropFilterV2, attrs); + } + var conv3DBackpropFilter = op({ conv3DBackpropFilter_: conv3DBackpropFilter_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the derivative of the input of a 3D convolution. + * + * @param xShape The shape of the input: [batch, depth, height, width, + * in_channels]. If length of 4, batch of 1 is assumed. + * @param dy The derivative of the output, of rank 5 or rank 4 of shape + * `[batch, outDepth, outHeight, outWidth, in_channels]`. + * If rank 4, batch of 1 is assumed. + * @param filter The filter, rank 5, of shape + * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideDepth, strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm used: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + */ + function conv3DBackpropInput_(xShape, dy, filter, strides, pad) { + assert(xShape.length === dy.rank, function () { return "Length of inShape " + + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"); }); + var xShape5D = xShape; + var dy5D = dy; + var reshapedTo5D = false; + if (dy.rank === 4) { + reshapedTo5D = true; + dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]); + xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]]; + } + var inDepth = xShape5D[4]; + var outDepth = dy5D.shape[4]; + assert(xShape5D.length === 5, function () { + return "Error in conv3dDerInput: inShape must be length 5, but got length " + + (xShape5D.length + "."); + }); + assert(dy5D.rank === 5, function () { return "Error in conv3dDerInput: dy must be rank 5, but got " + + ("rank " + dy5D.rank); }); + assert(filter.rank === 5, function () { return "Error in conv3dDerInput: filter must be rank 5, but got " + + ("rank " + filter.rank); }); + assert(inDepth === filter.shape[3], function () { return "Error in conv3dDerInput: depth of input (" + inDepth + ") must " + + ("match input depth for filter " + filter.shape[3] + "."); }); + assert(outDepth === filter.shape[4], function () { return "Error in conv3dDerInput: depth of output (" + outDepth + ") must " + + ("match output depth for filter " + filter.shape[4] + "."); }); + var forward = function (backend) { + var dilations = 1; + var convInfo = computeConv3DInfo(xShape5D, filter.shape, strides, dilations, pad); + return backend.conv3dDerInput(dy5D, filter, convInfo); + }; + var inputs = { dy: dy5D }; + var attrs = { pad: pad }; + var res = ENGINE.runKernelFunc(forward, inputs, null, Conv3DBackpropInputV2, attrs); + if (reshapedTo5D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); + } + return res; + } + var conv3DBackpropInput = op({ conv3DBackpropInput_: conv3DBackpropInput_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var conv3DGradConfig = { + kernelName: Conv3D, + inputsToSave: ['x', 'filter'], + gradFunc: function (dy, saved, attrs) { + var _a = attrs, dilations = _a.dilations, strides = _a.strides, pad = _a.pad; + assert(tupleValuesAreOne(dilations), function () { + return 'Error in gradient of conv3D: dilation rates greater than 1 are ' + + ("not yet supported in gradients. Got dilations '" + dilations + "'"); + }); + var x5D = saved[0], $filter = saved[1]; + return { + x: function () { return conv3DBackpropInput(x5D.shape, dy, $filter, strides, pad); }, + filter: function () { return conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad); } + }; + } + }; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`. + * + * The returned `tf.Tensor`'s dimension `i` will correspond to the input + * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`, + * where `n` is the rank of the input `tf.Tensor`. Hence by default, this + * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s. + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + * + * a.transpose().print(); // or tf.transpose(a) + * ``` + * + * @param x The tensor to transpose. + * @param perm The permutation of the dimensions of a. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function transpose_(x, perm) { + var $x = convertToTensor(x, 'x', 'transpose'); + if (perm == null) { + perm = $x.shape.map(function (s, i) { return i; }).reverse(); + } + assert($x.rank === perm.length, function () { return "Error in transpose: rank of input " + $x.rank + " " + + ("must match length of perm " + perm + "."); }); + perm.forEach(function (axis) { + assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + + (" but got " + perm); }); + }); + if ($x.rank <= 1) { + return $x.clone(); + } + var attrs = { perm: perm }; + return ENGINE.runKernelFunc(function (backend) { return backend.transpose($x, perm); }, { x: $x }, null /* gradient */, 'Transpose', attrs); + } + var transpose = op({ transpose_: transpose_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the cumulative sum of a `tf.Tensor` along `axis`. + * + * ```js + * const x = tf.tensor([1, 2, 3, 4]); + * x.cumsum().print(); + * ``` + * ```js + * const x = tf.tensor([[1, 2], [3, 4]]); + * x.cumsum().print(); + * ``` + * + * @param x The input tensor to be summed. + * @param axis The axis along which to sum. Optional. Defaults to 0. + * @param exclusive Whether to perform exclusive cumulative sum. Optional. + * Defaults to false. If set to true then the sum of each tensor entry + * does not include its own value, but only the values previous to it + * along the specified axis. + * @param reverse Whether to sum in the opposite direction. Optional. + * Defaults to false. + */ + /** @doc {heading: 'Operations', subheading: 'Scan'} */ + function cumsum_(x, axis, exclusive, reverse) { + if (axis === void 0) { axis = 0; } + if (exclusive === void 0) { exclusive = false; } + if (reverse === void 0) { reverse = false; } + var $x = convertToTensor(x, 'x', 'cumsum'); + var forward = function (backend, save) { + var permutation = getAxesPermutation([axis], $x.rank); + var permutedX = $x; + if (permutation != null) { + permutedX = transpose($x, permutation); + } + var permutedAxis = getInnerMostAxes(1, $x.rank)[0]; + var value = backend.cumsum(permutedX, permutedAxis, exclusive, reverse); + save([$x]); + if (permutation != null) { + value = transpose(value, permutation); + } + return value; + }; + var inputs = { x: $x }; + var attrs = { axis: axis, exclusive: exclusive, reverse: reverse }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Cumsum, attrs); + } + var cumsum = op({ cumsum_: cumsum_ }); + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var cumsumGradConfig = { + kernelName: Cumsum, + inputsToSave: ['x'], + gradFunc: function (dy, saved, attrs) { + var x = saved[0]; + var _a = attrs, axis = _a.axis, exclusive = _a.exclusive, reverse = _a.reverse; + return { + x: function () { + var permutation = getAxesPermutation([axis], x.rank); + var out = cumsum(dy, axis, exclusive, !reverse); + if (permutation != null) { + out = transpose(out, permutation); + } + return out; + } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, convInfo) { + var x4D = x; + if (x.rank === 3) { + x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]); + } + var dy4D = dy; + if (dy4D.rank === 3) { + dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); + } + var forward = function (backend) { + return backend.depthwiseConv2DDerFilter(x4D, dy4D, convInfo); + }; + var inputs = { x: x4D, dy: dy4D }; + return ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNativeBackpropFilter); + } + var depthwiseConv2dNativeBackpropFilter = op({ depthwiseConv2dNativeBackpropFilter_: depthwiseConv2dNativeBackpropFilter_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, convInfo) { + var dy4D = dy; + var reshapedTo4D = false; + if (dy.rank === 3) { + reshapedTo4D = true; + dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); + } + var forward = function (backend) { + return backend.depthwiseConv2DDerInput(dy4D, filter, convInfo); + }; + var inputs = { dy: dy4D }; + var res = ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNativeBackpropInput); + if (reshapedTo4D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]); + } + return res; + } + var depthwiseConv2dNativeBackpropInput = op({ depthwiseConv2dNativeBackpropInput_: depthwiseConv2dNativeBackpropInput_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var depthwiseConv2dNativeGradConfig = { + kernelName: DepthwiseConv2dNative, + inputsToSave: ['x', 'filter'], + gradFunc: function (dy, saved, attrs) { + var _a = attrs, dilations = _a.dilations, strides = _a.strides, pad = _a.pad, dimRoundingMode = _a.dimRoundingMode; + var $dilations = dilations == null ? [1, 1] : dilations; + assert(tupleValuesAreOne($dilations), function () { return 'Error in gradient of depthwiseConv2dNative: dilation rates ' + + "greater than 1 are not yet supported. Got dilations " + + ("'" + $dilations + "'"); }); + var _b = saved, x = _b[0], filter = _b[1]; + assert(x.rank === 4, function () { return "Error in gradient of depthwiseConv2dNative: input must be " + + ("rank 4, but got rank " + x.rank + "."); }); + assert(filter.rank === 4, function () { return "Error in gradient of depthwiseConv2dNative: filter must be " + + ("rank 4, but got rank " + filter.rank + "."); }); + assert(x.shape[3] === filter.shape[2], function () { return "Error in gradient of depthwiseConv2d: number of input " + + ("channels (" + x.shape[3] + ") must match the inChannels dimension ") + + ("in filter " + filter.shape[2] + "."); }); + assert(eitherStridesOrDilationsAreOne(strides, $dilations), function () { return 'Error in gradient of depthwiseConv2d: Either strides or ' + + ("dilations must be 1. Got strides " + strides + " and dilations ") + + ("'" + $dilations + "'."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { + return "Error in depthwiseConv2d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); + }); + } + var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); + return { + x: function () { + return depthwiseConv2dNativeBackpropInput(x.shape, dy, filter, convInfo); + }, + filter: function () { + return depthwiseConv2dNativeBackpropFilter(x, dy, filter.shape, convInfo); + }, + }; + } + }; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Enables production mode which disables correctness checks in favor of + * performance. + */ + /** @doc {heading: 'Environment'} */ + function enableProdMode() { + env().set('PROD', true); + } + /** + * Enables debug mode which will log information about all executed kernels: + * the elapsed time of the kernel execution, as well as the rank, shape, and + * size of the output tensor. + * + * Debug mode will significantly slow down your application as it will + * download the result of every operation to the CPU. This should not be used in + * production. Debug mode does not affect the timing information of the kernel + * execution as we do not measure download time in the kernel execution time. + * + * See also: `tf.profile`, `tf.memory`. + */ + /** @doc {heading: 'Environment'} */ + function enableDebugMode() { + env().set('DEBUG', true); + } + /** Globally disables deprecation warnings */ + function disableDeprecationWarnings() { + env().set('DEPRECATION_WARNINGS_ENABLED', false); + console.warn("TensorFlow.js deprecation warnings have been disabled."); + } + /** Warn users about deprecated functionality. */ + function deprecationWarn(msg) { + if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) { + console.warn(msg + ' You can disable deprecation warnings with ' + + 'tf.disableDeprecationWarnings().'); + } + } + /** + * Dispose all variables kept in backend engine. + */ + /** @doc {heading: 'Environment'} */ + function disposeVariables() { + ENGINE.disposeVariables(); + } + /** + * It returns the global engine that keeps track of all tensors and backends. + */ + /** @doc {heading: 'Environment'} */ + function engine() { + return ENGINE; + } + /** + * Returns memory info at the current time in the program. The result is an + * object with the following properties: + * + * - `numBytes`: Number of bytes allocated (undisposed) at this time. + * - `numTensors`: Number of unique tensors allocated. + * - `numDataBuffers`: Number of unique data buffers allocated + * (undisposed) at this time, which is ≤ the number of tensors + * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same + * data buffer with `a`). + * - `unreliable`: True if the memory usage is unreliable. See `reasons` when + * `unreliable` is true. + * - `reasons`: `string[]`, reasons why the memory is unreliable, present if + * `unreliable` is true. + * + * WebGL Properties: + * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at + * this time. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function memory() { + return ENGINE.memory(); + } + /** + * Executes the provided function `f()` and returns a promise that resolves + * with information about the function's memory use: + * - `newBytes`: the number of new bytes allocated + * - `newTensors`: the number of new tensors created + * - `peakBytes`: the peak number of bytes allocated + * - `kernels`: an array of objects for each kernel involved that reports + * their input and output shapes, number of bytes used, and number of new + * tensors created. + * + * ```js + * const profile = await tf.profile(() => { + * const x = tf.tensor1d([1, 2, 3]); + * let x2 = x.square(); + * x2.dispose(); + * x2 = x.square(); + * x2.dispose(); + * return x; + * }); + * + * console.log(`newBytes: ${profile.newBytes}`); + * console.log(`newTensors: ${profile.newTensors}`); + * console.log(`byte usage over all kernels: ${profile.kernels.map(k => + * k.totalBytesSnapshot)}`); + * ``` + * + */ + /** @doc {heading: 'Performance', subheading: 'Profile'} */ + function profile(f) { + return ENGINE.profile(f); + } + /** + * Executes the provided function `fn` and after it is executed, cleans up all + * intermediate tensors allocated by `fn` except those returned by `fn`. + * `fn` must not return a Promise (async functions not allowed). The returned + * result can be a complex object. + * + * Using this method helps avoid memory leaks. In general, wrap calls to + * operations in `tf.tidy` for automatic memory cleanup. + * + * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to + * dispose variables, please use `tf.disposeVariables` or call dispose() + * directly on variables. + * + * ```js + * // y = 2 ^ 2 + 1 + * const y = tf.tidy(() => { + * // a, b, and one will be cleaned up when the tidy ends. + * const one = tf.scalar(1); + * const a = tf.scalar(2); + * const b = a.square(); + * + * console.log('numTensors (in tidy): ' + tf.memory().numTensors); + * + * // The value returned inside the tidy function will return + * // through the tidy, in this case to the variable y. + * return b.add(one); + * }); + * + * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); + * y.print(); + * ``` + * + * @param nameOrFn The name of the closure, or the function to execute. + * If a name is provided, the 2nd argument should be the function. + * If debug mode is on, the timing and the memory usage of the function + * will be tracked and displayed on the console using the provided name. + * @param fn The function to execute. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function tidy(nameOrFn, fn) { + return ENGINE.tidy(nameOrFn, fn); + } + /** + * Disposes any `tf.Tensor`s found within the provided object. + * + * @param container an object that may be a `tf.Tensor` or may directly + * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If + * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing + * happens. In general it is safe to pass any object here, except that + * `Promise`s are not supported. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function dispose(container) { + var tensors = getTensorsInContainer(container); + tensors.forEach(function (tensor) { return tensor.dispose(); }); + } + /** + * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed + * automatically. + * + * ```js + * let b; + * const y = tf.tidy(() => { + * const one = tf.scalar(1); + * const a = tf.scalar(2); + * + * // b will not be cleaned up by the tidy. a and one will be cleaned up + * // when the tidy ends. + * b = tf.keep(a.square()); + * + * console.log('numTensors (in tidy): ' + tf.memory().numTensors); + * + * // The value returned inside the tidy function will return + * // through the tidy, in this case to the variable y. + * return b.add(one); + * }); + * + * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); + * console.log('y:'); + * y.print(); + * console.log('b:'); + * b.print(); + * ``` + * + * @param result The tensor to keep from being disposed. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function keep(result) { + return ENGINE.keep(result); + } + /** + * Executes `f()` and returns a promise that resolves with timing + * information. + * + * The result is an object with the following properties: + * + * - `wallMs`: Wall execution time. + * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the + * WebGL backend and the query timer extension is not available, this will + * return an error object. + * - On `WebGL` The following additional properties exist: + * - `uploadWaitMs`: CPU blocking time on texture uploads. + * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels). + * + * ```js + * const x = tf.randomNormal([20, 20]); + * const time = await tf.time(() => x.matMul(x)); + * + * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`); + * ``` + * + * @param f The function to execute and time. + */ + /** @doc {heading: 'Performance', subheading: 'Timing'} */ + function time(f) { + return ENGINE.time(f); + } + /** + * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and + * executing operations on those tensors. Returns a promise that resolves + * to a boolean if the backend initialization was successful. + * + * Note this disposes the current backend, if any, as well as any tensors + * associated with it. A new backend is initialized, even if it is of the + * same type as the previous one. + * + * @param backendName The name of the backend. Currently supports + * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js + * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm). + */ + /** @doc {heading: 'Backends'} */ + function setBackend(backendName) { + return ENGINE.setBackend(backendName); + } + /** + * Returns a promise that resolves when the currently selected backend (or the + * highest priority one) has initialized. Await this promise when you are using + * a backend that has async initialization. + */ + /** @doc {heading: 'Backends'} */ + function ready() { + return ENGINE.ready(); + } + /** + * Returns the current backend name (cpu, webgl, etc). The backend is + * responsible for creating tensors and executing operations on those tensors. + */ + /** @doc {heading: 'Backends'} */ + function getBackend() { + return ENGINE.backendName; + } + /** + * Removes a backend and the registered factory. + */ + /** @doc {heading: 'Backends'} */ + function removeBackend(name) { + ENGINE.removeBackend(name); + } + /** + * Finds the backend registered under the provided name. Returns null if the + * name is not in the registry, or the registration hasn't finished yet. + */ + function findBackend(name) { + return ENGINE.findBackend(name); + } + /** + * Finds the backend factory registered under the provided name. Returns a + * function that produces a new backend when called. Returns null if the name + * is not in the registry. + */ + function findBackendFactory(name) { + return ENGINE.findBackendFactory(name); + } + /** + * Registers a global backend. The registration should happen when importing + * a module file (e.g. when importing `backend_webgl.ts`), and is used for + * modular builds (e.g. custom tfjs bundle with only webgl support). + * + * @param factory The backend factory function. When called, it should + * return a backend instance, or a promise of an instance. + * @param priority The priority of the backend (higher = more important). + * In case multiple backends are registered, the priority is used to find + * the best backend. Defaults to 1. + * @return False if there is already a registered backend under this name, true + * if not. + */ + /** @doc {heading: 'Backends'} */ + function registerBackend(name, factory, priority) { + if (priority === void 0) { priority = 1; } + return ENGINE.registerBackend(name, factory, priority); + } + /** + * Gets the current backend. If no backends have been initialized, this will + * attempt to initialize the best backend. Will throw an error if the highest + * priority backend has async initialization, in which case, you should call + * 'await tf.ready()' before running other code. + */ + /** @doc {heading: 'Backends'} */ + function backend() { + return ENGINE.backend; + } + /** + * Sets the global platform. + * + * @param platformName The name of this platform. + * @param platform A platform implementation. + */ + function setPlatform(platformName, platform) { + env().setPlatform(platformName, platform); + } + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting. + * + * + * ```js + * const a = tf.tensor1d([1, 2, 3, 4]); + * const b = tf.tensor1d([10, 20, 30, 40]); + * + * a.add(b).print(); // or tf.add(a, b) + * ``` + * + * ```js + * // Broadcast add a with b. + * const a = tf.scalar(5); + * const b = tf.tensor1d([10, 20, 30, 40]); + * + * a.add(b).print(); // or tf.add(a, b) + * ``` + * @param a The first `tf.Tensor` to add. + * @param b The second `tf.Tensor` to add. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function add_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'add'); + var $b = convertToTensor(b, 'b', 'add'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var forward = function (backend, save) { + var res = backend.add($a, $b); + save([$a, $b]); + return res; + }; + var inputs = { a: $a, b: $b }; + return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Add); + } + var add = op({ add_: add_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes `-1 * x` element-wise. + * + * ```js + * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]); + * + * x.neg().print(); // or tf.neg(x) + * ``` + * + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function neg_(x) { + var $x = convertToTensor(x, 'x', 'neg'); + var grad = function (dy) { + return { x: function () { return dy.neg(); } }; + }; + var attrs = {}; + var inputsToSave = [$x]; + return ENGINE.runKernelFunc(function (backend) { return backend.neg($x); }, { x: $x }, grad, 'Neg', attrs, inputsToSave); + } + /** + * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)` + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.ceil().print(); // or tf.ceil(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function ceil_(x) { + var $x = convertToTensor(x, 'x', 'ceil'); + // TODO(manrajgrover): Return null for gradients when backprop supports it. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.ceil($x); }, { $x: $x }, grad); + } + /** + * Computes floor of input `tf.Tensor` element-wise: `floor(x)`. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.floor().print(); // or tf.floor(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function floor_(x) { + var $x = convertToTensor(x, 'x', 'floor'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.floor($x); }, { $x: $x }, grad); + } + /** + * Returns an element-wise indication of the sign of a number. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]); + * + * x.sign().print(); // or tf.sign(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sign_(x) { + var $x = convertToTensor(x, 'x', 'sign'); + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.sign($x); }, { $x: $x }, grad); + } + /** + * RReturns which elements of x are NaN. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isNaN().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function isNaN_(x) { + var $x = convertToTensor(x, 'x', 'isNaN'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.isNaN($x); }, { $x: $x }, grad); + } + /** + * Returns which elements of x are Infinity or -Infinity. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isInf().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function isInf_(x) { + var $x = convertToTensor(x, 'x', 'isInf'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.isInf($x); }, { $x: $x }, grad); + } + /** + * Returns which elements of x are finite. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isFinite().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function isFinite_(x) { + var $x = convertToTensor(x, 'x', 'isFinite'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.isFinite($x); }, { $x: $x }, grad); + } + /** + * Computes round of input `tf.Tensor` element-wise: `round(x)`. + * It implements banker's rounding. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.round().print(); // or tf.round(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function round_(x) { + var $x = convertToTensor(x, 'x', 'round'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.round($x); }, { $x: $x }, grad); + } + /** + * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x` + * + * ```js + * const x = tf.tensor1d([1, 2, -3]); + * + * x.exp().print(); // or tf.exp(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function exp_(x) { + var $x = convertToTensor(x, 'x', 'exp'); + var bck = function (dy, saved) { + // tslint:disable-next-line: no-unnecessary-type-assertion + return { x: function () { return dy.mul(saved[0]); } }; + }; + var attrs = {}; + var inputsToSave = []; + var outputsToSave = [true]; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.exp($x); + save([y]); + return y; + }, { x: $x }, bck, 'Exp', attrs, inputsToSave, outputsToSave); + } + /** + * Computes exponential of the input `tf.Tensor` minus one element-wise. + * `e ^ x - 1` + * + * ```js + * const x = tf.tensor1d([1, 2, -3]); + * + * x.expm1().print(); // or tf.expm1(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function expm1_(x) { + var $x = convertToTensor(x, 'x', 'expm1'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mul($x.exp()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.expm1($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.E]); + * + * x.log().print(); // or tf.log(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function log_(x) { + var $x = convertToTensor(x, 'x', 'log'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return dy.div($x.toFloat()); } }; + }; + var attrs = {}; + var inputsToSave = [$x]; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.log($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Log', attrs, inputsToSave); + } + /** + * Computes natural logarithm of the input `tf.Tensor` plus one + * element-wise: `ln(1 + x)` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.E - 1]); + * + * x.log1p().print(); // or tf.log1p(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function log1p_(x) { + var $x = convertToTensor(x, 'x', 'log1p'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.add(1)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.log1p($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, 4, -1]); + * + * x.sqrt().print(); // or tf.sqrt(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sqrt_(x) { + var $x = convertToTensor(x, 'x', 'sqrt'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return dy.div($x.toFloat().sqrt().mul(2)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.sqrt($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Sqrt', {}); + } + /** + * Computes reciprocal of square root of the input `tf.Tensor` element-wise: + * `y = 1 / sqrt(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, 4, -1]); + * + * x.rsqrt().print(); // or tf.rsqrt(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function rsqrt_(x) { + var $x = convertToTensor(x, 'x', 'rsqrt'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return dy.div($x.pow(1.5).mul(2)).neg(); } }; + }; + var inputsToSave = [$x]; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.rsqrt($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Rsqrt', {} /* attrs */, inputsToSave); + } + /** + * Computes reciprocal of x element-wise: `1 / x` + * + * ```js + * const x = tf.tensor1d([0, 1, 2]); + * + * x.reciprocal().print(); // or tf.reciprocal(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function reciprocal_(x) { + var $x = convertToTensor(x, 'x', 'reciprocal'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.square().neg()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.reciprocal($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes absolute value element-wise: `abs(x)` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.abs().print(); // or tf.abs(x) + * ``` + * @param x The input `tf.Tensor`. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function abs_(x) { + var $x = convertToTensor(x, 'x', 'abs'); + if ($x.dtype === 'complex64') { + return ENGINE.runKernelFunc(function (backend) { return backend.complexAbs($x); }, { $x: $x }); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return dy.mul($x.toFloat().step(-1)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.abs($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Abs'); + } + /** + * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3) + * ``` + * @param x The input tensor. + * @param clipValueMin Lower-bound of range to be clipped to. + * @param clipValueMax Upper-bound of range to be clipped to. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function clipByValue_(x, clipValueMin, clipValueMax) { + var $x = convertToTensor(x, 'x', 'clipByValue'); + assert((clipValueMin <= clipValueMax), function () { return "Error in clip: min (" + clipValueMin + ") must be " + + ("less than or equal to max (" + clipValueMax + ")."); }); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + x: function () { return dy.where($x.greaterEqual(clipValueMin) + .logicalAnd($x.lessEqual(clipValueMax)), zerosLike(dy)); }, + }; + }; + var inputsToSave = [$x]; + var attr = { min: clipValueMin, max: clipValueMax }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.clip($x, clipValueMin, clipValueMax); + save([$x]); + return res; + }, { x: $x }, grad, 'ClipByValue', attr, inputsToSave); + } + /** + * Computes sigmoid element-wise, `1 / (1 + exp(-x))` + * + * ```js + * const x = tf.tensor1d([0, -1, 2, -3]); + * + * x.sigmoid().print(); // or tf.sigmoid(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sigmoid_(x) { + var $x = convertToTensor(x, 'x', 'sigmoid'); + var grad = function (dy, saved) { + var y = saved[0]; + return { x: function () { return dy.mul(y.mul(scalar(1).sub(y))); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.sigmoid($x); + save([y]); + return y; + }, { x: $x }, grad, 'Sigmoid'); + } + /** + * Computes log sigmoid of the input `tf.Tensor` element-wise: + * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`. + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.logSigmoid().print(); // or tf.logSigmoid(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function logSigmoid_(x) { + var $x = convertToTensor(x, 'x', 'logSigmoid'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mul($x.neg().sigmoid()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.softplus($x.neg()).neg(); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.softplus().print(); // or tf.softplus(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function softplus_(x) { + var $x = convertToTensor(x, 'x', 'softplus'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mul($x.sigmoid()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.softplus($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes sin of the input Tensor element-wise: `sin(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.sin().print(); // or tf.sin(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sin_(x) { + var $x = convertToTensor(x, 'x', 'sin'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return $x.toFloat().cos().mul(dy); } }; + }; + var inputsToSave = [$x]; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.sin($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Sin', {} /* attrs */, inputsToSave); + } + /** + * Computes cos of the input `tf.Tensor` element-wise: `cos(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.cos().print(); // or tf.cos(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function cos_(x) { + var $x = convertToTensor(x, 'x', 'cos'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return $x.toFloat().sin().neg().mul(dy); } }; + }; + var inputsToSave = [$x]; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.cos($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Cos', {} /* attrs */, inputsToSave); + } + /** + * Computes tan of the input `tf.Tensor` element-wise, `tan(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.tan().print(); // or tf.tan(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function tan_(x) { + var $x = convertToTensor(x, 'x', 'tan'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.cos().square()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.tan($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes asin of the input `tf.Tensor` element-wise: `asin(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.asin().print(); // or tf.asin(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function asin_(x) { + var $x = convertToTensor(x, 'x', 'asin'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + // tslint:disable-next-line: no-unnecessary-type-assertion + $x: function () { return dy.div(scalar(1).sub($x.toFloat().square()).sqrt()); } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.asin($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes acos of the input `tf.Tensor` element-wise: `acos(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.acos().print(); // or tf.acos(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function acos_(x) { + var $x = convertToTensor(x, 'x', 'acos'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { + var a = $x.toFloat().square(); + var b = scalar(1).sub(a).sqrt(); + // tslint:disable-next-line: no-unnecessary-type-assertion + return dy.div(b).neg(); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.acos($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes atan of the input `tf.Tensor` element-wise: `atan(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.atan().print(); // or tf.atan(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function atan_(x) { + var $x = convertToTensor(x, 'x', 'atan'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.toFloat().square().add(1)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.atan($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.sinh().print(); // or tf.sinh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sinh_(x) { + var $x = convertToTensor(x, 'x', 'sinh'); + var grad = function (dy, saved) { + var $x = saved[0]; + // tslint:disable-next-line: no-unnecessary-type-assertion + return { $x: function () { return $x.toFloat().cosh().mul(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.sinh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.cosh().print(); // or tf.cosh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function cosh_(x) { + var $x = convertToTensor(x, 'x', 'cosh'); + var grad = function (dy, saved) { + var $x = saved[0]; + // tslint:disable-next-line: no-unnecessary-type-assertion + return { $x: function () { return $x.toFloat().sinh().mul(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.cosh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, 70]); + * + * x.tanh().print(); // or tf.tanh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function tanh_(x) { + var $x = convertToTensor(x, 'x', 'tanh'); + var grad = function (dy, saved) { + var y = saved[0]; + // tslint:disable-next-line: no-unnecessary-type-assertion + return { x: function () { return scalar(1).sub(y.square()).mul(dy); } }; + }; + var outputsToSave = [true]; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.tanh($x); + save([y]); + return y; + }, { x: $x }, grad, 'Tanh', {} /* attrs */, null /* inputsToSave */, outputsToSave); + } + /** + * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise: + * `asinh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.asinh().print(); // or tf.asinh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function asinh_(x) { + var $x = convertToTensor(x, 'x', 'asinh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { + var a = scalar(1).add($x.toFloat().square()).sqrt(); + // tslint:disable-next-line: no-unnecessary-type-assertion + return dy.div(a); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.asinh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise: + * `acosh(x)` + * + * ```js + * const x = tf.tensor1d([10, 1, 3, 5.7]); + * + * x.acosh().print(); // or tf.acosh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function acosh_(x) { + var $x = convertToTensor(x, 'x', 'acosh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { + var a = $x.toFloat().square().sub(1).sqrt(); + // tslint:disable-next-line: no-unnecessary-type-assertion + return dy.div(a); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.acosh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise: + * `atanh(x)` + * + * ```js + * const x = tf.tensor1d([0, .1, -.1, .7]); + * + * x.atanh().print(); // or tf.atanh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function atanh_(x) { + var $x = convertToTensor(x, 'x', 'atanh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div(scalar(1).sub($x.toFloat().square())); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.atanh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes gause error function of the input `tf.Tensor` element-wise: + * `erf(x)` + * + * ```js + * const x = tf.tensor1d([0, .1, -.1, .7]); + * + * x.erf().print(); // or tf.erf(x); + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function erf_(x) { + var $x = convertToTensor(x, 'x', 'erf'); + assert($x.dtype === 'int32' || $x.dtype === 'float32', function () { return 'Input dtype must be `int32` or `float32`.'; }); + if ($x.dtype === 'int32') { + $x = $x.toFloat(); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { return dy.mul($x.square().neg().exp().mul(2 / Math.sqrt(Math.PI))); } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.erf($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x` + * + * ```js + * const x = tf.tensor1d([0, 2, -1, -3]); + * + * x.step(.5).print(); // or tf.step(x, .5) + * ``` + * @param x The input tensor. + * @param alpha The gradient when input is negative. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function step_(x, alpha) { + if (alpha === void 0) { alpha = 0.0; } + var $x = convertToTensor(x, 'x', 'step'); + // TODO(manrajgrover): Return null for gradients when backprop supports + // it. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.step($x, alpha); }, { $x: $x }, grad); + } + var abs = op({ abs_: abs_ }); + var acos = op({ acos_: acos_ }); + var acosh = op({ acosh_: acosh_ }); + var asin = op({ asin_: asin_ }); + var asinh = op({ asinh_: asinh_ }); + var atan = op({ atan_: atan_ }); + var atanh = op({ atanh_: atanh_ }); + var ceil = op({ ceil_: ceil_ }); + var clipByValue = op({ clipByValue_: clipByValue_ }); + var cos = op({ cos_: cos_ }); + var cosh = op({ cosh_: cosh_ }); + var erf = op({ erf_: erf_ }); + var exp = op({ exp_: exp_ }); + var expm1 = op({ expm1_: expm1_ }); + var floor = op({ floor_: floor_ }); + var log = op({ log_: log_ }); + var log1p = op({ log1p_: log1p_ }); + var logSigmoid = op({ logSigmoid_: logSigmoid_ }); + var neg = op({ neg_: neg_ }); + var reciprocal = op({ reciprocal_: reciprocal_ }); + var round = op({ round_: round_ }); + var rsqrt = op({ rsqrt_: rsqrt_ }); + var sigmoid = op({ sigmoid_: sigmoid_ }); + var sign = op({ sign_: sign_ }); + var isNaN$1 = op({ isNaN_: isNaN_ }); + var isInf = op({ isInf_: isInf_ }); + var isFinite$1 = op({ isFinite_: isFinite_ }); + var sin = op({ sin_: sin_ }); + var sinh = op({ sinh_: sinh_ }); + var softplus = op({ softplus_: softplus_ }); + var sqrt = op({ sqrt_: sqrt_ }); + var step = op({ step_: step_ }); + var tan = op({ tan_: tan_ }); + var tanh$1 = op({ tanh_: tanh_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * @deprecated + * Adds two `tf.Tensor`s element-wise, A + B. + * + * Inputs must be the same shape. For broadcasting support, use add() instead. + * + * @param a The first Tensor to add element-wise. + * @param b The second Tensor to add element-wise. + */ + function addStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'addStrict'); + var $b = convertToTensor(b, 'b', 'addStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: '); + return $a.add($b); + } + /** + * @deprecated + * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must + * be the same shape. + * + * For broadcasting support, use `tf.sub` instead. + * + * @param a The first Tensor to subtract element-wise. + * @param b The second Tensor to subtract element-wise. + */ + function subStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'subStrict'); + var $b = convertToTensor(b, 'b', 'subStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: '); + return $a.sub($b); + } + /** + * Computes the power of one `tf.Tensor` to another. Supports broadcasting. + * + * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for + * corresponding elements in x and y. The result's dtype will be the upcasted + * type of the `base` and `exp` dtypes. + * + * ```js + * const a = tf.tensor([[2, 3], [4, 5]]) + * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); + * + * a.pow(b).print(); // or tf.pow(a, b) + * ``` + * + * ```js + * const a = tf.tensor([[1, 2], [3, 4]]) + * const b = tf.tensor(2).toInt(); + * + * a.pow(b).print(); // or tf.pow(a, b) + * ``` + * We also expose `powStrict` which has the same signature as this op and + * asserts that `base` and `exp` are the same shape (does not broadcast). + * + * @param base The base `tf.Tensor` to pow element-wise. + * @param exp The exponent `tf.Tensor` to pow element-wise. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function pow_(base, exp) { + var _a; + var $base = convertToTensor(base, 'base', 'pow'); + var $exp = convertToTensor(exp, 'exp', 'pow'); + _a = makeTypesMatch($base, $exp), $base = _a[0], $exp = _a[1]; + var outShape = assertAndGetBroadcastShape($base.shape, $exp.shape); + var grad = function (dy, saved) { + var $base = saved[0], $exp = saved[1], y = saved[2]; + var derBase = function () { + var expFloat = $exp.toFloat(); + var res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1))))); + var reduceAxes = getReductionAxes($base.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($base.shape); + }; + var derExp = function () { + var condition = $base.greater(0); + var logBase = $base.log().where(condition, zerosLike($base)); + var res = dy.mul(y.mul(logBase)); + var reduceAxes = getReductionAxes($exp.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($exp.shape); + }; + return { a: derBase, b: derExp }; + }; + var attrs = {}; + var inputsToSave = [$base, $exp]; + var outputsToSave = [true]; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.pow($base, $exp); + save([$base, $exp, y]); + return y; + }, { a: $base, b: $exp }, grad, 'Pow', attrs, inputsToSave, outputsToSave); + } + /** + * @deprecated + * Computes the power of one `tf.Tensor` to another. Inputs must + * be the same shape. + * + * For broadcasting support, use `tf.pow` instead. + * + * @param base The base tensor to pow element-wise. + * @param exp The exponent tensor to pow element-wise. + */ + function powStrict_(base, exp) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: '); + return base.pow(exp); + } + /** + * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting. + * + * We also expose `tf.mulStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3, 4]); + * const b = tf.tensor1d([2, 3, 4, 5]); + * + * a.mul(b).print(); // or tf.mul(a, b) + * ``` + * + * ```js + * // Broadcast mul a with b. + * const a = tf.tensor1d([1, 2, 3, 4]); + * const b = tf.scalar(5); + * + * a.mul(b).print(); // or tf.mul(a, b) + * ``` + * @param a The first tensor to multiply. + * @param b The second tensor to multiply. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function mul_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'mul'); + var $b = convertToTensor(b, 'b', 'mul'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var res = dy.mul($b.toFloat()); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($a.shape); + } + return res; + }; + var derB = function () { + var res = dy.mul($a.toFloat()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($b.shape); + } + return res; + }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.multiply($a, $b); + save([$a, $b]); + return res; + }, { a: $a, b: $b }, der, 'Mul'); + } + /** + * @deprecated + * Multiplies two `tf.Tensor`s element-wise, A * B. + * + * Inputs must be the same shape. For broadcasting support, use `tf.mul`. + * + * @param a The first tensor to multiply. + * @param b The first tensor to multiply. Must have the same + * dtype as `a`. + */ + function mulStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'mul'); + var $b = convertToTensor(b, 'b', 'mul'); + assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: '); + return $a.mul($b); + } + /** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. + * The result is rounded with floor function. + * + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.floorDiv(b).print(); // or tf.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * + * a.floorDiv(b).print(); // or tf.floorDiv(a, b) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function floorDiv_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'floorDiv'); + var $b = convertToTensor(b, 'b', 'floorDiv'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var res = dy.div($b.toFloat()); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($a.shape); + } + return res; + }; + var derB = function () { + var res = dy.mul($a.toFloat()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes).reshape($b.shape); + } + var tmp = $b.square(); + return res.div(tmp.toFloat()).neg(); + }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.floorDiv($a, $b); + save([$a, $b]); + return res; + }, { a: $a, b: $b }, der, 'FloorDiv'); + } + /** + * @deprecated + * Divides two `tf.Tensor`s element-wise, A / B. Inputs must + * be the same shape. + * + * @param a The first tensor as the numerator for element-wise division. + * @param b The second tensor as the denominator for element-wise division. + */ + function divStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'div'); + var $b = convertToTensor(b, 'b', 'div'); + assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: '); + return $a.div($b); + } + /** + * Returns the mod of a and b element-wise. + * `floor(x / y) * y + mod(x, y) = x` + * Supports broadcasting. + * + * We also expose `tf.modStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.mod(b).print(); // or tf.mod(a, b) + * ``` + * + * ```js + * // Broadcast a mod b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.mod(b).print(); // or tf.mod(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function mod_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'mod'); + var $b = convertToTensor(b, 'b', 'mod'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return dy.sum(reduceAxes).reshape($a.shape); + } + return dy; + }; + var derB = function () { + var res = dy.mul($a.div($b).floor().neg()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($b.shape); + } + return res; + }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.mod($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + /** + * @deprecated + * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must + * be the same shape. For broadcasting support, use mod(). + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + */ + function modStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'modStrict'); + var $b = convertToTensor(b, 'b', 'modStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: '); + return $a.mod($b); + } + /** + * Returns the min of a and b (`a < b ? a : b`) element-wise. + * Supports broadcasting. + * + * We also expose `minimumStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.minimum(b).print(); // or tf.minimum(a, b) + * ``` + * + * ```js + * // Broadcast minimum a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.minimum(b).print(); // or tf.minimum(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function minimum_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'minimum'); + var $b = convertToTensor(b, 'b', 'minimum'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + if ($a.dtype === 'bool') { + $a = $a.toInt(); + $b = $b.toInt(); + } + assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { return dy.mul($a.lessEqual($b).toFloat()); }; + var derB = function () { return dy.mul($a.greater($b).toFloat()); }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.minimum($a, $b); + save([$a, $b]); + return res; + }, { a: $a, b: $b }, der, 'Minimum'); + } + /** + * @deprecated + * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must + * be the same shape. For broadcasting support, use minimum(). + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + */ + function minimumStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'minimumStrict'); + var $b = convertToTensor(b, 'b', 'minimumStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: '); + return $a.minimum($b); + } + /** + * Returns the max of a and b (`a > b ? a : b`) element-wise. + * Supports broadcasting. + * + * We also expose `tf.maximumStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.maximum(b).print(); // or tf.maximum(a, b) + * ``` + * + * ```js + * // Broadcast maximum a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.maximum(b).print(); // or tf.maximum(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function maximum_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'maximum'); + var $b = convertToTensor(b, 'b', 'maximum'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + if ($a.dtype === 'bool') { + $a = $a.toInt(); + $b = $b.toInt(); + } + assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { return dy.mul($a.greaterEqual($b).toFloat()); }; + var derB = function () { return dy.mul($a.less($b).toFloat()); }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.maximum($a, $b); + save([$a, $b]); + return res; + }, { a: $a, b: $b }, der, 'Maximum'); + } + /** + * @deprecated + * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must + * be the same shape. For broadcasting support, use maximum(). + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + */ + function maximumStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'maximumStrict'); + var $b = convertToTensor(b, 'b', 'maximumStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: '); + return $a.maximum($b); + } + /** + * @deprecated + * Returns (a - b) * (a - b) element-wise. + * + * Inputs must be the same shape. For broadcasting support, use + * `tf.squaredDifference` instead. + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + function squaredDifferenceStrict_(a, b) { + deprecationWarn('strict variants of ops have been deprecated ' + + 'and will be removed in future'); + var $a = convertToTensor(a, 'a', 'squaredDifferenceStrict'); + var $b = convertToTensor(b, 'b', 'squaredDifferenceStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in squaredDifferenceStrict: '); + return $a.squaredDifference($b); + } + /** + * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`. + * Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]); + * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]); + * + * tf.atan2(a, b).print() + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + * + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function atan2_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'atan2'); + var $b = convertToTensor(b, 'b', 'atan2'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var d = add($a.square(), $b.square()); + var res = dy.mul($b.div(d)); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($a.shape); + }; + var derB = function () { + var d = add($a.square(), $b.square()); + var res = neg(dy.mul($a.div(d))); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($b.shape); + }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.atan2($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + var addStrict = op({ addStrict_: addStrict_ }); + var atan2 = op({ atan2_: atan2_ }); + var divStrict = op({ divStrict_: divStrict_ }); + var floorDiv = op({ floorDiv_: floorDiv_ }); + var maximum = op({ maximum_: maximum_ }); + var maximumStrict = op({ maximumStrict_: maximumStrict_ }); + var minimum = op({ minimum_: minimum_ }); + var minimumStrict = op({ minimumStrict_: minimumStrict_ }); + var mod = op({ mod_: mod_ }); + var modStrict = op({ modStrict_: modStrict_ }); + var mul = op({ mul_: mul_ }); + var mulStrict = op({ mulStrict_: mulStrict_ }); + var pow = op({ pow_: pow_ }); + var powStrict = op({ powStrict_: powStrict_ }); + var squaredDifferenceStrict = op({ squaredDifferenceStrict_: squaredDifferenceStrict_ }); + var subStrict = op({ subStrict_: subStrict_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function div_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'div'); + var $b = convertToTensor(b, 'b', 'div'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + if ($a.dtype === 'int32' && $b.dtype === 'int32') { + return floorDiv($a, $b); + } + var forward = function (backend, save) { + var res = backend.realDivide($a, $b); + save([$a, $b]); + return res; + }; + var inputs = { a: $a, b: $b }; + var attrs = {}; + return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Div, attrs); + } + var div = op({ div_: div_ }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes square of `x` element-wise: `x ^ 2` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]); + * + * x.square().print(); // or tf.square(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function square_(x) { + var $x = convertToTensor(x, 'x', 'square'); + var attrs = {}; + var inputsToSave = [$x]; + var outputsToSave = []; + return ENGINE.runKernelFunc(function (backend, save) { + save([$x]); + return backend.square($x); + }, { x: $x }, null /* grad */, 'Square', attrs, inputsToSave, outputsToSave); + } + var square = op({ square_: square_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var divGradConfig = { + kernelName: Div, + inputsToSave: ['a', 'b'], + gradFunc: function (dy, saved) { + var a = saved[0], b = saved[1]; + var outShape = assertAndGetBroadcastShape(a.shape, b.shape); + var derA = function () { + var res = div(dy, b.toFloat()); + var reduceAxes = getReductionAxes(a.shape, outShape); + if (reduceAxes.length > 0) { + return sum$1(res, reduceAxes).reshape(a.shape); + } + return res; + }; + var derB = function () { + var res = mul(dy, a.toFloat()); + var reduceAxes = getReductionAxes(b.shape, outShape); + if (reduceAxes.length > 0) { + res = reshape(sum$1(res, reduceAxes), b.shape); + } + var tmp = square(b); + return neg(div(res, tmp.toFloat())); + }; + return { a: derA, b: derB }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([10, 20, 30, 40]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.sub(b).print(); // or tf.sub(a, b) + * ``` + * + * ```js + * // Broadcast subtract a with b. + * const a = tf.tensor1d([10, 20, 30, 40]); + * const b = tf.scalar(5); + * + * a.sub(b).print(); // or tf.sub(a, b) + * ``` + * @param a The first `tf.Tensor` to subtract from. + * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function sub_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'sub'); + var $b = convertToTensor(b, 'b', 'sub'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var forward = function (backend, save) { + var res = backend.subtract($a, $b); + save([$a, $b]); + return res; + }; + var inputs = { a: $a, b: $b }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Sub); + } + var sub = op({ sub_: sub_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Construct a tensor by repeating it the number of times given by reps. + * + * This operation creates a new tensor by replicating `input` `reps` + * times. The output tensor's i'th dimension has `input.shape[i] * + * reps[i]` elements, and the values of `input` are replicated + * `reps[i]` times along the i'th dimension. For example, tiling + * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * + * a.tile([2]).print(); // or a.tile([2]) + * ``` + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.tile([1, 2]).print(); // or a.tile([1, 2]) + * ``` + * @param x The tensor to tile. + * @param reps Determines the number of replications per dimension. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function tile_(x, reps) { + var parseAs = null; + var $x = convertToTensor(x, 'x', 'tile', parseAs); + assert($x.rank === reps.length, function () { return "Error in transpose: rank of input " + $x.rank + " " + + ("must match length of reps " + reps + "."); }); + var forward = function (backend, save) { + var res = backend.tile($x, reps); + save([$x]); + return res; + }; + var inputsToSave = [$x]; + var inputs = { x: $x }; + var attrs = { reps: reps }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Tile, attrs, inputsToSave); + } + var tile = op({ tile_: tile_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var fusedBatchNormGradConfig = { + kernelName: FusedBatchNorm, + inputsToSave: ['x', 'mean', 'variance', 'scale'], + gradFunc: function (dy, saved, attrs) { + var varianceEpsilon = attrs.varianceEpsilon; + var x = saved[0], mean = saved[1], variance = saved[2], scale = saved[3]; + var scaleValue = scale == null ? scalar(1) : scale; + var reductionAxes = getReductionAxes(mean.shape, x.shape); + var tileShape = []; + if (mean.rank === 1) { + for (var i = 0; i < x.shape.length - 1; ++i) { + tileShape.push(x.shape[i]); + } + tileShape.push(1); + } + var xMinusMean = sub(x, mean); + var dyTimesScaleValue = mul(dy, scaleValue); + var oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon))); + var minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5)); + var derX = function () { + if (mean.rank === 1) { + return reshape(mul(mul(dy, tile(oneOverSqrtVariance.as4D(1, 1, 1, mean.shape[0]), tileShape)), scaleValue), x.shape); + } + else { + return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape); + } + }; + var derMean = function () { + var meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue); + if (mean.rank === 1) { + meanDer = sum$1(meanDer, reductionAxes); + } + return reshape(meanDer, mean.shape); + }; + var derVariance = function () { + var varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue); + if (mean.rank === 1) { + varianceDer = sum$1(varianceDer, reductionAxes); + } + return reshape(varianceDer, mean.shape); + }; + var derScale = function () { + var xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance); + var scaleDer = mul(dy, xMinusMean2TimesRsqrt); + if (mean.rank === 1) { + scaleDer = sum$1(scaleDer, reductionAxes); + } + return reshape(scaleDer, mean.shape); + }; + var derOffset = function () { + var offsetDer = dy; + if (mean.rank === 1) { + offsetDer = sum$1(offsetDer, reductionAxes); + } + return reshape(offsetDer, mean.shape); + }; + return { + x: derX, + mean: derMean, + variance: derVariance, + scale: derScale, + offset: derOffset + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var greaterEqualGradConfig = { + kernelName: GreaterEqual, + inputsToSave: ['a', 'b'], + gradFunc: function (dy, saved) { + var a = saved[0], b = saved[1]; + return { a: function () { return zerosLike(a); }, b: function () { return zerosLike(b); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var identityGradConfig = { + kernelName: Identity, + gradFunc: function (dy) { + return { x: function () { return dy.toFloat(); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function localResponseNormalizationBackprop_(x, y, dy, depthRadius, bias, alpha, beta) { + if (depthRadius === void 0) { depthRadius = 5; } + if (bias === void 0) { bias = 1; } + if (alpha === void 0) { alpha = 1; } + if (beta === void 0) { beta = 0.5; } + var forward = function (backend) { + return backend.LRNGrad(dy, x, y, depthRadius, bias, alpha, beta); + }; + var inputs = { x: x, y: y, dy: dy }; + var attrs = { depthRadius: depthRadius, bias: bias, alpha: alpha, beta: beta }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, LRNBackprop, attrs); + } + var localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_: localResponseNormalizationBackprop_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var lrnGradConfig = { + kernelName: LRN, + inputsToSave: ['x'], + outputsToSave: [true], + gradFunc: function (dy, saved, attrs) { + var _a = saved, x = _a[0], y = _a[1]; + var _b = attrs, depthRadius = _b.depthRadius, bias = _b.bias, alpha = _b.alpha, beta = _b.beta; + return { + x: function () { return localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta); } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var maxGradConfig = { + kernelName: Max, + inputsToSave: ['x'], + outputsToSave: [true], + gradFunc: function (dy, saved, attrs) { + var maxAttrs = attrs; + var reductionIndices = maxAttrs.reductionIndices; + var x = saved[0], y = saved[1]; + var origAxes = parseAxisParam(reductionIndices, x.shape); + var permutedAxes = getAxesPermutation(origAxes, x.rank); + var maxGrad = gradForMinAndMax(dy, y, x, origAxes, permutedAxes); + return { + x: function () { + var out = maxGrad['x'](); + if (permutedAxes != null) { + out = transpose(out); + } + return out; + } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the backprop of a 3d max pool. + * + * @param dy The dy error, of rank 5 of shape + * [batchSize, depth, height, width, channels]. + * assumed. + * @param input The original input image, of rank 5 or rank 4 of shape + * [batchSize, depth, height, width, channels]. + * @param output The original output image, of rank 5 of shape + * [batchSize, outDepth, outHeight, outWidth, channels]. + * @param filterSize The filter size: + * `[filterDepth, filterHeight, filterWidth]`. + * `filterSize` is a single number, + * then `filterDepth == filterHeight == filterWidth`. + * @param strides The strides of the pooling: + * `[strideDepth, strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dilations Deprecated, this field will be gone in v3.0.0. + * The dilation rates: `[dilationDepth, dilationHeight, dilationWidth]` + * in which we sample input values across the depth, height and width + * dimensions in dilated pooling. + * Defaults to `[1, 1, 1]`. If `dilations` is a single number, + * then `dilationDepth == dilationHeight == dilationWidth`. + * If it is greater than 1, then all values of `strides` must be 1. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function maxPool3dBackprop_(dy, input, output, filterSize, strides, dilations, pad, dimRoundingMode) { + if (dilations === void 0) { dilations = [1, 1, 1]; } + var $dy = convertToTensor(dy, 'dy', 'maxPool3dBackprop'); + var $input = convertToTensor(input, 'input', 'maxPool3dBackprop'); + var $output = convertToTensor(output, 'output', 'maxPool3dBackprop'); + var dy5D = $dy; + var input5D = $input; + var output5D = $output; + var reshapedTo5D = false; + if ($input.rank === 4) { + reshapedTo5D = true; + dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); + input5D = reshape($input, [ + 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] + ]); + output5D = reshape($output, [ + 1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3] + ]); + } + assert(dy5D.rank === 5, function () { return "Error in maxPool3dBackprop: dy must be rank 5 but got rank " + + (dy5D.rank + "."); }); + assert(input5D.rank === 5, function () { return "Error in maxPool3dBackprop: input must be rank 5 but got rank " + + (input5D.rank + "."); }); + assert(output5D.rank === 5, function () { return "Error in maxPool3dBackprop: output must be rank 5 but got rank " + + (output5D.rank + "."); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool3dBackprop: Either strides or dilations ' + + ("must be 1. Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPool3dBackprop: pad must be an integer when " + + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var forward = function (backend) { + var convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + return backend.maxPool3dBackprop(dy5D, input5D, output5D, convInfo); + }; + var inputs = { dy: dy5D, input: input5D, output: output5D }; + var attrs = { filterSize: filterSize, strides: strides, dilations: dilations, pad: pad, dimRoundingMode: dimRoundingMode }; + var res = ENGINE.runKernelFunc(forward, inputs, null /* grad */, MaxPool3DBackprop, attrs); + if (reshapedTo5D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); + } + return res; + } + var maxPool3dBackprop = op({ maxPool3dBackprop_: maxPool3dBackprop_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var maxPool3DGradConfig = { + kernelName: MaxPool3D, + inputsToSave: ['x'], + outputsToSave: [true], + gradFunc: function (dy, saved, attrs) { + var _a = saved, x = _a[0], y = _a[1]; + var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, dilations = _b.dilations, pad = _b.pad, dimRoundingMode = _b.dimRoundingMode; + var $dilations = dilations == null ? [1, 1, 1] : dilations; + return { + x: function () { return maxPool3dBackprop(dy, x, y, filterSize, strides, $dilations, pad, dimRoundingMode); } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the backprop of a 2D max pool. + * + * @param dy The dy error, of rank 4 or rank 3 of shape + * [batchSize, height, width, channels]. If rank 3, batch of 1 is + * assumed. + * @param input The original input image, of rank 4, of shape + * [batchSize, height, width, channels]. + * @param output The original output image, of rank 4, of shape + * [batchSize, outHeight, outWidth, channels]. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function maxPoolBackprop_(dy, input, output, filterSize, strides, pad, dimRoundingMode) { + var $dy = convertToTensor(dy, 'dy', 'maxPoolBackprop'); + var $input = convertToTensor(input, 'input', 'maxPoolBackprop'); + var $output = convertToTensor(output, 'output', 'maxPoolBackprop'); + assert($input.rank === $dy.rank, function () { return "Rank of input (" + $input.rank + ") does not match rank of dy " + + ("(" + $dy.rank + ")"); }); + assert($dy.rank === 4, function () { return "Error in maxPoolBackprop: dy must be rank 4 but got rank " + + ($dy.rank + "."); }); + assert($input.rank === 4, function () { return "Error in maxPoolBackprop: input must be rank 4 but got rank " + + ($input.rank + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPoolBackprop: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var forward = function (backend) { + var convInfo = computePool2DInfo($input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); + return backend.maxPoolBackprop($dy, $input, $output, convInfo); + }; + var inputs = { dy: $dy, input: $input, output: $output }; + var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode }; + return ENGINE.runKernelFunc(forward, inputs, null, MaxPoolBackprop, attrs); + } + var maxPoolBackprop = op({ maxPoolBackprop_: maxPoolBackprop_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var maxPoolGradConfig = { + kernelName: MaxPool, + inputsToSave: ['x'], + outputsToSave: [true], + gradFunc: function (dy, saved, attrs) { + var _a = saved, x = _a[0], y = _a[1]; + var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, pad = _b.pad; + return { + x: function () { return maxPoolBackprop(dy, x, y, filterSize, strides, pad); } + }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var oneHotGradConfig = { + kernelName: OneHot, + inputsToSave: ['indices'], + gradFunc: function (dy, saved) { + var indices = saved[0]; + return { indices: function () { return zeros(indices.shape, 'float32'); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var padV2GradConfig = { + kernelName: PadV2, + inputsToSave: ['x'], + gradFunc: function (dy, saved, attrs) { + // Pad introduces values around the original tensor, so the gradient + // slices the original shape out of the gradient. + var x = saved[0]; + var paddings = attrs.paddings; + var begin = paddings.map(function (p) { return p[0]; }); + return { x: function () { return dy.slice(begin, x.shape); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of + * shape `blockShape + [batch]`, interleaves these blocks back into the grid + * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with + * the same rank as the input. The spatial dimensions of this intermediate + * result are then optionally cropped according to `crops` to produce the + * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise + * description. + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]); + * const blockShape = [2, 2]; + * const crops = [[0, 0], [0, 0]]; + * + * x.batchToSpaceND(blockShape, crops).print(); + * ``` + * + * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + + * remainingShape`, where spatialShape has `M` dimensions. + * @param blockShape A 1-D array. Must have shape `[M]`, all values must + * be >= 1. + * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0. + * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input + * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required + * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]` + * + * This operation is equivalent to the following steps: + * + * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ..., + * blockShape[M-1], batch / prod(blockShape), x.shape[1], ..., + * x.shape[N-1]]` + * + * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch / + * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M], + * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` + * + * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch / + * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] * + * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` + * + * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted` + * according to `crops` to produce the output of shape: `[batch / + * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1], + * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] - + * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]` + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function batchToSpaceND_(x, blockShape, crops) { + var $x = convertToTensor(x, 'x', 'batchToSpaceND'); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + assert($x.rank >= 1 + blockShape.length, function () { return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length; }); + assert(crops.length === blockShape.length, function () { return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length; }); + assert($x.shape[0] % prod === 0, function () { return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " + + ("the elements of blockShape " + blockShape.join(' * ') + " === " + prod); }); + var forward = function (backend) { + return backend.batchToSpaceND($x, blockShape, crops); + }; + var inputs = { x: $x }; + var attrs = { blockShape: blockShape, crops: crops }; + return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, BatchToSpaceND, attrs); + } + var batchToSpaceND = op({ batchToSpaceND_: batchToSpaceND_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var spaceToBatchNDGradConfig = { + kernelName: SpaceToBatchND, + gradFunc: function (dy, saved, attrs) { + var _a = attrs, blockShape = _a.blockShape, paddings = _a.paddings; + return { x: function () { return batchToSpaceND(dy, blockShape, paddings); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var splitVGradConfig = { + kernelName: SplitV, + gradFunc: function (dy, saved, attrs) { + var axis = attrs.axis; + return { x: function () { return concat(dy, axis); } }; + } + }; + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var squareGradConfig = { + kernelName: Square, + inputsToSave: ['x'], + gradFunc: function (dy, saved) { + var x = saved[0]; + return { x: function () { return mul(dy, mul(x.toFloat(), 2)); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var squaredDifferenceGradConfig = { + kernelName: SquaredDifference, + inputsToSave: ['a', 'b'], + gradFunc: function (dy, saved) { + var a = saved[0], b = saved[1]; + var two = scalar(2); + var derA = function () { return mul(dy, mul(two, sub(a, b))); }; + var derB = function () { return mul(dy, mul(two, sub(b, a))); }; + return { a: derA, b: derB }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var subGradConfig = { + kernelName: Sub, + inputsToSave: ['a', 'b'], + gradFunc: function (dy, saved) { + var a = saved[0], b = saved[1]; + var outShape = assertAndGetBroadcastShape(a.shape, b.shape); + var derA = function () { + var res = dy; + var reduceAxes = getReductionAxes(a.shape, outShape); + if (reduceAxes.length > 0) { + res = sum$1(res, reduceAxes); + } + return reshape(res, a.shape); + }; + var derB = function () { + var res = dy; + var reduceAxes = getReductionAxes(b.shape, outShape); + if (reduceAxes.length > 0) { + res = sum$1(res, reduceAxes); + } + return reshape(neg(res), b.shape); + }; + return { a: derA, b: derB }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Pads a `tf.Tensor` with a given value and paddings. + * + * This operation currently only implements the `CONSTANT` mode. + * + * Also available are stricter rank-specific methods with the same signature + * as this method that assert that `paddings` is of given length. + * - `tf.pad1d` + * - `tf.pad2d` + * - `tf.pad3d` + * - `tf.pad4d` + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * x.pad([[1, 2]]).print(); + * ``` + * @param x The tensor to pad. + * @param paddings An array of length `R` (the rank of the tensor), where + * each element is a length-2 tuple of ints `[padBefore, padAfter]`, + * specifying how much to pad along each dimension of the tensor. + * @param constantValue The pad value to use. Defaults to 0. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function pad_(x, paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + var $x = convertToTensor(x, 'x', 'pad'); + if ($x.rank === 0) { + throw new Error('pad(scalar) is not defined. Pass non-scalar to pad'); + } + var forward = function (backend, save) { + save([$x]); + return backend.pad($x, paddings, constantValue); + }; + var attrs = { paddings: paddings, constantValue: constantValue }; + var inputs = { x: $x }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, PadV2, attrs); + } + var pad = op({ pad_: pad_ }); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function assertParamsValid(input, begin, size) { + assert(input.rank === begin.length, function () { return "Error in slice" + input.rank + "D: Length of begin " + begin + " must " + + ("match the rank of the array (" + input.rank + ")."); }); + assert(input.rank === size.length, function () { return "Error in slice" + input.rank + "D: Length of size " + size + " must " + + ("match the rank of the array (" + input.rank + ")."); }); + var _loop_1 = function (i) { + assert(begin[i] + size[i] <= input.shape[i], function () { return "Error in slice" + input.rank + "D: begin[" + i + "] + size[" + i + "] " + + ("(" + (begin[i] + size[i]) + ") would overflow input.shape[" + i + "] (" + input.shape[i] + ")"); }); + }; + for (var i = 0; i < input.rank; ++i) { + _loop_1(i); + } + } + /** Converts a binary mask to an array of axes. Used in stridedSlice(). */ + function maskToAxes(mask) { + var axes = []; + var axis = 0; + while (mask > 0) { + if (mask & 1) { + axes.push(axis); + } + mask /= 2; + axis++; + } + return axes; + } + /** Computes the output shape given the strided slice params. */ + function computeOutShape$1(begin, end, strides) { + var size = []; + for (var axis = 0; axis < begin.length; axis++) { + size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]); + } + return size; + } + // Creates full selection at the elided dimensions. If the dimension matches + // the ellipsis mask, override the current stride value. Otherwise, insert. + function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes) { + var newStrides = strides.slice(); + for (var i = 0; i < numElidedAxes; i++) { + if (i === 0) { + newStrides[ellipsisInsertionIndex] = 1; + } + else { + newStrides.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 1 /* element to add */); + newStrides.pop(); + } + } + return newStrides; + } + // Creates full selection at the elided dimensions. If the dimension matches + // the ellipsis mask, override the current start value. Otherwise, insert. + function startIndicesWithElidedDims(startIndices, ellipsisInsertionIndex, numElidedAxes) { + var newIndices = startIndices.slice(); + for (var i = 0; i < numElidedAxes; i++) { + if (i === 0) { + newIndices[ellipsisInsertionIndex] = 0; + } + else { + newIndices.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 0 /* element to add */); + newIndices.pop(); + } + } + return newIndices; + } + // Creates full selection at the elided dimensions. If the dimension matches + // the ellipsis mask, override the current stop value. Otherwise, insert. + function stopIndicesWithElidedDims(stopIndices, ellipsisInsertionIndex, numElidedAxes, inputShape) { + var newIndices = stopIndices.slice(); + for (var i = 0; i < numElidedAxes; i++) { + if (i === 0) { + newIndices[ellipsisInsertionIndex] = Number.MAX_SAFE_INTEGER; + } + else { + newIndices.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, Number.MAX_SAFE_INTEGER /* element to add */); + newIndices.pop(); + } + } + for (var i = 0; i < newIndices.length; i++) { + newIndices[i] = clamp(0, newIndices[i], inputShape[i]); + } + return newIndices; + } + function stridesForAxis(strides, axis, ellipsisMask) { + var stride = strides[axis]; + if (ellipsisMask & (1 << axis) || stride == null) { + stride = 1; + } + return stride; + } + function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) { + // Begin with the specified index + var start = startIndices[axis]; + var stride = strides[axis] || 1; + // Check the axis bit from right of masked axes, or the begin index is not set + // for the axis. + if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) { + if (stride > 0) { + // Forward iteration - use the first element. These values will get + // clamped below (Note: We could have set them to 0 and axis_size-1, but + // use lowest() and max() to maintain symmetry with StopForAxis()) + start = Number.MIN_SAFE_INTEGER; + } + else { + // Backward iteration - use the last element. + start = Number.MAX_SAFE_INTEGER; + } + } + // Handle negative indices + var axisSize = inputShape[axis]; + if (start < 0) { + start += axisSize; + } + // Clamping + start = clamp(0, start, axisSize - 1); + return start; + } + function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) { + // Begin with the specified index + var stop = stopIndices[axis]; + var stride = strides[axis] || 1; + // Check the axis bit from right of masked axes, or if the stop index is not + // set for this axis. + if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) { + if (stride > 0) { + // Forward iteration - use the last element. These values will get + // clamped below + stop = Number.MAX_SAFE_INTEGER; + } + else { + // Backward iteration - use the first element. + stop = Number.MIN_SAFE_INTEGER; + } + } + // Handle negative indices + var axisSize = inputShape[axis]; + if (stop < 0) { + stop += axisSize; + } + // Clamping + // Because the end index points one past the last element, we need slightly + // different clamping ranges depending on the direction. + if (stride > 0) { + // Forward iteration + stop = clamp(0, stop, axisSize); + } + else { + // Backward iteration + stop = clamp(-1, stop, axisSize - 1); + } + return stop; + } + /** + * Returns true if the slice occupies a continous set of elements in the + * 'flat' space. + */ + function isSliceContinous(shape, begin, size) { + // Index of the first axis that has size > 1. + var firstNonOneAxis = size.length; + for (var i = 0; i < size.length; i++) { + if (size[i] > 1) { + firstNonOneAxis = i; + break; + } + } + for (var i = firstNonOneAxis + 1; i < size.length; i++) { + if (begin[i] > 0 || size[i] !== shape[i]) { + return false; + } + } + return true; + } + function computeFlatOffset(begin, strides) { + var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1; + for (var i = 0; i < begin.length - 1; i++) { + flatOffset += begin[i] * strides[i]; + } + return flatOffset; + } + + var slice_util = { + __proto__: null, + assertParamsValid: assertParamsValid, + maskToAxes: maskToAxes, + computeOutShape: computeOutShape$1, + stridesWithElidedDims: stridesWithElidedDims, + startIndicesWithElidedDims: startIndicesWithElidedDims, + stopIndicesWithElidedDims: stopIndicesWithElidedDims, + stridesForAxis: stridesForAxis, + startForAxis: startForAxis, + stopForAxis: stopForAxis, + isSliceContinous: isSliceContinous, + computeFlatOffset: computeFlatOffset + }; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Extracts a 1D slice from 1D array starting at coordinates `begin` and is + * of length `size`. See `slice` for details. + */ + function slice1d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice1d'); + assert($x.rank === 1, function () { + return "slice1d expects a rank-1 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, [begin], [size]); + } + /** + * Extracts a 2D slice from a 2D array starting at coordinates `begin` and + * is of size `size`. See `slice` for details. + */ + function slice2d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice2d'); + assert($x.rank === 2, function () { + return "slice2d expects a rank-2 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, begin, size); + } + /** + * Extracts a 3D slice from a 3D array starting at coordinates `begin` and + * is of size `size`. See `slice` for details. + */ + function slice3d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice3d'); + assert($x.rank === 3, function () { + return "slice3d expects a rank-3 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, begin, size); + } + /** + * Extracts a 4D slice from a 4D array starting at coordinates `begin` and + * is of size `size`. See `slice` for details. + */ + function slice4d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice4d'); + assert($x.rank === 4, function () { + return "slice4d expects a rank-4 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, begin, size); + } + /** + * Extracts a slice from a `tf.Tensor` starting at coordinates `begin` + * and is of size `size`. + * + * Also available are stricter rank-specific methods with the same signature + * as this method that assert that `x` is of the given rank: + * - `tf.slice1d` + * - `tf.slice2d` + * - `tf.slice3d` + * - `tf.slice4d` + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * + * x.slice([1], [2]).print(); + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * x.slice([1, 0], [1, 2]).print(); + * ``` + * @param x The input `tf.Tensor` to slice from. + * @param begin The coordinates to start the slice from. The length can be + * less than the rank of x - the rest of the axes will have implicit 0 as + * start. Can also be a single number, in which case it specifies the + * first axis. + * @param size The size of the slice. The length can be less than the rank of + * x - the rest of the axes will have implicit -1. A value of -1 requests + * the rest of the dimensions in the axis. Can also be a single number, + * in which case it specifies the size of the first axis. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function slice_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice'); + if ($x.rank === 0) { + throw new Error('Slicing scalar is not possible'); + } + // The following logic allows for more ergonomic calls. + var begin_; + if (typeof begin === 'number') { + begin_ = [begin].concat(new Array($x.rank - 1).fill(0)); + } + else if (begin.length < $x.rank) { + begin_ = begin.concat(new Array($x.rank - begin.length).fill(0)); + } + else { + begin_ = begin.slice(); + } + begin_.forEach(function (d) { + assert(d !== -1, function () { return 'slice() does not support negative begin indexing.'; }); + }); + var size_; + if (size == null) { + size_ = new Array($x.rank).fill(-1); + } + else if (typeof size === 'number') { + size_ = [size].concat(new Array($x.rank - 1).fill(-1)); + } + else if (size.length < $x.rank) { + size_ = size.concat(new Array($x.rank - size.length).fill(-1)); + } + else { + size_ = size; + } + size_ = size_.map(function (d, i) { + if (d >= 0) { + return d; + } + else { + assert(d === -1, function () { return "Negative size values should be exactly -1 but got " + + (d + " for the slice() size at index " + i + "."); }); + return $x.shape[i] - begin_[i]; + } + }); + assertParamsValid($x, begin_, size_); + var inputShape = $x.shape; + var grad = function (dy) { + // Create an Nx2 padding where the first column represents how many + // zeros are prepended (at start) for each dimension, and the second + // column indicates how many zeros are appended (at end). + // The number of zeros to append is the shape of the input + // elementwise-subtracted by both the begin vector and sizes vector. + var paddings = []; + for (var i = 0; i < dy.rank; i++) { + paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]); + } + return { x: function () { return pad(dy, paddings); } }; + }; + var attrs = { begin: begin_, size: size_ }; + return ENGINE.runKernelFunc(function (backend) { return backend.slice($x, begin_, size_); }, { x: $x }, grad, 'Slice', attrs); + } + var slice = op({ slice_: slice_ }); + var slice1d = op({ slice1d_: slice1d_ }); + var slice2d = op({ slice2d_: slice2d_ }); + var slice3d = op({ slice3d_: slice3d_ }); + var slice4d = op({ slice4d_: slice4d_ }); + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var tileGradConfig = { + kernelName: Tile, + inputsToSave: ['x'], + gradFunc: function (dy, saved, attrs) { + var x = saved[0]; + var reps = attrs.reps; + var derX = function () { + var xGrad = zerosLike(x); + // TODO(cais): Maybe reduce memory footprint by avoiding repeated + // slicing. + if (x.rank === 1) { + for (var i = 0; i < reps[0]; ++i) { + xGrad = add(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]])); + } + } + else if (x.rank === 2) { + for (var i = 0; i < reps[0]; ++i) { + for (var j = 0; j < reps[1]; ++j) { + xGrad = add(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [ + x.shape[0], x.shape[1] + ])); + } + } + } + else if (x.rank === 3) { + for (var i = 0; i < reps[0]; ++i) { + for (var j = 0; j < reps[1]; ++j) { + for (var k = 0; k < reps[2]; ++k) { + xGrad = + add(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]])); + } + } + } + } + else if (x.rank === 4) { + for (var i = 0; i < reps[0]; ++i) { + for (var j = 0; j < reps[1]; ++j) { + for (var k = 0; k < reps[2]; ++k) { + for (var l = 0; l < reps[3]; ++l) { + xGrad = + add(xGrad, slice(dy, [ + i * x.shape[0], j * x.shape[1], k * x.shape[2], + l * x.shape[3] + ], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]])); + } + } + } + } + } + else { + throw new Error("Gradient for tile operation is not implemented for rank-" + + (x.rank + " tensors yet.")); + } + return xGrad; + }; + return { x: derX }; + }, + }; + + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var transposeGradConfig = { + kernelName: Transpose, + gradFunc: function (dy, saved, attrs) { + var transposeAttrs = attrs; + var perm = transposeAttrs.perm; + var undoPerm = getUndoAxesPermutation(perm); + return { x: function () { return transpose(dy, undoPerm); } }; + } + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Export all kernel configs here so that the package can auto register them + var gradConfigs = [ + addGradConfig, + addNGradConfig, + avgPoolGradConfig, + avgPool3DGradConfig, + batchMatMulGradConfig, + batchToSpaceNDGradConfig, + broadcastToGradConfig, + concatGradConfig, + conv2DGradConfig, + conv2DBackpropInputGradConfig, + conv3DGradConfig, + cumsumGradConfig, + depthwiseConv2dNativeGradConfig, + divGradConfig, + fusedBatchNormGradConfig, + greaterEqualGradConfig, + identityGradConfig, + lrnGradConfig, + oneHotGradConfig, + padV2GradConfig, + splitVGradConfig, + maxGradConfig, + spaceToBatchNDGradConfig, + maxGradConfig, + maxPoolGradConfig, + maxPool3DGradConfig, + oneHotGradConfig, + padV2GradConfig, + spaceToBatchNDGradConfig, + splitVGradConfig, + squareGradConfig, + squaredDifferenceGradConfig, + tileGradConfig, + transposeGradConfig, + subGradConfig + ]; + for (var _i = 0, gradConfigs_1 = gradConfigs; _i < gradConfigs_1.length; _i++) { + var gradientConfig = gradConfigs_1[_i]; + registerGradient(gradientConfig); + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PlatformBrowser = /** @class */ (function () { + function PlatformBrowser() { + } + PlatformBrowser.prototype.fetch = function (path, init) { + return fetch(path, init); + }; + PlatformBrowser.prototype.now = function () { + return performance.now(); + }; + PlatformBrowser.prototype.encode = function (text, encoding) { + if (encoding !== 'utf-8' && encoding !== 'utf8') { + throw new Error("Browser's encoder only supports utf-8, but got " + encoding); + } + if (this.textEncoder == null) { + this.textEncoder = new TextEncoder(); + } + return this.textEncoder.encode(text); + }; + PlatformBrowser.prototype.decode = function (bytes, encoding) { + return new TextDecoder(encoding).decode(bytes); + }; + return PlatformBrowser; + }()); + if (env().get('IS_BROWSER')) { + env().setPlatform('browser', new PlatformBrowser()); + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // We are wrapping this within an object so it can be stubbed by Jasmine. + var getNodeFetch = { + // tslint:disable-next-line:no-require-imports + importFetch: function () { return require('node-fetch'); } + }; + var systemFetch; + var PlatformNode = /** @class */ (function () { + function PlatformNode() { + // tslint:disable-next-line:no-require-imports + this.util = require('util'); + // According to the spec, the built-in encoder can do only UTF-8 encoding. + // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder + this.textEncoder = new this.util.TextEncoder(); + } + PlatformNode.prototype.fetch = function (path, requestInits) { + if (env().global.fetch != null) { + return env().global.fetch(path, requestInits); + } + if (systemFetch == null) { + systemFetch = getNodeFetch.importFetch(); + } + return systemFetch(path, requestInits); + }; + PlatformNode.prototype.now = function () { + var time = process.hrtime(); + return time[0] * 1000 + time[1] / 1000000; + }; + PlatformNode.prototype.encode = function (text, encoding) { + if (encoding !== 'utf-8' && encoding !== 'utf8') { + throw new Error("Node built-in encoder only supports utf-8, but got " + encoding); + } + return this.textEncoder.encode(text); + }; + PlatformNode.prototype.decode = function (bytes, encoding) { + if (bytes.length === 0) { + return ''; + } + return new this.util.TextDecoder(encoding).decode(bytes); + }; + return PlatformNode; + }()); + if (env().get('IS_NODE')) { + env().setPlatform('node', new PlatformNode()); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /* Type definitions for exporting and importing of models. */ + /** + * A map from Tensor dtype to number of bytes per element of the Tensor. + */ + var DTYPE_VALUE_SIZE_MAP = { + 'float32': 4, + 'int32': 4, + 'uint16': 2, + 'uint8': 1, + 'bool': 1, + }; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** Number of bytes reserved for the length of the string. (32bit integer). */ + var NUM_BYTES_STRING_LENGTH = 4; + /** + * Encode a map from names to weight values as an ArrayBuffer, along with an + * `Array` of `WeightsManifestEntry` as specification of the encoded weights. + * + * This function does not perform sharding. + * + * This function is the reverse of `decodeWeights`. + * + * @param tensors A map ("dict") from names to tensors. + * @param group Group to which the weights belong (optional). + * @returns A `Promise` of + * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s + * concatenated. + * - An `Array` of `WeightManifestEntry`s, carrying information including + * tensor names, `dtype`s and shapes. + * @throws Error: on unsupported tensor `dtype`. + */ + function encodeWeights(tensors, group) { + return __awaiter(this, void 0, void 0, function () { + var specs, dataPromises, names, _loop_1, i, tensorValues; + var _this = this; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + specs = []; + dataPromises = []; + names = Array.isArray(tensors) ? + tensors.map(function (tensor) { return tensor.name; }) : + Object.keys(tensors); + _loop_1 = function (i) { + var name_1 = names[i]; + var t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name_1]; + if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && + t.dtype !== 'string') { + throw new Error("Unsupported dtype in weight '" + name_1 + "': " + t.dtype); + } + var spec = { name: name_1, shape: t.shape, dtype: t.dtype }; + if (t.dtype === 'string') { + var utf8bytes = new Promise(function (resolve) { return __awaiter(_this, void 0, void 0, function () { + var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, t.bytes()]; + case 1: + vals = _a.sent(); + totalNumBytes = vals.reduce(function (p, c) { return p + c.length; }, 0) + + NUM_BYTES_STRING_LENGTH * vals.length; + bytes = new Uint8Array(totalNumBytes); + offset = 0; + for (i_1 = 0; i_1 < vals.length; i_1++) { + val = vals[i_1]; + bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); + bytes.set(bytesOfLength, offset); + offset += NUM_BYTES_STRING_LENGTH; + bytes.set(val, offset); + offset += val.length; + } + resolve(bytes); + return [2 /*return*/]; + } + }); + }); }); + dataPromises.push(utf8bytes); + } + else { + dataPromises.push(t.data()); + } + if (group != null) { + spec.group = group; + } + specs.push(spec); + }; + for (i = 0; i < names.length; ++i) { + _loop_1(i); + } + return [4 /*yield*/, Promise.all(dataPromises)]; + case 1: + tensorValues = _a.sent(); + return [2 /*return*/, { data: concatenateTypedArrays(tensorValues), specs: specs }]; + } + }); + }); + } + /** + * Decode flat ArrayBuffer as weights. + * + * This function does not handle sharding. + * + * This function is the reverse of `encodeWeights`. + * + * @param buffer A flat ArrayBuffer carrying the binary values of the tensors + * concatenated in the order specified in `specs`. + * @param specs Specifications of the names, dtypes and shapes of the tensors + * whose value are encoded by `buffer`. + * @return A map from tensor name to tensor value, with the names corresponding + * to names in `specs`. + * @throws Error, if any of the tensors has unsupported dtype. + */ + function decodeWeights(buffer, specs) { + // TODO(adarob, cais): Support quantization. + var out = {}; + var offset = 0; + for (var _i = 0, specs_1 = specs; _i < specs_1.length; _i++) { + var spec = specs_1[_i]; + var name_2 = spec.name; + var dtype = spec.dtype; + var shape = spec.shape; + var size = sizeFromShape(shape); + var values = void 0; + if ('quantization' in spec) { + var quantization = spec.quantization; + if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { + throw new Error("Weight " + spec.name + " has unknown " + + ("quantization dtype " + quantization.dtype + ". ") + + "Supported quantization dtypes are: 'uint8' and 'uint16'."); + } + var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + var byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor); + var quantizedArray = (quantization.dtype === 'uint8') ? + new Uint8Array(byteBuffer) : + new Uint16Array(byteBuffer); + if (dtype === 'float32') { + values = new Float32Array(quantizedArray.length); + for (var i = 0; i < quantizedArray.length; i++) { + var v = quantizedArray[i]; + values[i] = v * quantization.scale + quantization.min; + } + } + else if (dtype === 'int32') { + values = new Int32Array(quantizedArray.length); + for (var i = 0; i < quantizedArray.length; i++) { + var v = quantizedArray[i]; + values[i] = Math.round(v * quantization.scale + quantization.min); + } + } + else { + throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); + } + offset += size * quantizationSizeFactor; + } + else if (dtype === 'string') { + var size_1 = sizeFromShape(spec.shape); + values = []; + for (var i = 0; i < size_1; i++) { + var byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + offset += NUM_BYTES_STRING_LENGTH; + var bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); + values.push(bytes); + offset += byteLength; + } + } + else { + var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; + var byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); + if (dtype === 'float32') { + values = new Float32Array(byteBuffer); + } + else if (dtype === 'int32') { + values = new Int32Array(byteBuffer); + } + else if (dtype === 'bool') { + values = new Uint8Array(byteBuffer); + } + else { + throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); + } + offset += size * dtypeFactor; + } + out[name_2] = tensor(values, shape, dtype); + } + return out; + } + /** + * Concatenate TypedArrays into an ArrayBuffer. + */ + function concatenateTypedArrays(xs) { + // TODO(adarob, cais): Support quantization. + if (xs === null) { + throw new Error("Invalid input value: " + JSON.stringify(xs)); + } + var totalByteLength = 0; + // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer' + // can have a different byte length from that of the `TypedArray` itself, + // for example, when the `TypedArray` is created from an offset in an + // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match + // the `TypedArray` in byte length. If an element of `xs` does not show + // this property, a new `TypedArray` that satisfy this property will be + // constructed and pushed into `normalizedXs`. + var normalizedXs = []; + xs.forEach(function (x) { + totalByteLength += x.byteLength; + // tslint:disable:no-any + normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : + new x.constructor(x)); + if (!(x instanceof Float32Array || x instanceof Int32Array || + x instanceof Uint8Array)) { + throw new Error("Unsupported TypedArray subtype: " + x.constructor.name); + } + // tslint:enable:no-any + }); + var y = new Uint8Array(totalByteLength); + var offset = 0; + normalizedXs.forEach(function (x) { + y.set(new Uint8Array(x.buffer), offset); + offset += x.byteLength; + }); + return y.buffer; + } + // Use Buffer on Node.js instead of Blob/atob/btoa + var useNodeBuffer = typeof Buffer !== 'undefined' && + (typeof Blob === 'undefined' || typeof atob === 'undefined' || + typeof btoa === 'undefined'); + /** + * Calculate the byte length of a JavaScript string. + * + * Note that a JavaScript string can contain wide characters, therefore the + * length of the string is not necessarily equal to the byte length. + * + * @param str Input string. + * @returns Byte length. + */ + function stringByteLength(str) { + if (useNodeBuffer) { + return Buffer.byteLength(str); + } + return new Blob([str]).size; + } + /** + * Encode an ArrayBuffer as a base64 encoded string. + * + * @param buffer `ArrayBuffer` to be converted. + * @returns A string that base64-encodes `buffer`. + */ + function arrayBufferToBase64String(buffer) { + if (useNodeBuffer) { + return Buffer.from(buffer).toString('base64'); + } + var buf = new Uint8Array(buffer); + var s = ''; + for (var i = 0, l = buf.length; i < l; i++) { + s += String.fromCharCode(buf[i]); + } + return btoa(s); + } + /** + * Decode a base64 string as an ArrayBuffer. + * + * @param str Base64 string. + * @returns Decoded `ArrayBuffer`. + */ + function base64StringToArrayBuffer(str) { + if (useNodeBuffer) { + var buf = Buffer.from(str, 'base64'); + return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); + } + var s = atob(str); + var buffer = new Uint8Array(s.length); + for (var i = 0; i < s.length; ++i) { + buffer.set([s.charCodeAt(i)], i); + } + return buffer.buffer; + } + /** + * Concatenate a number of ArrayBuffers into one. + * + * @param buffers A number of array buffers to concatenate. + * @returns Result of concatenating `buffers` in order. + */ + function concatenateArrayBuffers(buffers) { + if (buffers.length === 1) { + return buffers[0]; + } + var totalByteLength = 0; + buffers.forEach(function (buffer) { + totalByteLength += buffer.byteLength; + }); + var temp = new Uint8Array(totalByteLength); + var offset = 0; + buffers.forEach(function (buffer) { + temp.set(new Uint8Array(buffer), offset); + offset += buffer.byteLength; + }); + return temp.buffer; + } + /** + * Get the basename of a path. + * + * Behaves in a way analogous to Linux's basename command. + * + * @param path + */ + function basename(path) { + var SEPARATOR = '/'; + path = path.trim(); + while (path.endsWith(SEPARATOR)) { + path = path.slice(0, path.length - 1); + } + var items = path.split(SEPARATOR); + return items[items.length - 1]; + } + /** + * Populate ModelArtifactsInfo fields for a model with JSON topology. + * @param modelArtifacts + * @returns A ModelArtifactsInfo object. + */ + function getModelArtifactsInfoForJSON(modelArtifacts) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('Expected JSON model topology, received ArrayBuffer.'); + } + return { + dateSaved: new Date(), + modelTopologyType: 'JSON', + modelTopologyBytes: modelArtifacts.modelTopology == null ? + 0 : + stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), + weightSpecsBytes: modelArtifacts.weightSpecs == null ? + 0 : + stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), + weightDataBytes: modelArtifacts.weightData == null ? + 0 : + modelArtifacts.weightData.byteLength, + }; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var IORouterRegistry = /** @class */ (function () { + function IORouterRegistry() { + this.saveRouters = []; + this.loadRouters = []; + } + IORouterRegistry.getInstance = function () { + if (IORouterRegistry.instance == null) { + IORouterRegistry.instance = new IORouterRegistry(); + } + return IORouterRegistry.instance; + }; + /** + * Register a save-handler router. + * + * @param saveRouter A function that maps a URL-like string onto an instance + * of `IOHandler` with the `save` method defined or `null`. + */ + IORouterRegistry.registerSaveRouter = function (saveRouter) { + IORouterRegistry.getInstance().saveRouters.push(saveRouter); + }; + /** + * Register a load-handler router. + * + * @param loadRouter A function that maps a URL-like string onto an instance + * of `IOHandler` with the `load` method defined or `null`. + */ + IORouterRegistry.registerLoadRouter = function (loadRouter) { + IORouterRegistry.getInstance().loadRouters.push(loadRouter); + }; + /** + * Look up IOHandler for saving, given a URL-like string. + * + * @param url + * @returns If only one match is found, an instance of IOHandler with the + * `save` method defined. If no match is found, `null`. + * @throws Error, if more than one match is found. + */ + IORouterRegistry.getSaveHandlers = function (url) { + return IORouterRegistry.getHandlers(url, 'save'); + }; + /** + * Look up IOHandler for loading, given a URL-like string. + * + * @param url + * @param loadOptions Optional, custom load options. + * @returns All valid handlers for `url`, given the currently registered + * handler routers. + */ + IORouterRegistry.getLoadHandlers = function (url, loadOptions) { + return IORouterRegistry.getHandlers(url, 'load', loadOptions); + }; + IORouterRegistry.getHandlers = function (url, handlerType, loadOptions) { + var validHandlers = []; + var routers = handlerType === 'load' ? + IORouterRegistry.getInstance().loadRouters : + IORouterRegistry.getInstance().saveRouters; + routers.forEach(function (router) { + var handler = router(url, loadOptions); + if (handler !== null) { + validHandlers.push(handler); + } + }); + return validHandlers; + }; + return IORouterRegistry; + }()); + var registerSaveRouter = function (loudRouter) { + return IORouterRegistry.registerSaveRouter(loudRouter); + }; + var registerLoadRouter = function (loudRouter) { + return IORouterRegistry.registerLoadRouter(loudRouter); + }; + var getSaveHandlers = function (url) { + return IORouterRegistry.getSaveHandlers(url); + }; + var getLoadHandlers = function (url, loadOptions) { + return IORouterRegistry.getLoadHandlers(url, loadOptions); + }; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var URL_SCHEME_SUFFIX = '://'; + var ModelStoreManagerRegistry = /** @class */ (function () { + function ModelStoreManagerRegistry() { + this.managers = {}; + } + ModelStoreManagerRegistry.getInstance = function () { + if (ModelStoreManagerRegistry.instance == null) { + ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry(); + } + return ModelStoreManagerRegistry.instance; + }; + /** + * Register a save-handler router. + * + * @param saveRouter A function that maps a URL-like string onto an instance + * of `IOHandler` with the `save` method defined or `null`. + */ + ModelStoreManagerRegistry.registerManager = function (scheme, manager) { + assert(scheme != null, function () { return 'scheme must not be undefined or null.'; }); + if (scheme.endsWith(URL_SCHEME_SUFFIX)) { + scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX)); + } + assert(scheme.length > 0, function () { return 'scheme must not be an empty string.'; }); + var registry = ModelStoreManagerRegistry.getInstance(); + assert(registry.managers[scheme] == null, function () { return "A model store manager is already registered for scheme '" + scheme + "'."; }); + registry.managers[scheme] = manager; + }; + ModelStoreManagerRegistry.getManager = function (scheme) { + var manager = this.getInstance().managers[scheme]; + if (manager == null) { + throw new Error("Cannot find model manager for scheme '" + scheme + "'"); + } + return manager; + }; + ModelStoreManagerRegistry.getSchemes = function () { + return Object.keys(this.getInstance().managers); + }; + return ModelStoreManagerRegistry; + }()); + /** + * Helper method for parsing a URL string into a scheme and a path. + * + * @param url E.g., 'localstorage://my-model' + * @returns A dictionary with two fields: scheme and path. + * Scheme: e.g., 'localstorage' in the example above. + * Path: e.g., 'my-model' in the example above. + */ + function parseURL(url) { + if (url.indexOf(URL_SCHEME_SUFFIX) === -1) { + throw new Error("The url string provided does not contain a scheme. " + + "Supported schemes are: " + + ("" + ModelStoreManagerRegistry.getSchemes().join(','))); + } + return { + scheme: url.split(URL_SCHEME_SUFFIX)[0], + path: url.split(URL_SCHEME_SUFFIX)[1], + }; + } + function cloneModelInternal(sourceURL, destURL, deleteSource) { + if (deleteSource === void 0) { deleteSource = false; } + return __awaiter(this, void 0, void 0, function () { + var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + assert(sourceURL !== destURL, function () { return "Old path and new path are the same: '" + sourceURL + "'"; }); + loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL); + assert(loadHandlers.length > 0, function () { return "Copying failed because no load handler is found for source URL " + sourceURL + "."; }); + assert(loadHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " + + ("load handlers for source URL " + sourceURL + "."); }); + loadHandler = loadHandlers[0]; + saveHandlers = IORouterRegistry.getSaveHandlers(destURL); + assert(saveHandlers.length > 0, function () { return "Copying failed because no save handler is found for destination " + + ("URL " + destURL + "."); }); + assert(saveHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " + + ("save handlers for destination URL " + destURL + "."); }); + saveHandler = saveHandlers[0]; + sourceScheme = parseURL(sourceURL).scheme; + sourcePath = parseURL(sourceURL).path; + sameMedium = sourceScheme === parseURL(sourceURL).scheme; + return [4 /*yield*/, loadHandler.load()]; + case 1: + modelArtifacts = _a.sent(); + if (!(deleteSource && sameMedium)) return [3 /*break*/, 3]; + return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme) + .removeModel(sourcePath)]; + case 2: + _a.sent(); + _a.label = 3; + case 3: return [4 /*yield*/, saveHandler.save(modelArtifacts)]; + case 4: + saveResult = _a.sent(); + if (!(deleteSource && !sameMedium)) return [3 /*break*/, 6]; + return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme) + .removeModel(sourcePath)]; + case 5: + _a.sent(); + _a.label = 6; + case 6: return [2 /*return*/, saveResult.modelArtifactsInfo]; + } + }); + }); + } + /** + * List all models stored in registered storage mediums. + * + * For a web browser environment, the registered mediums are Local Storage and + * IndexedDB. + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Delete the model. + * await tf.io.removeModel('localstorage://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * ``` + * + * @returns A `Promise` of a dictionary mapping URLs of existing models to + * their model artifacts info. URLs include medium-specific schemes, e.g., + * 'indexeddb://my/model/1'. Model artifacts info include type of the + * model's topology, byte sizes of the topology, weights, etc. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function listModels() { + return __awaiter(this, void 0, void 0, function () { + var schemes, out, _i, schemes_1, scheme, schemeOut, path, url; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + schemes = ModelStoreManagerRegistry.getSchemes(); + out = {}; + _i = 0, schemes_1 = schemes; + _a.label = 1; + case 1: + if (!(_i < schemes_1.length)) return [3 /*break*/, 4]; + scheme = schemes_1[_i]; + return [4 /*yield*/, ModelStoreManagerRegistry.getManager(scheme).listModels()]; + case 2: + schemeOut = _a.sent(); + for (path in schemeOut) { + url = scheme + URL_SCHEME_SUFFIX + path; + out[url] = schemeOut[path]; + } + _a.label = 3; + case 3: + _i++; + return [3 /*break*/, 1]; + case 4: return [2 /*return*/, out]; + } + }); + }); + } + /** + * Remove a model specified by URL from a reigstered storage medium. + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Delete the model. + * await tf.io.removeModel('localstorage://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * ``` + * + * @param url A URL to a stored model, with a scheme prefix, e.g., + * 'localstorage://my-model-1', 'indexeddb://my/model/2'. + * @returns ModelArtifactsInfo of the deleted model (if and only if deletion + * is successful). + * @throws Error if deletion fails, e.g., if no model exists at `path`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function removeModel(url) { + return __awaiter(this, void 0, void 0, function () { + var schemeAndPath, manager; + return __generator(this, function (_a) { + schemeAndPath = parseURL(url); + manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme); + return [2 /*return*/, manager.removeModel(schemeAndPath.path)]; + }); + }); + } + /** + * Copy a model from one URL to another. + * + * This function supports: + * + * 1. Copying within a storage medium, e.g., + * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')` + * 2. Copying between two storage mediums, e.g., + * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')` + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Copy the model, from Local Storage to IndexedDB. + * await tf.io.copyModel( + * 'localstorage://demo/management/model1', + * 'indexeddb://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Remove both models. + * await tf.io.removeModel('localstorage://demo/management/model1'); + * await tf.io.removeModel('indexeddb://demo/management/model1'); + * ``` + * + * @param sourceURL Source URL of copying. + * @param destURL Destination URL of copying. + * @returns ModelArtifactsInfo of the copied model (if and only if copying + * is successful). + * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or + * if `oldPath` and `newPath` are identical. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function copyModel(sourceURL, destURL) { + return __awaiter(this, void 0, void 0, function () { + var deleteSource; + return __generator(this, function (_a) { + deleteSource = false; + return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)]; + }); + }); + } + /** + * Move a model from one URL to another. + * + * This function supports: + * + * 1. Moving within a storage medium, e.g., + * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')` + * 2. Moving between two storage mediums, e.g., + * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')` + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Move the model, from Local Storage to IndexedDB. + * await tf.io.moveModel( + * 'localstorage://demo/management/model1', + * 'indexeddb://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Remove the moved model. + * await tf.io.removeModel('indexeddb://demo/management/model1'); + * ``` + * + * @param sourceURL Source URL of moving. + * @param destURL Destination URL of moving. + * @returns ModelArtifactsInfo of the copied model (if and only if copying + * is successful). + * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or + * if `oldPath` and `newPath` are identical. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function moveModel(sourceURL, destURL) { + return __awaiter(this, void 0, void 0, function () { + var deleteSource; + return __generator(this, function (_a) { + deleteSource = true; + return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)]; + }); + }); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DATABASE_NAME = 'tensorflowjs'; + var DATABASE_VERSION = 1; + // Model data and ModelArtifactsInfo (metadata) are stored in two separate + // stores for efficient access of the list of stored models and their metadata. + // 1. The object store for model data: topology, weights and weight manifests. + var MODEL_STORE_NAME = 'models_store'; + // 2. The object store for ModelArtifactsInfo, including meta-information such + // as the type of topology (JSON vs binary), byte size of the topology, byte + // size of the weights, etc. + var INFO_STORE_NAME = 'model_info_store'; + function getIndexedDBFactory() { + if (!env().getBool('IS_BROWSER')) { + // TODO(cais): Add more info about what IOHandler subtypes are available. + // Maybe point to a doc page on the web and/or automatically determine + // the available IOHandlers and print them in the error message. + throw new Error('Failed to obtain IndexedDB factory because the current environment' + + 'is not a web browser.'); + } + // tslint:disable-next-line:no-any + var theWindow = window || self; + var factory = theWindow.indexedDB || theWindow.mozIndexedDB || + theWindow.webkitIndexedDB || theWindow.msIndexedDB || + theWindow.shimIndexedDB; + if (factory == null) { + throw new Error('The current browser does not appear to support IndexedDB.'); + } + return factory; + } + function setUpDatabase(openRequest) { + var db = openRequest.result; + db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' }); + db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' }); + } + /** + * IOHandler subclass: Browser IndexedDB. + * + * See the doc string of `browserIndexedDB` for more details. + */ + var BrowserIndexedDB = /** @class */ (function () { + function BrowserIndexedDB(modelPath) { + this.indexedDB = getIndexedDBFactory(); + if (modelPath == null || !modelPath) { + throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.'); + } + this.modelPath = modelPath; + } + BrowserIndexedDB.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + // TODO(cais): Support saving GraphDef models. + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + return [2 /*return*/, this.databaseAction(this.modelPath, modelArtifacts)]; + }); + }); + }; + BrowserIndexedDB.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.databaseAction(this.modelPath)]; + }); + }); + }; + /** + * Perform database action to put model artifacts into or read model artifacts + * from IndexedDB object store. + * + * Whether the action is put or get depends on whether `modelArtifacts` is + * specified. If it is specified, the action will be put; otherwise the action + * will be get. + * + * @param modelPath A unique string path for the model. + * @param modelArtifacts If specified, it will be the model artifacts to be + * stored in IndexedDB. + * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise` + * of `ModelArtifacts`, if the action is get. + */ + BrowserIndexedDB.prototype.databaseAction = function (modelPath, modelArtifacts) { + var _this = this; + return new Promise(function (resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; + openRequest.onsuccess = function () { + var db = openRequest.result; + if (modelArtifacts == null) { + // Read model out from object store. + var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly'); + var modelStore = modelTx.objectStore(MODEL_STORE_NAME); + var getRequest_1 = modelStore.get(_this.modelPath); + getRequest_1.onsuccess = function () { + if (getRequest_1.result == null) { + db.close(); + return reject(new Error("Cannot find model with path '" + _this.modelPath + "' " + + "in IndexedDB.")); + } + else { + resolve(getRequest_1.result.modelArtifacts); + } + }; + getRequest_1.onerror = function (error) { + db.close(); + return reject(getRequest_1.error); + }; + modelTx.oncomplete = function () { return db.close(); }; + } + else { + // Put model into object store. + var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts); + // First, put ModelArtifactsInfo into info store. + var infoTx_1 = db.transaction(INFO_STORE_NAME, 'readwrite'); + var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); + var putInfoRequest_1 = infoStore_1.put({ modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1 }); + var modelTx_1; + putInfoRequest_1.onsuccess = function () { + // Second, put model data into model store. + modelTx_1 = db.transaction(MODEL_STORE_NAME, 'readwrite'); + var modelStore = modelTx_1.objectStore(MODEL_STORE_NAME); + var putModelRequest = modelStore.put({ + modelPath: _this.modelPath, + modelArtifacts: modelArtifacts, + modelArtifactsInfo: modelArtifactsInfo_1 + }); + putModelRequest.onsuccess = function () { return resolve({ modelArtifactsInfo: modelArtifactsInfo_1 }); }; + putModelRequest.onerror = function (error) { + // If the put-model request fails, roll back the info entry as + // well. + infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); + var deleteInfoRequest = infoStore_1.delete(_this.modelPath); + deleteInfoRequest.onsuccess = function () { + db.close(); + return reject(putModelRequest.error); + }; + deleteInfoRequest.onerror = function (error) { + db.close(); + return reject(putModelRequest.error); + }; + }; + }; + putInfoRequest_1.onerror = function (error) { + db.close(); + return reject(putInfoRequest_1.error); + }; + infoTx_1.oncomplete = function () { + if (modelTx_1 == null) { + db.close(); + } + else { + modelTx_1.oncomplete = function () { return db.close(); }; + } + }; + } + }; + openRequest.onerror = function (error) { return reject(openRequest.error); }; + }); + }; + BrowserIndexedDB.URL_SCHEME = 'indexeddb://'; + return BrowserIndexedDB; + }()); + var indexedDBRouter = function (url) { + if (!env().getBool('IS_BROWSER')) { + return null; + } + else { + if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { + return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); + } + else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(indexedDBRouter); + IORouterRegistry.registerLoadRouter(indexedDBRouter); + /** + * Creates a browser IndexedDB IOHandler for saving and loading models. + * + * ```js + * const model = tf.sequential(); + * model.add( + * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); + * + * const saveResult = await model.save('indexeddb://MyModel')); + * console.log(saveResult); + * ``` + * + * @param modelPath A unique identifier for the model to be saved. Must be a + * non-empty string. + * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`), + * which can be used with, e.g., `tf.Model.save`. + */ + function browserIndexedDB(modelPath) { + return new BrowserIndexedDB(modelPath); + } + function maybeStripScheme(key) { + return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? + key.slice(BrowserIndexedDB.URL_SCHEME.length) : + key; + } + var BrowserIndexedDBManager = /** @class */ (function () { + function BrowserIndexedDBManager() { + this.indexedDB = getIndexedDBFactory(); + } + BrowserIndexedDBManager.prototype.listModels = function () { + return __awaiter(this, void 0, void 0, function () { + var _this = this; + return __generator(this, function (_a) { + return [2 /*return*/, new Promise(function (resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; + openRequest.onsuccess = function () { + var db = openRequest.result; + var tx = db.transaction(INFO_STORE_NAME, 'readonly'); + var store = tx.objectStore(INFO_STORE_NAME); + // tslint:disable:max-line-length + // Need to cast `store` as `any` here because TypeScript's DOM + // library does not have the `getAll()` method even though the + // method is supported in the latest version of most mainstream + // browsers: + // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll + // tslint:enable:max-line-length + // tslint:disable-next-line:no-any + var getAllInfoRequest = store.getAll(); + getAllInfoRequest.onsuccess = function () { + var out = {}; + for (var _i = 0, _a = getAllInfoRequest.result; _i < _a.length; _i++) { + var item = _a[_i]; + out[item.modelPath] = item.modelArtifactsInfo; + } + resolve(out); + }; + getAllInfoRequest.onerror = function (error) { + db.close(); + return reject(getAllInfoRequest.error); + }; + tx.oncomplete = function () { return db.close(); }; + }; + openRequest.onerror = function (error) { return reject(openRequest.error); }; + })]; + }); + }); + }; + BrowserIndexedDBManager.prototype.removeModel = function (path) { + return __awaiter(this, void 0, void 0, function () { + var _this = this; + return __generator(this, function (_a) { + path = maybeStripScheme(path); + return [2 /*return*/, new Promise(function (resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; + openRequest.onsuccess = function () { + var db = openRequest.result; + var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite'); + var infoStore = infoTx.objectStore(INFO_STORE_NAME); + var getInfoRequest = infoStore.get(path); + var modelTx; + getInfoRequest.onsuccess = function () { + if (getInfoRequest.result == null) { + db.close(); + return reject(new Error("Cannot find model with path '" + path + "' " + + "in IndexedDB.")); + } + else { + // First, delete the entry in the info store. + var deleteInfoRequest = infoStore.delete(path); + var deleteModelData_1 = function () { + // Second, delete the entry in the model store. + modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite'); + var modelStore = modelTx.objectStore(MODEL_STORE_NAME); + var deleteModelRequest = modelStore.delete(path); + deleteModelRequest.onsuccess = function () { + return resolve(getInfoRequest.result.modelArtifactsInfo); + }; + deleteModelRequest.onerror = function (error) { + return reject(getInfoRequest.error); + }; + }; + // Proceed with deleting model data regardless of whether deletion + // of info data succeeds or not. + deleteInfoRequest.onsuccess = deleteModelData_1; + deleteInfoRequest.onerror = function (error) { + deleteModelData_1(); + db.close(); + return reject(getInfoRequest.error); + }; + } + }; + getInfoRequest.onerror = function (error) { + db.close(); + return reject(getInfoRequest.error); + }; + infoTx.oncomplete = function () { + if (modelTx == null) { + db.close(); + } + else { + modelTx.oncomplete = function () { return db.close(); }; + } + }; + }; + openRequest.onerror = function (error) { return reject(openRequest.error); }; + })]; + }); + }); + }; + return BrowserIndexedDBManager; + }()); + if (env().getBool('IS_BROWSER')) { + // Wrap the construction and registration, to guard against browsers that + // don't support Local Storage. + try { + ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager()); + } + catch (err) { + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PATH_SEPARATOR = '/'; + var PATH_PREFIX = 'tensorflowjs_models'; + var INFO_SUFFIX = 'info'; + var MODEL_TOPOLOGY_SUFFIX = 'model_topology'; + var WEIGHT_SPECS_SUFFIX = 'weight_specs'; + var WEIGHT_DATA_SUFFIX = 'weight_data'; + var MODEL_METADATA_SUFFIX = 'model_metadata'; + function getModelKeys(path) { + return { + info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), + topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), + weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), + weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), + modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) + }; + } + /** + * Get model path from a local-storage key. + * + * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1' + * + * @param key + */ + function getModelPathFromKey(key) { + var items = key.split(PATH_SEPARATOR); + if (items.length < 3) { + throw new Error("Invalid key format: " + key); + } + return items.slice(1, items.length - 1).join(PATH_SEPARATOR); + } + function maybeStripScheme$1(key) { + return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? + key.slice(BrowserLocalStorage.URL_SCHEME.length) : + key; + } + /** + * IOHandler subclass: Browser Local Storage. + * + * See the doc string to `browserLocalStorage` for more details. + */ + var BrowserLocalStorage = /** @class */ (function () { + function BrowserLocalStorage(modelPath) { + if (!env().getBool('IS_BROWSER') || + typeof window === 'undefined' || + typeof window.localStorage === 'undefined') { + // TODO(cais): Add more info about what IOHandler subtypes are + // available. + // Maybe point to a doc page on the web and/or automatically determine + // the available IOHandlers and print them in the error message. + throw new Error('The current environment does not support local storage.'); + } + this.LS = window.localStorage; + if (modelPath == null || !modelPath) { + throw new Error('For local storage, modelPath must not be null, undefined or empty.'); + } + this.modelPath = modelPath; + this.keys = getModelKeys(this.modelPath); + } + /** + * Save model artifacts to browser local storage. + * + * See the documentation to `browserLocalStorage` for details on the saved + * artifacts. + * + * @param modelArtifacts The model artifacts to be stored. + * @returns An instance of SaveResult. + */ + BrowserLocalStorage.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + var topology, weightSpecs, modelArtifactsInfo; + return __generator(this, function (_a) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + else { + topology = JSON.stringify(modelArtifacts.modelTopology); + weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); + modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); + try { + this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); + this.LS.setItem(this.keys.topology, topology); + this.LS.setItem(this.keys.weightSpecs, weightSpecs); + this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); + this.LS.setItem(this.keys.modelMetadata, JSON.stringify({ + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + userDefinedMetadata: modelArtifacts.userDefinedMetadata + })); + return [2 /*return*/, { modelArtifactsInfo: modelArtifactsInfo }]; + } + catch (err) { + // If saving failed, clean up all items saved so far. + this.LS.removeItem(this.keys.info); + this.LS.removeItem(this.keys.topology); + this.LS.removeItem(this.keys.weightSpecs); + this.LS.removeItem(this.keys.weightData); + this.LS.removeItem(this.keys.modelMetadata); + throw new Error("Failed to save model '" + this.modelPath + "' to local storage: " + + "size quota being exceeded is a possible cause of this failure: " + + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + ".")); + } + } + return [2 /*return*/]; + }); + }); + }; + /** + * Load a model from local storage. + * + * See the documentation to `browserLocalStorage` for details on the saved + * artifacts. + * + * @returns The loaded model (if loading succeeds). + */ + BrowserLocalStorage.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64; + return __generator(this, function (_a) { + info = JSON.parse(this.LS.getItem(this.keys.info)); + if (info == null) { + throw new Error("In local storage, there is no model with name '" + this.modelPath + "'"); + } + if (info.modelTopologyType !== 'JSON') { + throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + + 'topology yet.'); + } + out = {}; + topology = JSON.parse(this.LS.getItem(this.keys.topology)); + if (topology == null) { + throw new Error("In local storage, the topology of model '" + this.modelPath + "' " + + "is missing."); + } + out.modelTopology = topology; + weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); + if (weightSpecs == null) { + throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' " + + "are missing."); + } + out.weightSpecs = weightSpecs; + metadataString = this.LS.getItem(this.keys.modelMetadata); + if (metadataString != null) { + metadata = JSON.parse(metadataString); + out.format = metadata['format']; + out.generatedBy = metadata['generatedBy']; + out.convertedBy = metadata['convertedBy']; + out.userDefinedMetadata = metadata['userDefinedMetadata']; + } + weightDataBase64 = this.LS.getItem(this.keys.weightData); + if (weightDataBase64 == null) { + throw new Error("In local storage, the binary weight values of model " + + ("'" + this.modelPath + "' are missing.")); + } + out.weightData = base64StringToArrayBuffer(weightDataBase64); + return [2 /*return*/, out]; + }); + }); + }; + BrowserLocalStorage.URL_SCHEME = 'localstorage://'; + return BrowserLocalStorage; + }()); + var localStorageRouter = function (url) { + if (!env().getBool('IS_BROWSER')) { + return null; + } + else { + if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { + return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); + } + else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(localStorageRouter); + IORouterRegistry.registerLoadRouter(localStorageRouter); + /** + * Factory function for local storage IOHandler. + * + * This `IOHandler` supports both `save` and `load`. + * + * For each model's saved artifacts, four items are saved to local storage. + * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the + * model, such as date saved, type of the topology, size in bytes, etc. + * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras- + * style models, this is a stringized JSON. + * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the + * model, can be used to decode the saved binary weight values (see + * item below). + * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary + * weight values, stored as a base64-encoded string. + * + * Saving may throw an `Error` if the total size of the artifacts exceed the + * browser-specific quota. + * + * @param modelPath A unique identifier for the model to be saved. Must be a + * non-empty string. + * @returns An instance of `IOHandler`, which can be used with, e.g., + * `tf.Model.save`. + */ + function browserLocalStorage(modelPath) { + return new BrowserLocalStorage(modelPath); + } + var BrowserLocalStorageManager = /** @class */ (function () { + function BrowserLocalStorageManager() { + assert(env().getBool('IS_BROWSER'), function () { return 'Current environment is not a web browser'; }); + assert(typeof window === 'undefined' || + typeof window.localStorage !== 'undefined', function () { return 'Current browser does not appear to support localStorage'; }); + this.LS = window.localStorage; + } + BrowserLocalStorageManager.prototype.listModels = function () { + return __awaiter(this, void 0, void 0, function () { + var out, prefix, suffix, i, key, modelPath; + return __generator(this, function (_a) { + out = {}; + prefix = PATH_PREFIX + PATH_SEPARATOR; + suffix = PATH_SEPARATOR + INFO_SUFFIX; + for (i = 0; i < this.LS.length; ++i) { + key = this.LS.key(i); + if (key.startsWith(prefix) && key.endsWith(suffix)) { + modelPath = getModelPathFromKey(key); + out[modelPath] = JSON.parse(this.LS.getItem(key)); + } + } + return [2 /*return*/, out]; + }); + }); + }; + BrowserLocalStorageManager.prototype.removeModel = function (path) { + return __awaiter(this, void 0, void 0, function () { + var keys, info; + return __generator(this, function (_a) { + path = maybeStripScheme$1(path); + keys = getModelKeys(path); + if (this.LS.getItem(keys.info) == null) { + throw new Error("Cannot find model at path '" + path + "'"); + } + info = JSON.parse(this.LS.getItem(keys.info)); + this.LS.removeItem(keys.info); + this.LS.removeItem(keys.topology); + this.LS.removeItem(keys.weightSpecs); + this.LS.removeItem(keys.weightData); + return [2 /*return*/, info]; + }); + }); + }; + return BrowserLocalStorageManager; + }()); + if (env().getBool('IS_BROWSER')) { + // Wrap the construction and registration, to guard against browsers that + // don't support Local Storage. + try { + ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); + } + catch (err) { + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DEFAULT_FILE_NAME_PREFIX = 'model'; + var DEFAULT_JSON_EXTENSION_NAME = '.json'; + var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; + function defer(f) { + return new Promise(function (resolve) { return setTimeout(resolve); }).then(f); + } + var BrowserDownloads = /** @class */ (function () { + function BrowserDownloads(fileNamePrefix) { + if (!env().getBool('IS_BROWSER')) { + // TODO(cais): Provide info on what IOHandlers are available under the + // current environment. + throw new Error('browserDownloads() cannot proceed because the current environment ' + + 'is not a browser.'); + } + if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { + fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); + } + if (fileNamePrefix == null || fileNamePrefix.length === 0) { + fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; + } + this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; + this.weightDataFileName = + fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; + } + BrowserDownloads.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + var weightsURL, weightsManifest, modelTopologyAndWeightManifest, modelTopologyAndWeightManifestURL, jsonAnchor_1, weightDataAnchor_1; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (typeof (document) === 'undefined') { + throw new Error('Browser downloads are not supported in ' + + 'this environment since `document` is not present'); + } + weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' })); + if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3 /*break*/, 1]; + throw new Error('BrowserDownloads.save() does not support saving model topology ' + + 'in binary formats yet.'); + case 1: + weightsManifest = [{ + paths: ['./' + this.weightDataFileName], + weights: modelArtifacts.weightSpecs + }]; + modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + weightsManifest: weightsManifest + }; + modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: 'application/json' })); + jsonAnchor_1 = this.jsonAnchor == null ? document.createElement('a') : + this.jsonAnchor; + jsonAnchor_1.download = this.modelTopologyFileName; + jsonAnchor_1.href = modelTopologyAndWeightManifestURL; + // Trigger downloads by evoking a click event on the download anchors. + // When multiple downloads are started synchronously, Firefox will only + // save the last one. + return [4 /*yield*/, defer(function () { return jsonAnchor_1.dispatchEvent(new MouseEvent('click')); })]; + case 2: + // Trigger downloads by evoking a click event on the download anchors. + // When multiple downloads are started synchronously, Firefox will only + // save the last one. + _a.sent(); + if (!(modelArtifacts.weightData != null)) return [3 /*break*/, 4]; + weightDataAnchor_1 = this.weightDataAnchor == null ? + document.createElement('a') : + this.weightDataAnchor; + weightDataAnchor_1.download = this.weightDataFileName; + weightDataAnchor_1.href = weightsURL; + return [4 /*yield*/, defer(function () { return weightDataAnchor_1.dispatchEvent(new MouseEvent('click')); })]; + case 3: + _a.sent(); + _a.label = 4; + case 4: return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }]; + } + }); + }); + }; + BrowserDownloads.URL_SCHEME = 'downloads://'; + return BrowserDownloads; + }()); + var BrowserFiles = /** @class */ (function () { + function BrowserFiles(files) { + if (files == null || files.length < 1) { + throw new Error("When calling browserFiles, at least 1 file is required, " + + ("but received " + files)); + } + this.files = files; + } + BrowserFiles.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + var jsonFile, weightFiles; + var _this = this; + return __generator(this, function (_a) { + jsonFile = this.files[0]; + weightFiles = this.files.slice(1); + return [2 /*return*/, new Promise(function (resolve, reject) { + var jsonReader = new FileReader(); + jsonReader.onload = function (event) { + // tslint:disable-next-line:no-any + var modelJSON = JSON.parse(event.target.result); + var modelTopology = modelJSON.modelTopology; + if (modelTopology == null) { + reject(new Error("modelTopology field is missing from file " + jsonFile.name)); + return; + } + if (weightFiles.length === 0) { + resolve({ modelTopology: modelTopology }); + } + var weightsManifest = modelJSON.weightsManifest; + if (weightsManifest == null) { + reject(new Error("weightManifest field is missing from file " + jsonFile.name)); + return; + } + var pathToFile; + try { + pathToFile = + _this.checkManifestAndWeightFiles(weightsManifest, weightFiles); + } + catch (err) { + reject(err); + return; + } + var weightSpecs = []; + var paths = []; + var perFileBuffers = []; + weightsManifest.forEach(function (weightsGroup) { + weightsGroup.paths.forEach(function (path) { + paths.push(path); + perFileBuffers.push(null); + }); + weightSpecs.push.apply(weightSpecs, weightsGroup.weights); + }); + weightsManifest.forEach(function (weightsGroup) { + weightsGroup.paths.forEach(function (path) { + var weightFileReader = new FileReader(); + weightFileReader.onload = function (event) { + // tslint:disable-next-line:no-any + var weightData = event.target.result; + var index = paths.indexOf(path); + perFileBuffers[index] = weightData; + if (perFileBuffers.indexOf(null) === -1) { + resolve({ + modelTopology: modelTopology, + weightSpecs: weightSpecs, + weightData: concatenateArrayBuffers(perFileBuffers), + format: modelJSON.format, + generatedBy: modelJSON.generatedBy, + convertedBy: modelJSON.convertedBy, + userDefinedMetadata: modelJSON.userDefinedMetadata + }); + } + }; + weightFileReader.onerror = function (error) { + return reject("Failed to weights data from file of path '" + path + "'."); + }; + weightFileReader.readAsArrayBuffer(pathToFile[path]); + }); + }); + }; + jsonReader.onerror = function (error) { return reject("Failed to read model topology and weights manifest JSON " + + ("from file '" + jsonFile.name + "'. BrowserFiles supports loading ") + + "Keras-style tf.Model artifacts only."); }; + jsonReader.readAsText(jsonFile); + })]; + }); + }); + }; + /** + * Check the compatibility between weights manifest and weight files. + */ + BrowserFiles.prototype.checkManifestAndWeightFiles = function (manifest, files) { + var basenames = []; + var fileNames = files.map(function (file) { return basename(file.name); }); + var pathToFile = {}; + for (var _i = 0, manifest_1 = manifest; _i < manifest_1.length; _i++) { + var group = manifest_1[_i]; + group.paths.forEach(function (path) { + var pathBasename = basename(path); + if (basenames.indexOf(pathBasename) !== -1) { + throw new Error("Duplicate file basename found in weights manifest: " + + ("'" + pathBasename + "'")); + } + basenames.push(pathBasename); + if (fileNames.indexOf(pathBasename) === -1) { + throw new Error("Weight file with basename '" + pathBasename + "' is not provided."); + } + else { + pathToFile[path] = files[fileNames.indexOf(pathBasename)]; + } + }); + } + if (basenames.length !== files.length) { + throw new Error("Mismatch in the number of files in weights manifest " + + ("(" + basenames.length + ") and the number of weight files provided ") + + ("(" + files.length + ").")); + } + return pathToFile; + }; + return BrowserFiles; + }()); + var browserDownloadsRouter = function (url) { + if (!env().getBool('IS_BROWSER')) { + return null; + } + else { + if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { + return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); + } + else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(browserDownloadsRouter); + /** + * Creates an IOHandler that triggers file downloads from the browser. + * + * The returned `IOHandler` instance can be used as model exporting methods such + * as `tf.Model.save` and supports only saving. + * + * ```js + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * const saveResult = await model.save('downloads://mymodel'); + * // This will trigger downloading of two files: + * // 'mymodel.json' and 'mymodel.weights.bin'. + * console.log(saveResult); + * ``` + * + * @param fileNamePrefix Prefix name of the files to be downloaded. For use with + * `tf.Model`, `fileNamePrefix` should follow either of the following two + * formats: + * 1. `null` or `undefined`, in which case the default file + * names will be used: + * - 'model.json' for the JSON file containing the model topology and + * weights manifest. + * - 'model.weights.bin' for the binary file containing the binary weight + * values. + * 2. A single string or an Array of a single string, as the file name prefix. + * For example, if `'foo'` is provided, the downloaded JSON + * file and binary weights file will be named 'foo.json' and + * 'foo.weights.bin', respectively. + * @param config Additional configuration for triggering downloads. + * @returns An instance of `BrowserDownloads` `IOHandler`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Loading', + * namespace: 'io', + * ignoreCI: true + * } + */ + function browserDownloads(fileNamePrefix) { + if (fileNamePrefix === void 0) { fileNamePrefix = 'model'; } + return new BrowserDownloads(fileNamePrefix); + } + /** + * Creates an IOHandler that loads model artifacts from user-selected files. + * + * This method can be used for loading from files such as user-selected files + * in the browser. + * When used in conjunction with `tf.loadLayersModel`, an instance of + * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. + * + * ```js + * // Note: This code snippet won't run properly without the actual file input + * // elements in the HTML DOM. + * + * // Suppose there are two HTML file input (``) + * // elements. + * const uploadJSONInput = document.getElementById('upload-json'); + * const uploadWeightsInput = document.getElementById('upload-weights'); + * const model = await tf.loadLayersModel(tf.io.browserFiles( + * [uploadJSONInput.files[0], uploadWeightsInput.files[0]])); + * ``` + * + * @param files `File`s to load from. Currently, this function supports only + * loading from files that contain Keras-style models (i.e., `tf.Model`s), for + * which an `Array` of `File`s is expected (in that order): + * - A JSON file containing the model topology and weight manifest. + * - Optionally, One or more binary files containing the binary weights. + * These files must have names that match the paths in the `weightsManifest` + * contained by the aforementioned JSON file, or errors will be thrown + * during loading. These weights files have the same format as the ones + * generated by `tensorflowjs_converter` that comes with the `tensorflowjs` + * Python PIP package. If no weights files are provided, only the model + * topology will be loaded from the JSON file above. + * @returns An instance of `Files` `IOHandler`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Loading', + * namespace: 'io', + * ignoreCI: true + * } + */ + function browserFiles(files) { + return new BrowserFiles(files); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Monitor Promise.all progress, fire onProgress callback function. + * + * @param promises Promise list going to be monitored + * @param onProgress Callback function. Fired when a promise resolved. + * @param startFraction Optional fraction start. Default to 0. + * @param endFraction Optional fraction end. Default to 1. + */ + function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { + checkPromises(promises); + startFraction = startFraction == null ? 0 : startFraction; + endFraction = endFraction == null ? 1 : endFraction; + checkFraction(startFraction, endFraction); + var resolvedPromise = 0; + var registerMonitor = function (promise) { + promise.then(function (value) { + var fraction = startFraction + + ++resolvedPromise / promises.length * (endFraction - startFraction); + // pass fraction as parameter to callback function. + onProgress(fraction); + return value; + }); + return promise; + }; + function checkPromises(promises) { + assert(promises != null && Array.isArray(promises) && promises.length > 0, function () { return 'promises must be a none empty array'; }); + } + function checkFraction(startFraction, endFraction) { + assert(startFraction >= 0 && startFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + + ("got startFraction " + startFraction); }); + assert(endFraction >= 0 && endFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + + ("got endFraction " + endFraction); }); + assert(endFraction >= startFraction, function () { return "startFraction must be no more than endFraction, but " + + ("got startFraction " + startFraction + " and endFraction ") + + ("" + endFraction); }); + } + return Promise.all(promises.map(registerMonitor)); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Reads binary weights data from a number of URLs. + * + * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. + * @param requestOptions RequestInit (options) for the HTTP requests. + * @param fetchFunc Optional overriding value for the `window.fetch` function. + * @param onProgress Optional, progress callback function, fired periodically + * before the load is completed. + * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same + * length as `fetchURLs`. + */ + function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { + return __awaiter(this, void 0, void 0, function () { + var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b; + return __generator(this, function (_c) { + switch (_c.label) { + case 0: + if (loadOptions == null) { + loadOptions = {}; + } + fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : + loadOptions.fetchFunc; + requests = fetchURLs.map(function (fetchURL) { + return fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }); + }); + fetchStartFraction = 0; + fetchEndFraction = 0.5; + if (!(loadOptions.onProgress == null)) return [3 /*break*/, 2]; + return [4 /*yield*/, Promise.all(requests)]; + case 1: + _a = _c.sent(); + return [3 /*break*/, 4]; + case 2: return [4 /*yield*/, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)]; + case 3: + _a = _c.sent(); + _c.label = 4; + case 4: + responses = _a; + bufferPromises = responses.map(function (response) { return response.arrayBuffer(); }); + bufferStartFraction = 0.5; + bufferEndFraction = 1; + if (!(loadOptions.onProgress == null)) return [3 /*break*/, 6]; + return [4 /*yield*/, Promise.all(bufferPromises)]; + case 5: + _b = _c.sent(); + return [3 /*break*/, 8]; + case 6: return [4 /*yield*/, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)]; + case 7: + _b = _c.sent(); + _c.label = 8; + case 8: + buffers = _b; + return [2 /*return*/, buffers]; + } + }); + }); + } + /** + * Reads a weights manifest JSON configuration, fetches the weights and + * returns them as `Tensor`s. + * + * @param manifest The weights manifest JSON. + * @param filePathPrefix The path prefix for filenames given in the manifest. + * Defaults to the empty string. + * @param weightNames The names of the weights to be fetched. + */ + function loadWeights(manifest, filePathPrefix, weightNames, requestInit) { + if (filePathPrefix === void 0) { filePathPrefix = ''; } + return __awaiter(this, void 0, void 0, function () { + var fetchWeights, loadWeights; + return __generator(this, function (_a) { + fetchWeights = function (fetchUrls) { + return loadWeightsAsArrayBuffer(fetchUrls, { requestInit: requestInit }); + }; + loadWeights = weightsLoaderFactory(fetchWeights); + return [2 /*return*/, loadWeights(manifest, filePathPrefix, weightNames)]; + }); + }); + } + /** + * Creates a function, which reads a weights manifest JSON configuration, + * fetches the weight files using the specified function and returns them as + * `Tensor`s. + * + * ```js + * // example for creating a nodejs weight loader, which reads the weight files + * // from disk using fs.readFileSync + * + * import * as fs from 'fs' + * + * const fetchWeightsFromDisk = (filePaths: string[]) => + * filePaths.map(filePath => fs.readFileSync(filePath).buffer) + * + * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) + * + * const manifest = JSON.parse( + * fs.readFileSync('./my_model-weights_manifest').toString() + * ) + * const weightMap = await loadWeights(manifest, './') + * ``` + * @param fetchWeightsFunction The function used for fetching the weight files. + * @returns Weight loading function. + */ + function weightsLoaderFactory(fetchWeightsFunction) { + var _this = this; + return function (manifest, filePathPrefix, weightNames) { + if (filePathPrefix === void 0) { filePathPrefix = ''; } + return __awaiter(_this, void 0, void 0, function () { + var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + groupIndicesToFetchMap = manifest.map(function () { return false; }); + groupWeightsToFetch = {}; + weightsFound = weightNames != null ? weightNames.map(function () { return false; }) : []; + allManifestWeightNames = []; + manifest.forEach(function (manifestGroupConfig, groupIndex) { + var groupOffset = 0; + manifestGroupConfig.weights.forEach(function (weightsEntry) { + var rawDtype = ('quantization' in weightsEntry) ? + weightsEntry.quantization.dtype : + weightsEntry.dtype; + var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * + sizeFromShape(weightsEntry.shape); + var enqueueWeightsForFetchingFn = function () { + groupIndicesToFetchMap[groupIndex] = true; + if (groupWeightsToFetch[groupIndex] == null) { + groupWeightsToFetch[groupIndex] = []; + } + groupWeightsToFetch[groupIndex].push({ + manifestEntry: weightsEntry, + groupOffset: groupOffset, + sizeBytes: weightsBytes + }); + }; + if (weightNames != null) { + weightNames.forEach(function (weightName, weightIndex) { + if (weightName === weightsEntry.name) { + enqueueWeightsForFetchingFn(); + weightsFound[weightIndex] = true; + } + }); + } + else { + enqueueWeightsForFetchingFn(); + } + allManifestWeightNames.push(weightsEntry.name); + groupOffset += weightsBytes; + }); + }); + if (!weightsFound.every(function (found) { return found; })) { + weightsNotFound = weightNames.filter(function (_, i) { return !weightsFound[i]; }); + throw new Error("Could not find weights in manifest with names: " + + (weightsNotFound.join(', ') + ". \n") + + "Manifest JSON has weights with names: " + + (allManifestWeightNames.join(', ') + ".")); + } + groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) { + if (shouldFetch) { + accumulator.push(i); + } + return accumulator; + }, []); + fetchUrls = []; + groupIndicesToFetch.forEach(function (i) { + manifest[i].paths.forEach(function (filepath) { + var fetchUrl = filePathPrefix + + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; + fetchUrls.push(fetchUrl); + }); + }); + return [4 /*yield*/, fetchWeightsFunction(fetchUrls)]; + case 1: + buffers = _a.sent(); + weightsTensorMap = {}; + bufferIndexOffset = 0; + groupIndicesToFetch.forEach(function (i) { + var numBuffers = manifest[i].paths.length; + var groupBytes = 0; + for (var i_1 = 0; i_1 < numBuffers; i_1++) { + groupBytes += buffers[bufferIndexOffset + i_1].byteLength; + } + // Create a buffer for the whole group. + var groupBuffer = new ArrayBuffer(groupBytes); + var groupByteBuffer = new Uint8Array(groupBuffer); + var groupBufferOffset = 0; + for (var i_2 = 0; i_2 < numBuffers; i_2++) { + var buffer = new Uint8Array(buffers[bufferIndexOffset + i_2]); + groupByteBuffer.set(buffer, groupBufferOffset); + groupBufferOffset += buffer.byteLength; + } + var weightsEntries = groupWeightsToFetch[i]; + weightsEntries.forEach(function (weightsEntry) { + var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); + var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); + for (var name_1 in nameToTensorMap) { + weightsTensorMap[name_1] = nameToTensorMap[name_1]; + } + }); + bufferIndexOffset += numBuffers; + }); + return [2 /*return*/, weightsTensorMap]; + } + }); + }); + }; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; + var JSON_TYPE = 'application/json'; + var HTTPRequest = /** @class */ (function () { + function HTTPRequest(path, loadOptions) { + this.DEFAULT_METHOD = 'POST'; + if (loadOptions == null) { + loadOptions = {}; + } + this.weightPathPrefix = loadOptions.weightPathPrefix; + this.onProgress = loadOptions.onProgress; + if (loadOptions.fetchFunc != null) { + assert(typeof loadOptions.fetchFunc === 'function', function () { return 'Must pass a function that matches the signature of ' + + '`fetch` (see ' + + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'; }); + this.fetch = loadOptions.fetchFunc; + } + else { + this.fetch = env().platform.fetch; + } + assert(path != null && path.length > 0, function () { return 'URL path for http must not be null, undefined or ' + + 'empty.'; }); + if (Array.isArray(path)) { + assert(path.length === 2, function () { return 'URL paths for http must have a length of 2, ' + + ("(actual length is " + path.length + ")."); }); + } + this.path = path; + if (loadOptions.requestInit != null && + loadOptions.requestInit.body != null) { + throw new Error('requestInit is expected to have no pre-existing body, but has one.'); + } + this.requestInit = loadOptions.requestInit || {}; + } + HTTPRequest.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + var init, weightsManifest, modelTopologyAndWeightManifest, response; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit); + init.body = new FormData(); + weightsManifest = [{ + paths: ['./model.weights.bin'], + weights: modelArtifacts.weightSpecs, + }]; + modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + userDefinedMetadata: modelArtifacts.userDefinedMetadata, + weightsManifest: weightsManifest + }; + init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json'); + if (modelArtifacts.weightData != null) { + init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin'); + } + return [4 /*yield*/, this.fetch(this.path, init)]; + case 1: + response = _a.sent(); + if (response.ok) { + return [2 /*return*/, { + modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), + responses: [response], + }]; + } + else { + throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + + (response.status + ".")); + } + } + }); + }); + }; + /** + * Load model artifacts via HTTP request(s). + * + * See the documentation to `tf.io.http` for details on the saved + * artifacts. + * + * @returns The loaded model artifacts (if loading succeeds). + */ + HTTPRequest.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + var modelConfigRequest, modelConfig, e_1, message, modelTopology, weightsManifest, generatedBy, convertedBy, format, userDefinedMetadata, weightSpecs, weightData, results; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.fetch(this.path, this.requestInit)]; + case 1: + modelConfigRequest = _a.sent(); + if (!modelConfigRequest.ok) { + throw new Error("Request to " + this.path + " failed with status code " + + (modelConfigRequest.status + ". Please verify this URL points to ") + + "the model JSON of the model to load."); + } + _a.label = 2; + case 2: + _a.trys.push([2, 4, , 5]); + return [4 /*yield*/, modelConfigRequest.json()]; + case 3: + modelConfig = _a.sent(); + return [3 /*break*/, 5]; + case 4: + e_1 = _a.sent(); + message = "Failed to parse model JSON of response from " + this.path + "."; + // TODO(nsthorat): Remove this after some time when we're comfortable that + // .pb files are mostly gone. + if (this.path.endsWith('.pb')) { + message += ' Your path contains a .pb file extension. ' + + 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + + 'in favor of .json models. You can re-convert your Python ' + + 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + + 'or you can convert your.pb models with the \'pb2json\'' + + 'NPM script in the tensorflow/tfjs-converter repository.'; + } + else { + message += ' Please make sure the server is serving valid ' + + 'JSON for this request.'; + } + throw new Error(message); + case 5: + modelTopology = modelConfig.modelTopology; + weightsManifest = modelConfig.weightsManifest; + generatedBy = modelConfig.generatedBy; + convertedBy = modelConfig.convertedBy; + format = modelConfig.format; + userDefinedMetadata = modelConfig.userDefinedMetadata; + // We do not allow both modelTopology and weightsManifest to be missing. + if (modelTopology == null && weightsManifest == null) { + throw new Error("The JSON from HTTP path " + this.path + " contains neither model " + + "topology or manifest for weights."); + } + if (!(weightsManifest != null)) return [3 /*break*/, 7]; + return [4 /*yield*/, this.loadWeights(weightsManifest)]; + case 6: + results = _a.sent(); + weightSpecs = results[0], weightData = results[1]; + _a.label = 7; + case 7: return [2 /*return*/, { + modelTopology: modelTopology, + weightSpecs: weightSpecs, + weightData: weightData, + userDefinedMetadata: userDefinedMetadata, + generatedBy: generatedBy, + convertedBy: convertedBy, + format: format + }]; + } + }); + }); + }; + HTTPRequest.prototype.loadWeights = function (weightsManifest) { + return __awaiter(this, void 0, void 0, function () { + var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, _i, weightsManifest_1, entry, fetchURLs, buffers; + return __generator(this, function (_b) { + switch (_b.label) { + case 0: + weightPath = Array.isArray(this.path) ? this.path[1] : this.path; + _a = parseUrl(weightPath), prefix = _a[0], suffix = _a[1]; + pathPrefix = this.weightPathPrefix || prefix; + weightSpecs = []; + for (_i = 0, weightsManifest_1 = weightsManifest; _i < weightsManifest_1.length; _i++) { + entry = weightsManifest_1[_i]; + weightSpecs.push.apply(weightSpecs, entry.weights); + } + fetchURLs = []; + weightsManifest.forEach(function (weightsGroup) { + weightsGroup.paths.forEach(function (path) { + fetchURLs.push(pathPrefix + path + suffix); + }); + }); + return [4 /*yield*/, loadWeightsAsArrayBuffer(fetchURLs, { + requestInit: this.requestInit, + fetchFunc: this.fetch, + onProgress: this.onProgress + })]; + case 1: + buffers = _b.sent(); + return [2 /*return*/, [weightSpecs, concatenateArrayBuffers(buffers)]]; + } + }); + }); + }; + HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//; + return HTTPRequest; + }()); + /** + * Extract the prefix and suffix of the url, where the prefix is the path before + * the last file, and suffix is the search params after the last file. + * ``` + * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file' + * [prefix, suffix] = parseUrl(url) + * // prefix = 'http://tfhub.dev/model/1/' + * // suffix = '?tfjs-format=file' + * ``` + * @param url the model url to be parsed. + */ + function parseUrl(url) { + var lastSlash = url.lastIndexOf('/'); + var lastSearchParam = url.lastIndexOf('?'); + var prefix = url.substring(0, lastSlash); + var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ''; + return [prefix + '/', suffix]; + } + function isHTTPScheme(url) { + return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; + } + var httpRouter = function (url, loadOptions) { + if (typeof fetch === 'undefined' && + (loadOptions == null || loadOptions.fetchFunc == null)) { + // `http` uses `fetch` or `node-fetch`, if one wants to use it in + // an environment that is not the browser or node they have to setup a + // global fetch polyfill. + return null; + } + else { + var isHTTP = true; + if (Array.isArray(url)) { + isHTTP = url.every(function (urlItem) { return isHTTPScheme(urlItem); }); + } + else { + isHTTP = isHTTPScheme(url); + } + if (isHTTP) { + return http(url, loadOptions); + } + } + return null; + }; + IORouterRegistry.registerSaveRouter(httpRouter); + IORouterRegistry.registerLoadRouter(httpRouter); + /** + * Creates an IOHandler subtype that sends model artifacts to HTTP server. + * + * An HTTP request of the `multipart/form-data` mime type will be sent to the + * `path` URL. The form data includes artifacts that represent the topology + * and/or weights of the model. In the case of Keras-style `tf.Model`, two + * blobs (files) exist in form-data: + * - A JSON file consisting of `modelTopology` and `weightsManifest`. + * - A binary weights file consisting of the concatenated weight values. + * These files are in the same format as the one generated by + * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). + * + * The following code snippet exemplifies the client-side code that uses this + * function: + * + * ```js + * const model = tf.sequential(); + * model.add( + * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); + * + * const saveResult = await model.save(tf.io.http( + * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}})); + * console.log(saveResult); + * ``` + * + * If the default `POST` method is to be used, without any custom parameters + * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`: + * + * ```js + * const saveResult = await model.save('http://model-server:5000/upload'); + * ``` + * + * The following GitHub Gist + * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 + * implements a server based on [flask](https://github.com/pallets/flask) that + * can receive the request. Upon receiving the model artifacts via the requst, + * this particular server reconsistutes instances of [Keras + * Models](https://keras.io/models/model/) in memory. + * + * + * @param path A URL path to the model. + * Can be an absolute HTTP path (e.g., + * 'http://localhost:8000/model-upload)') or a relative path (e.g., + * './model-upload'). + * @param requestInit Request configurations to be used when sending + * HTTP request to server using `fetch`. It can contain fields such as + * `method`, `credentials`, `headers`, `mode`, etc. See + * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request + * for more information. `requestInit` must not have a body, because the + * body will be set by TensorFlow.js. File blobs representing the model + * topology (filename: 'model.json') and the weights of the model (filename: + * 'model.weights.bin') will be appended to the body. If `requestInit` has a + * `body`, an Error will be thrown. + * @param loadOptions Optional configuration for the loading. It includes the + * following fields: + * - weightPathPrefix Optional, this specifies the path prefix for weight + * files, by default this is calculated from the path param. + * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js, + * the `fetch` from node-fetch can be used here. + * - onProgress Optional, progress callback function, fired periodically + * before the load is completed. + * @returns An instance of `IOHandler`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Loading', + * namespace: 'io', + * ignoreCI: true + * } + */ + function http(path, loadOptions) { + return new HTTPRequest(path, loadOptions); + } + /** + * Deprecated. Use `tf.io.http`. + * @param path + * @param loadOptions + */ + function browserHTTPRequest(path, loadOptions) { + return http(path, loadOptions); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PassthroughLoader = /** @class */ (function () { + function PassthroughLoader(modelArtifacts) { + this.modelArtifacts = modelArtifacts; + } + PassthroughLoader.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.modelArtifacts]; + }); + }); + }; + return PassthroughLoader; + }()); + var PassthroughSaver = /** @class */ (function () { + function PassthroughSaver(saveHandler) { + this.saveHandler = saveHandler; + } + PassthroughSaver.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.saveHandler(modelArtifacts)]; + }); + }); + }; + return PassthroughSaver; + }()); + /** + * Creates an IOHandler that loads model artifacts from memory. + * + * When used in conjunction with `tf.loadLayersModel`, an instance of + * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. + * + * ```js + * const model = await tf.loadLayersModel(tf.io.fromMemory( + * modelTopology, weightSpecs, weightData)); + * ``` + * + * @param modelArtifacts a object containing model topology (i.e., parsed from + * the JSON format). + * @param weightSpecs An array of `WeightsManifestEntry` objects describing the + * names, shapes, types, and quantization of the weight data. + * @param weightData A single `ArrayBuffer` containing the weight data, + * concatenated in the order described by the weightSpecs. + * @param trainingConfig Model training configuration. Optional. + * + * @returns A passthrough `IOHandler` that simply loads the provided data. + */ + function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { + if (arguments.length === 1) { + var isModelArtifacts = modelArtifacts.modelTopology != null || + modelArtifacts.weightSpecs != null; + if (isModelArtifacts) { + return new PassthroughLoader(modelArtifacts); + } + else { + // Legacy support: with only modelTopology. + // TODO(cais): Remove this deprecated API. + console.warn('Please call tf.io.fromMemory() with only one argument. ' + + 'The argument should be of type ModelArtifacts. ' + + 'The multi-argument signature of tf.io.fromMemory() has been ' + + 'deprecated and will be removed in a future release.'); + return new PassthroughLoader({ modelTopology: modelArtifacts }); + } + } + else { + // Legacy support. + // TODO(cais): Remove this deprecated API. + console.warn('Please call tf.io.fromMemory() with only one argument. ' + + 'The argument should be of type ModelArtifacts. ' + + 'The multi-argument signature of tf.io.fromMemory() has been ' + + 'deprecated and will be removed in a future release.'); + return new PassthroughLoader({ + modelTopology: modelArtifacts, + weightSpecs: weightSpecs, + weightData: weightData, + trainingConfig: trainingConfig + }); + } + } + /** + * Creates an IOHandler that passes saved model artifacts to a callback. + * + * ```js + * function handleSave(artifacts) { + * // ... do something with the artifacts ... + * return {modelArtifactsInfo: {...}, ...}; + * } + * + * const saveResult = model.save(tf.io.withSaveHandler(handleSave)); + * ``` + * + * @param saveHandler A function that accepts a `ModelArtifacts` and returns a + * `SaveResult`. + */ + function withSaveHandler(saveHandler) { + return new PassthroughSaver(saveHandler); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + var io = { + __proto__: null, + browserFiles: browserFiles, + browserHTTPRequest: browserHTTPRequest, + concatenateArrayBuffers: concatenateArrayBuffers, + decodeWeights: decodeWeights, + encodeWeights: encodeWeights, + fromMemory: fromMemory, + getLoadHandlers: getLoadHandlers, + getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON, + getSaveHandlers: getSaveHandlers, + http: http, + isHTTPScheme: isHTTPScheme, + loadWeights: loadWeights, + registerLoadRouter: registerLoadRouter, + registerSaveRouter: registerSaveRouter, + weightsLoaderFactory: weightsLoaderFactory, + withSaveHandler: withSaveHandler, + copyModel: copyModel, + listModels: listModels, + moveModel: moveModel, + removeModel: removeModel + }; + + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take + * value `onValue` (defaults to 1), while all other locations take value + * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank + * `R+1` with the last axis of size `depth`. + * + * ```js + * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print(); + * ``` + * + * @param indices `tf.Tensor` of indices with dtype `int32`. + * @param depth The depth of the one hot dimension. + * @param onValue A number used to fill in the output when the index matches + * the location. + * @param offValue A number used to fill in the output when the index does + * not match the location. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function oneHot_(indices, depth, onValue, offValue) { + if (onValue === void 0) { onValue = 1; } + if (offValue === void 0) { offValue = 0; } + if (depth < 2) { + throw new Error("Error in oneHot: depth must be >=2, but it is " + depth); + } + var $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); + var outShape = $indices.shape.concat([depth]); + $indices = $indices.flatten(); + var forward = function (backend, save) { + save([$indices]); + return reshape(backend.oneHot($indices, depth, onValue, offValue), outShape); + }; + var inputs = { indices: $indices }; + var attrs = { depth: depth, onValue: onValue, offValue: offValue }; + return ENGINE.runKernelFunc(forward, inputs, null /* grad */, OneHot, attrs); + } + var oneHot = op({ oneHot_: oneHot_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the confusion matrix from true labels and predicted labels. + * + * ```js + * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32'); + * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32'); + * const numClasses = 3; + * const out = tf.math.confusionMatrix(labels, predictions, numClasses); + * out.print(); + * // Expected output matrix: + * // [[2, 0, 0], + * // [0, 1, 1], + * // [0, 0, 1]] + * ``` + * + * @param labels The target labels, assumed to be 0-based integers + * for the classes. The shape is `[numExamples]`, where + * `numExamples` is the number of examples included. + * @param predictions The predicted classes, assumed to be + * 0-based integers for the classes. Must have the same shape as `labels`. + * @param numClasses Number of all classes, as an integer. + * Its value must be larger than the largest element in `labels` and + * `predictions`. + * @returns The confusion matrix as a int32-type 2D tensor. The value at + * row `r` and column `c` is the number of times examples of actual class + * `r` were predicted as class `c`. + */ + /** @doc {heading: 'Operations', subheading: 'Evaluation'} */ + function confusionMatrix_(labels, predictions, numClasses) { + var $labels = convertToTensor(labels, 'labels', 'confusionMatrix'); + var $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix'); + assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function () { return "If provided, numClasses must be a positive integer, " + + ("but got " + numClasses); }); + assert($labels.rank === 1, function () { return "Expected the rank of labels to be 1, but got " + $labels.rank; }); + assert($predictions.rank === 1, function () { return "Expected the rank of predictions to be 1, " + + ("but got " + $predictions.rank); }); + assert($labels.shape[0] === $predictions.shape[0], function () { return "Mismatch in the number of examples: " + + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + + "Labels and predictions should have the same number of elements."; }); + assert(numClasses > 0 && Number.isInteger(numClasses), function () { return "numClasses is required to be a positive integer, but got " + + ("" + numClasses); }); + // TODO(cais): In the future, if oneHot supports tensors inputs for + // `numClasses`, `confusionMatrix` can make `numClasses` optional. + var oneHotLabels = oneHot($labels.asType('int32'), numClasses); + var oneHotPredictions = oneHot($predictions.asType('int32'), numClasses); + var oneHotLabelsT = oneHotLabels.transpose(); + return oneHotLabelsT.matMul(oneHotPredictions).asType('int32'); + } + var confusionMatrix = op({ confusionMatrix_: confusionMatrix_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + var math = { + __proto__: null, + confusionMatrix: confusionMatrix + }; + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var fromPixels2DContext; + /** + * Creates a `tf.Tensor` from an image. + * + * ```js + * const image = new ImageData(1, 1); + * image.data[0] = 100; + * image.data[1] = 150; + * image.data[2] = 200; + * image.data[3] = 255; + * + * tf.browser.fromPixels(image).print(); + * ``` + * + * @param pixels The input image to construct the tensor from. The + * supported image types are all 4-channel. You can also pass in an image + * object with following attributes: + * `{data: Uint8Array; width: number; height: number}` + * @param numChannels The number of channels of the output tensor. A + * numChannels value less than 4 allows you to ignore channels. Defaults to + * 3 (ignores alpha channel of input image). + */ + /** @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true} */ + function fromPixels_(pixels, numChannels) { + if (numChannels === void 0) { numChannels = 3; } + // Sanity checks. + if (numChannels > 4) { + throw new Error('Cannot construct Tensor with more than 4 channels from pixels.'); + } + if (pixels == null) { + throw new Error('pixels passed to tf.browser.fromPixels() can not be null'); + } + var isPixelData = false; + var isImageData = false; + var isVideo = false; + var isImage = false; + var isCanvasLike = false; + if (pixels.data instanceof Uint8Array) { + isPixelData = true; + } + else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) { + isImageData = true; + } + else if (typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement) { + isVideo = true; + } + else if (typeof (HTMLImageElement) !== 'undefined' && + pixels instanceof HTMLImageElement) { + isImage = true; + // tslint:disable-next-line: no-any + } + else if (pixels.getContext != null) { + isCanvasLike = true; + } + else { + throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " + + "in browser, or OffscreenCanvas, ImageData in webworker" + + " or {data: Uint32Array, width: number, height: number}, " + + ("but was " + pixels.constructor.name)); + } + if (isVideo) { + var HAVE_CURRENT_DATA_READY_STATE = 2; + if (isVideo && + pixels.readyState < + HAVE_CURRENT_DATA_READY_STATE) { + throw new Error('The video element has not loaded data yet. Please wait for ' + + '`loadeddata` event on the